In [1]:
import torch

In [2]:
a = torch.arange(1, 13, 1).reshape(3, 4).float()

In [3]:
a

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])

In [4]:
a.requires_grad = True

In [3]:
a = torch.arange(1, 13, 1).reshape(3, 4).float()
b = torch.arange(1, 13, 1).reshape(4, 3).float()

a.requires_grad = True
b.requires_grad = True

c = a @ b
d = c.sum()

d.backward()

print(a.grad, b.grad)

tensor([[ 6., 15., 24., 33.],
        [ 6., 15., 24., 33.],
        [ 6., 15., 24., 33.]]) tensor([[15., 15., 15.],
        [18., 18., 18.],
        [21., 21., 21.],
        [24., 24., 24.]])


In [6]:
d = a.half()
k = a.half()

e = d + k
f = e.sum()

a, d, k, e, f

(tensor([[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]], requires_grad=True),
 tensor([[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]], dtype=torch.float16, grad_fn=<ToCopyBackward0>),
 tensor([[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]], dtype=torch.float16, grad_fn=<ToCopyBackward0>),
 tensor([[ 2.,  4.,  6.,  8.],
         [10., 12., 14., 16.],
         [18., 20., 22., 24.]], dtype=torch.float16, grad_fn=<AddBackward0>),
 tensor(156., dtype=torch.float16, grad_fn=<SumBackward0>))

In [7]:
f.backward()

In [8]:
a.grad

tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])

In [9]:
x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)

y = x.sum()  # y is a vector: tensor([4., 6.], grad_fn=<SumBackward1>)

print(x, y)
y.retain_grad = True
y.backward()

tensor([[1., 2.],
        [3., 4.]], requires_grad=True) tensor(10., grad_fn=<SumBackward0>)


In [10]:
x.grad

tensor([[1., 1.],
        [1., 1.]])

In [11]:
import numpy as np

In [12]:
a = np.arange(1, 13, 1).reshape(3, 4)

In [13]:
b = a + 1
b

array([[ 2,  3,  4,  5],
       [ 6,  7,  8,  9],
       [10, 11, 12, 13]])

In [124]:
a = torch.arange(1, 4, 1).reshape(1, 3).float()
a.requires_grad = True
b = torch.arange(1, 7, 1).reshape(3, 2).float()
b.requires_grad = True

c = a @ b
c.retain_grad()
print(a.shape, b.shape, c.shape)

d = c.sum()
d.retain_grad()
d.backward()

print(a.grad, b.grad)
print(c.grad.shape, d.grad.shape)

torch.Size([1, 3]) torch.Size([3, 2]) torch.Size([1, 2])
tensor([[ 3.,  7., 11.]]) tensor([[1., 1.],
        [2., 2.],
        [3., 3.]])
torch.Size([1, 2]) torch.Size([])


In [111]:
def broadcast_axis(left, right):
    """
    mlx uses broadcasting before performing array ops
    this function determines which axes on either arrays will be broadcasted
    in order to calculate gradients along those axes.

    example:
    >>> left.shape = (3, 1)
    >>> right.shape = (1, 4)
    >>> broadcast_axis(left, right)     # ((1, ), (0, ))

    here the second axis for left, and first axis for right will be broadcasted
    """
    
    ldim = len(left)
    rdim = len(right)
    maxdim = max(ldim, rdim)

    lshape_new = (1, ) * (maxdim - ldim) + left
    rshape_new = (1, ) * (maxdim - rdim) + right

    assert len(lshape_new) == len(rshape_new)

    left_axes, right_axes = [], []

    for i in range(len(lshape_new)):
        if lshape_new[i] > rshape_new[i]:
            right_axes.append(i)
        elif rshape_new[i] > lshape_new[i]:
            left_axes.append(i)

    return tuple(left_axes), tuple(right_axes)

In [125]:
an = a.detach().numpy()
bn = b.detach().numpy()
cn = c.grad.detach().numpy()

def get_expand_axis():
    if an.ndim == 1:
        a_expand_axis = (0, )
    else:
        a_expand_axis = ()

    if bn.ndim == 1:
        b_expand_axis = (-1, )
    else:
        b_expand_axis = ()

    return a_expand_axis, b_expand_axis

aa, ba = get_expand_axis()
resa = aa + ba

l, r = broadcast_axis(an.shape[:-2], bn.shape[:-2])

print(l, r)

print(aa, ba, resa)
a_grad = np.reshape(np.sum(
    np.expand_dims(cn, resa) @ np.expand_dims(bn, ba).swapaxes(-1, -2),
    axis = l
), an.shape)

b_grad = np.reshape(np.sum(
    np.expand_dims(an, aa).swapaxes(-1, -2) @ np.expand_dims(cn, resa),
    axis = r
), bn.shape)

print("--------")
print(a_grad)
print(b_grad)

() ()
() () ()
--------
[[ 3.  7. 11.]]
[[1. 1.]
 [2. 2.]
 [3. 3.]]
