In [1]:
# %%
import minitorch
from minitorch.cuda_kernel_ops import CudaKernelOps
backend = minitorch.TensorBackend(CudaKernelOps)
import time
import numpy as np

In [2]:
# %%
batch_size = 2
nhead = 10
from_len = 3
to_len = 3

def rand(shape):
    return np.random.rand(*shape)

out_grad = rand((batch_size, nhead, from_len, to_len))
inp = rand((batch_size, nhead, from_len, to_len))


def custom():
    out_grad_mt = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)
    inp_mt = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    mask_mt = minitorch.tensor_from_numpy(np.zeros((batch_size, 1, 1, to_len)), backend=backend, requires_grad=True)
    soft_inp_mt = inp_mt.attn_softmax(mask_mt)

    start_time = time.time()
    soft_inp_mt.backward(out_grad_mt)
    end_time = time.time()

    return inp_mt.grad

def baseline():
    out_grad_mt = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)
    inp_mt = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    soft_inp_mt = minitorch.nn.softmax(inp_mt, dim=3)

    start_time = time.time()
    tsum = out_grad_mt * soft_inp_mt
    tsum = tsum.sum(dim=3).view(tsum.shape[0], tsum.shape[1], tsum.shape[2], 1)
    res = soft_inp_mt * (out_grad_mt - tsum)
    end_time = time.time()

    return res


# %%


# %%
c = custom()
b = baseline()
print(c-b)



[
	[
		[
			[0.000000 0.000000 0.000000]
			[-0.000000 -0.000000 -0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[-0.000000 -0.000000 -0.000000]
			[0.000000 0.000000 0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[0.000000 0.000000 0.000000]
			[-0.000000 0.000000 -0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[-0.000000 0.000000 -0.000000]
			[-0.000000 -0.000000 -0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[0.000000 0.000000 0.000000]
			[0.000000 0.000000 0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[-0.000000 0.000000 -0.000000]
			[0.000000 0.000000 -0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[-0.000000 -0.000000 0.000000]
			[0.000000 0.000000 0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[0.000000 0.000000 -0.000000]
			[0.000000 0.000000 0.000000]
			[-0.000000 -0.000000 -0.000000]]
		[
			[-0.000000 0.000000 0.000000]
			[0.000000 0.000000 0.000000]
			[-0.000000 -0.000000 -0.000000]]
		[
			[0.000000 0.000000 0.000000]
			[0.000000 0.000000 0

In [3]:
print(c)


[
	[
		[
			[0.092767 -0.036855 -0.055912]
			[-0.058527 0.139779 -0.081253]
			[-0.150839 0.035328 0.115512]]
		[
			[0.048383 -0.131900 0.083517]
			[0.114204 -0.098850 -0.015353]
			[-0.036401 0.033137 0.003264]]
		[
			[0.134668 -0.085441 -0.049227]
			[-0.068219 0.116971 -0.048752]
			[0.015344 -0.008184 -0.007160]]
		[
			[-0.040916 0.064287 -0.023371]
			[0.030878 0.007157 -0.038035]
			[0.027181 -0.087059 0.059878]]
		[
			[-0.013639 0.066808 -0.053168]
			[0.031436 0.032605 -0.064041]
			[-0.077439 0.158748 -0.081309]]
		[
			[-0.038526 0.066709 -0.028184]
			[-0.006540 -0.043883 0.050423]
			[0.012664 0.076423 -0.089087]]
		[
			[0.046142 -0.038295 -0.007847]
			[0.001152 -0.025865 0.024714]
			[-0.036622 -0.001361 0.037982]]
		[
			[-0.044327 -0.025147 0.069474]
			[-0.136750 0.104994 0.031756]
			[0.077746 -0.065585 -0.012161]]
		[
			[-0.036466 0.130564 -0.094099]
			[0.079381 -0.146869 0.067488]
			[-0.007400 -0.105100 0.112500]]
		[
			[-0.033786 0.024711 0.009074]
			[

In [4]:
print(b)


[
	[
		[
			[0.092767 -0.036855 -0.055912]
			[-0.058527 0.139779 -0.081253]
			[-0.150839 0.035328 0.115512]]
		[
			[0.048383 -0.131900 0.083517]
			[0.114204 -0.098850 -0.015353]
			[-0.036401 0.033137 0.003264]]
		[
			[0.134668 -0.085441 -0.049227]
			[-0.068219 0.116971 -0.048752]
			[0.015344 -0.008184 -0.007160]]
		[
			[-0.040916 0.064287 -0.023371]
			[0.030878 0.007157 -0.038035]
			[0.027181 -0.087059 0.059878]]
		[
			[-0.013639 0.066808 -0.053168]
			[0.031436 0.032605 -0.064041]
			[-0.077439 0.158748 -0.081309]]
		[
			[-0.038525 0.066709 -0.028184]
			[-0.006540 -0.043883 0.050423]
			[0.012664 0.076423 -0.089087]]
		[
			[0.046142 -0.038295 -0.007847]
			[0.001152 -0.025865 0.024714]
			[-0.036622 -0.001361 0.037982]]
		[
			[-0.044327 -0.025147 0.069474]
			[-0.136750 0.104994 0.031756]
			[0.077746 -0.065585 -0.012161]]
		[
			[-0.036466 0.130564 -0.094099]
			[0.079381 -0.146869 0.067488]
			[-0.007400 -0.105100 0.112500]]
		[
			[-0.033786 0.024711 0.009074]
			[

In [5]:
out_grad

array([[[[0.56587975, 0.19682115, 0.17559039],
         [0.36469261, 0.8916662 , 0.29738087],
         [0.14465601, 0.77274757, 0.89132974]],

        [[0.60525186, 0.06474024, 0.77163859],
         [0.92613015, 0.19998561, 0.53686157],
         [0.07299224, 0.27433351, 0.21230501]],

        [[0.70296892, 0.05553811, 0.06892603],
         [0.14159571, 0.6480166 , 0.1700379 ],
         [0.41380827, 0.31814923, 0.31529105]],

        [[0.15017064, 0.4197974 , 0.13709971],
         [0.89929501, 0.80068936, 0.63209228],
         [0.82768403, 0.38027813, 0.89553514]],

        [[0.1928471 , 0.48925577, 0.14157875],
         [0.67492506, 0.69323384, 0.262967  ],
         [0.03590189, 0.69482865, 0.07928615]],

        [[0.23316743, 0.56815928, 0.18033389],
         [0.5507861 , 0.33832412, 0.70861781],
         [0.53081603, 0.68469427, 0.21554192]],

        [[0.93742998, 0.66823492, 0.70874898],
         [0.56140987, 0.50145505, 0.64547061],
         [0.25871415, 0.40255577, 0.50255447]],


In [6]:
out_grad.shape


(2, 10, 3, 3)