In [1]:
# %%
import minitorch
from minitorch.cuda_kernel_ops import CudaKernelOps
backend = minitorch.TensorBackend(CudaKernelOps)
import time
import numpy as np

In [2]:
# %%
rows = 10
hidden_dim = 8

def rand(shape):
    return np.random.rand(*shape)

inp = rand((rows, hidden_dim))
out_grad = rand((rows,hidden_dim))
gamma = rand((hidden_dim,))
beta = rand((hidden_dim,))


def custom():
    inp_mt = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    gamma_mt = minitorch.tensor_from_numpy(gamma, backend=backend, requires_grad=True)
    beta_mt = minitorch.tensor_from_numpy(beta, backend=backend, requires_grad=True)
    out_grad_mt = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)

    out_mt = inp_mt.layernorm(gamma_mt, beta_mt)
    out_mt.backward(out_grad_mt)

    return inp_mt.grad, gamma_mt.grad, beta_mt.grad, inp_mt, gamma_mt, beta_mt

def baseline():
    f_input = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    f_gamma = minitorch.tensor_from_numpy(gamma, backend=backend, requires_grad=True)
    f_out_grad = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)

    f_means = f_input.mean(dim=1)
    f_vars = f_input.var(dim=1)
    f_stds = minitorch.tensor(np.sqrt(f_vars.to_numpy()).reshape(-1, 1).tolist(), backend=backend, requires_grad=True)

    xhat = (f_input - f_means) / f_stds
    print(xhat)
    dxhat = f_out_grad * f_gamma
    f_betta_grad = f_out_grad.sum(dim=0)
    f_gamma_grad = (f_out_grad * xhat).sum(dim=0)
    dinp = dxhat.sum(dim=1) + xhat * (dxhat * xhat).sum(dim=1)
    dinp = dxhat - dinp / hidden_dim
    dinp = dinp / f_stds
    return dinp, f_gamma_grad, f_betta_grad




#     return res


# %%


# %%
inp_grad_mt, gamma_mt, beta_mt, my_inp,g,b = custom()
dinp, f_gamma_grad, f_betta_grad = baseline()
# print(c-b)


betta threadIdx.x 0 1.295372 -1.435724 0.294708 -0.015074
betta threadIdx.x 1 1.295372 -1.435724 0.294708 -0.015074
betta threadIdx.x 2 1.295372 -1.435724 0.294708 -0.015074
betta threadIdx.x 3 1.295372 -1.435724 0.294708 -0.015074
betta threadIdx.x 4 -1.791443 0.724100 0.532538 0.395524
betta threadIdx.x 5 -1.791443 0.724100 0.532538 0.395524
betta threadIdx.x 6 -1.791443 0.724100 0.532538 0.395524
betta threadIdx.x 7 -1.791443 0.724100 0.532538 0.395524
betta threadIdx.x 8 -inf inf inf inf
betta threadIdx.x 9 -inf inf inf inf
betta threadIdx.x 10 -inf inf inf inf
betta threadIdx.x 11 -inf inf inf inf
betta threadIdx.x 12 inf inf inf inf
betta threadIdx.x 13 inf inf inf inf
betta threadIdx.x 14 inf inf inf inf
betta threadIdx.x 15 inf inf inf inf
betta threadIdx.x 16 nan nan nan nan
betta threadIdx.x 17 nan nan nan nan
betta threadIdx.x 18 nan nan nan nan
betta threadIdx.x 19 nan nan nan nan
betta threadIdx.x 20 nan nan nan nan
betta threadIdx.x 21 nan nan nan nan
betta threadIdx.x 22

In [3]:
dinp


[
	[1.398849 -0.222134 -0.599566 0.312500 -0.162755 0.436848 -0.958701 -0.205040]
	[0.390286 0.377775 -1.159635 -0.389615 0.993946 0.121604 -0.866265 0.531905]
	[1.738728 -0.440887 0.734151 -0.297888 1.167439 -0.996277 -1.625248 -0.280016]
	[-0.090175 0.881272 0.148913 -0.217755 0.441846 -0.633192 -1.079640 0.548731]
	[1.140527 0.936915 -0.769875 -0.260706 -0.312079 -0.500787 -0.425541 0.191547]
	[-0.512240 0.871054 -0.425876 0.187173 0.971125 -1.052704 -1.219969 1.181437]
	[-0.528009 1.204046 -0.331060 0.173711 -0.340129 0.087805 -0.289098 0.022732]
	[0.112246 -0.767554 -0.489325 -0.718132 2.138566 -0.312518 -0.209153 0.245870]
	[0.411308 0.324753 -0.747218 -0.684303 0.134015 0.737362 -0.590518 0.414601]
	[0.749596 -0.007529 0.188810 -0.063040 1.591694 -1.434853 -0.507111 -0.517567]]

In [4]:
inp_grad_mt


[
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]
	[nan nan nan nan nan nan nan nan]]