In [1]:
# %%
import minitorch
from minitorch.cuda_kernel_ops import CudaKernelOps
backend = minitorch.TensorBackend(CudaKernelOps)
import time
import numpy as np

In [2]:
# %%
batch_size = 2
nhead = 10
from_len = 3
to_len = 3

def rand(shape):
    return np.random.rand(*shape)

out_grad = rand((batch_size, nhead, from_len, to_len))
inp = rand((batch_size, nhead, from_len, to_len))


def custom():
    out_grad_mt = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)
    inp_mt = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    mask_mt = minitorch.tensor_from_numpy(np.zeros((batch_size, 1, 1, to_len)), backend=backend, requires_grad=True)
    soft_inp_mt = inp_mt.attn_softmax(mask_mt)

    start_time = time.time()
    soft_inp_mt.backward(out_grad_mt)
    end_time = time.time()

    return inp_mt.grad

def baseline():
    out_grad_mt = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)
    inp_mt = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    soft_inp_mt = minitorch.nn.softmax(inp_mt, dim=3)

    start_time = time.time()
    tsum = out_grad_mt * soft_inp_mt
    tsum = tsum.sum(dim=3).view(tsum.shape[0], tsum.shape[1], tsum.shape[2], 1)
    res = soft_inp_mt * (out_grad_mt - tsum)
    end_time = time.time()

    return res


# %%


# %%
c = custom()
b = baseline()
print(c-b)



[
	[
		[
			[-0.000000 -0.000000 -0.000000]
			[0.000000 -0.000000 0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[-0.000000 -0.000000 -0.000000]
			[0.000000 0.000000 0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[0.000000 0.000000 0.000000]
			[0.000000 0.000000 0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[0.000000 0.000000 -0.000000]
			[0.000000 0.000000 0.000000]
			[-0.000000 0.000000 0.000000]]
		[
			[-0.000000 -0.000000 -0.000000]
			[-0.000000 -0.000000 -0.000000]
			[0.000000 0.000000 0.000000]]
		[
			[0.721597 0.566672 0.598053]
			[0.571790 0.769677 0.625965]
			[0.801660 0.147156 0.522489]]
		[
			[0.842549 0.354289 0.716168]
			[0.822028 0.842531 0.950860]
			[0.271600 0.768335 0.311524]]
		[
			[0.194882 0.589023 0.612221]
			[0.524637 0.711209 0.629774]
			[0.343755 0.597789 0.492936]]
		[
			[0.762153 0.705383 0.674406]
			[0.818306 0.314732 0.409038]
			[0.203483 0.445408 0.247176]]
		[
			[0.601995 0.341825 0.651116]
			[0.510397 0.394909 0.215705]
		

In [3]:
print(c)


[
	[
		[
			[0.093647 -0.042401 -0.051246]
			[0.154206 0.005562 -0.159768]
			[-0.070276 0.016667 0.053608]]
		[
			[0.002756 -0.116612 0.113856]
			[-0.014992 0.008304 0.006688]
			[-0.037054 -0.099148 0.136202]]
		[
			[-0.038111 0.055119 -0.017008]
			[-0.059021 0.128074 -0.069053]
			[-0.061120 0.114608 -0.053488]]
		[
			[0.156593 -0.028268 -0.128325]
			[0.146208 -0.079415 -0.066793]
			[0.104886 -0.044499 -0.060387]]
		[
			[0.044440 -0.053384 0.008944]
			[0.017815 0.073496 -0.091310]
			[-0.070405 -0.055941 0.126346]]
		[
			[0.778685 0.522073 0.585564]
			[0.520622 0.822361 0.624449]
			[0.923689 0.006528 0.541088]]
		[
			[0.921267 0.230616 0.761123]
			[0.786992 0.839327 0.989100]
			[0.147710 0.955108 0.248642]]
		[
			[0.083755 0.606661 0.705710]
			[0.474992 0.764030 0.626598]
			[0.282103 0.659007 0.493370]]
		[
			[0.784638 0.697037 0.660268]
			[0.964749 0.221262 0.356065]
			[0.159456 0.498939 0.237673]]
		[
			[0.643234 0.254539 0.697163]
			[0.584640 0.416545 0.1

In [4]:
print(b)


[
	[
		[
			[0.093648 -0.042401 -0.051246]
			[0.154206 0.005562 -0.159768]
			[-0.070276 0.016667 0.053608]]
		[
			[0.002756 -0.116612 0.113856]
			[-0.014992 0.008304 0.006688]
			[-0.037054 -0.099148 0.136202]]
		[
			[-0.038111 0.055119 -0.017008]
			[-0.059021 0.128074 -0.069053]
			[-0.061120 0.114608 -0.053488]]
		[
			[0.156593 -0.028268 -0.128325]
			[0.146208 -0.079415 -0.066793]
			[0.104886 -0.044499 -0.060387]]
		[
			[0.044440 -0.053384 0.008944]
			[0.017815 0.073496 -0.091310]
			[-0.070405 -0.055941 0.126346]]
		[
			[0.057088 -0.044599 -0.012489]
			[-0.051168 0.052684 -0.001515]
			[0.122029 -0.140628 0.018599]]
		[
			[0.078718 -0.123673 0.044955]
			[-0.035036 -0.003204 0.038240]
			[-0.123890 0.186772 -0.062882]]
		[
			[-0.111127 0.017638 0.093489]
			[-0.049645 0.052821 -0.003176]
			[-0.061652 0.061218 0.000434]]
		[
			[0.022484 -0.008347 -0.014137]
			[0.146443 -0.093470 -0.052973]
			[-0.044027 0.053531 -0.009503]]
		[
			[0.041239 -0.087287 0.046047]
			[

In [5]:
out_grad

array([[[[0.826692  , 0.42501557, 0.32643509],
         [0.86724875, 0.50863343, 0.03402741],
         [0.06320287, 0.35797317, 0.45547104]],

        [[0.70406347, 0.24502249, 0.91608756],
         [0.64116942, 0.70829526, 0.69915332],
         [0.54065158, 0.14768044, 0.94104431]],

        [[0.19407281, 0.47843891, 0.26112177],
         [0.41166961, 0.92352441, 0.33378362],
         [0.19019248, 0.69262059, 0.2542264 ]],

        [[0.9028423 , 0.20929412, 0.02227044],
         [0.90050573, 0.3043901 , 0.30495891],
         [0.64043848, 0.24961443, 0.06529399]],

        [[0.79096915, 0.46086897, 0.71466107],
         [0.36927911, 0.56008467, 0.10113949],
         [0.27048538, 0.32680441, 0.95796152]],

        [[0.77868484, 0.52207278, 0.58556422],
         [0.52062171, 0.82236135, 0.62444911],
         [0.9236888 , 0.00652799, 0.54108831]],

        [[0.92126713, 0.23061561, 0.76112342],
         [0.78699225, 0.83932684, 0.98909979],
         [0.14770976, 0.95510754, 0.24864225]],


In [6]:
out_grad.shape


(2, 10, 3, 3)