In [1]:
# %%
import minitorch
from minitorch.cuda_kernel_ops import CudaKernelOps
backend = minitorch.TensorBackend(CudaKernelOps)
import time
import numpy as np

In [15]:
# %%
rows = 10
hidden_dim = 960

def rand(shape):
    return np.random.rand(*shape)

inp = rand((rows, hidden_dim))
out_grad = rand((rows,hidden_dim))
gamma = rand((hidden_dim,))
beta = rand((hidden_dim,))


def custom():
    inp_mt = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    gamma_mt = minitorch.tensor_from_numpy(gamma, backend=backend, requires_grad=True)
    beta_mt = minitorch.tensor_from_numpy(beta, backend=backend, requires_grad=True)
    out_grad_mt = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)

    out_mt = inp_mt.layernorm(gamma_mt, beta_mt)
    out_mt.backward(out_grad_mt)

    return inp_mt.grad, gamma_mt.grad, beta_mt.grad, inp_mt, gamma_mt, beta_mt

def baseline():
    f_input = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    f_gamma = minitorch.tensor_from_numpy(gamma, backend=backend, requires_grad=True)
    f_out_grad = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)

    f_means = f_input.mean(dim=1)
    f_vars = f_input.var(dim=1)
    f_stds = minitorch.tensor(np.sqrt(f_vars.to_numpy()).reshape(-1, 1).tolist(), backend=backend, requires_grad=True)

    xhat = (f_input - f_means) / f_stds
    print(xhat)
    dxhat = f_out_grad * f_gamma
    f_betta_grad = f_out_grad.sum(dim=0)
    f_gamma_grad = (f_out_grad * xhat).sum(dim=0)
    dinp = dxhat.sum(dim=1) + xhat * (dxhat * xhat).sum(dim=1)
    dinp = dxhat - dinp / hidden_dim
    dinp = dinp / f_stds
    return dinp, f_gamma_grad, f_betta_grad




#     return res


# %%


# %%
inp_grad_mt, gamma_mt, beta_mt, my_inp,g,b = custom()
dinp, f_gamma_grad, f_betta_grad = baseline()
# print(c-b)


threadIdx.x 0 threadIdx.y 2 grad 0.034777 xhat 1.546675
threadIdx.x 1 threadIdx.y 2 grad 0.119503 xhat 1.450822
threadIdx.x 2 threadIdx.y 2 grad 0.156209 xhat 0.803982
threadIdx.x 3 threadIdx.y 2 grad 0.961724 xhat -1.098891
threadIdx.x 4 threadIdx.y 2 grad 0.542171 xhat 0.320265
threadIdx.x 5 threadIdx.y 2 grad 0.736066 xhat -1.216780
threadIdx.x 6 threadIdx.y 2 grad 0.964330 xhat -0.109338
threadIdx.x 7 threadIdx.y 2 grad 0.298532 xhat -0.121526
threadIdx.x 8 threadIdx.y 2 grad 0.994314 xhat -0.569814
threadIdx.x 9 threadIdx.y 2 grad 0.857131 xhat 0.767572
threadIdx.x 10 threadIdx.y 2 grad 0.241567 xhat 0.931331
threadIdx.x 11 threadIdx.y 2 grad 0.888804 xhat 0.610758
threadIdx.x 12 threadIdx.y 2 grad 0.020139 xhat -1.425353
threadIdx.x 13 threadIdx.y 2 grad 0.069969 xhat -0.891831
threadIdx.x 14 threadIdx.y 2 grad 0.163875 xhat -0.271100
threadIdx.x 15 threadIdx.y 2 grad 0.918622 xhat 1.087667
threadIdx.x 16 threadIdx.y 2 grad 0.151154 xhat 1.100138
threadIdx.x 17 threadIdx.y 2 grad

In [16]:
(f_gamma_grad-gamma_mt).sum()


[-0.000016]