In [1]:
# %%
import minitorch
from minitorch.cuda_kernel_ops import CudaKernelOps
backend = minitorch.TensorBackend(CudaKernelOps)
import time
import numpy as np

In [2]:
# %%
rows = 10
hidden_dim = 8

def rand(shape):
    return np.random.rand(*shape)

inp = rand((rows, hidden_dim))
out_grad = rand((rows,hidden_dim))
gamma = rand((hidden_dim,))
beta = rand((hidden_dim,))


def custom():
    inp_mt = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    gamma_mt = minitorch.tensor_from_numpy(gamma, backend=backend, requires_grad=True)
    beta_mt = minitorch.tensor_from_numpy(beta, backend=backend, requires_grad=True)
    out_grad_mt = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)

    out_mt = inp_mt.layernorm(gamma_mt, beta_mt)
    out_mt.backward(out_grad_mt)

    return inp_mt.grad, gamma_mt.grad, beta_mt.grad, inp_mt, gamma_mt, beta_mt

def baseline():
    f_input = minitorch.tensor_from_numpy(inp, backend=backend, requires_grad=True)
    f_gamma = minitorch.tensor_from_numpy(gamma, backend=backend, requires_grad=True)
    f_out_grad = minitorch.tensor_from_numpy(out_grad, backend=backend, requires_grad=True)

    f_means = f_input.mean(dim=1)
    f_vars = f_input.var(dim=1)
    f_stds = minitorch.tensor(np.sqrt(f_vars.to_numpy()).reshape(-1, 1).tolist(), backend=backend, requires_grad=True)

    xhat = (f_input - f_means) / f_stds
    print(xhat)
    dxhat = f_out_grad * f_gamma
    f_betta_grad = f_out_grad.sum(dim=0)
    f_gamma_grad = (f_out_grad * xhat).sum(dim=0)
    dinp = dxhat.sum(dim=1) + xhat * (dxhat * xhat).sum(dim=1)
    dinp = dxhat - dinp / hidden_dim
    dinp = dinp / f_stds
    return dinp, f_gamma_grad, f_betta_grad




#     return res


# %%


# %%
inp_grad_mt, gamma_mt, beta_mt, my_inp,g,b = custom()
dinp, f_gamma_grad, f_betta_grad = baseline()
# print(c-b)


betta threadIdx.x 0 -2.410292 -0.193901 0.788223 -0.359631
betta threadIdx.x 1 0.827177 0.706287 0.240376 0.401762
betta threadIdx.x 0 -1.232705 0.729160 -0.904135 0.184053
betta threadIdx.x 1 0.269257 -0.791408 2.071791 -0.326014
betta threadIdx.x 0 -0.492562 1.727008 -1.047601 -0.141781
betta threadIdx.x 1 -1.315368 1.143262 -0.489627 0.616669
betta threadIdx.x 0 -1.710606 0.772420 -1.495746 0.503250
betta threadIdx.x 1 0.121531 1.197968 0.723928 -0.112746
betta threadIdx.x 0 0.784291 -2.200967 -0.116408 0.131355
betta threadIdx.x 1 0.676478 0.934529 0.655939 -0.865217
betta threadIdx.x 0 -0.394311 -1.536533 1.092929 0.715345
betta threadIdx.x 1 0.212143 -0.075685 1.358073 -1.371962
betta threadIdx.x 0 0.576854 -1.756566 0.850936 0.699395
betta threadIdx.x 1 -1.500286 0.306515 -0.173604 0.996756
betta threadIdx.x 0 0.841621 -1.225876 -1.064749 0.529245
betta threadIdx.x 1 0.976593 -1.163514 -0.299666 1.406345
betta threadIdx.x 0 -0.132305 -0.558172 -0.252767 -0.942241
betta threadIdx

In [3]:
dinp


[
	[0.363159 -0.973019 0.567064 -1.343835 0.311715 0.419770 0.539747 0.115399]
	[0.896951 -0.698703 -0.395386 -1.072312 0.843488 -0.885398 0.366174 0.945186]
	[-0.068733 -0.738745 1.149902 -1.119724 0.607371 -0.916680 0.485833 0.600775]
	[0.136526 -0.894889 -0.821604 -0.203212 -0.449466 -0.352787 0.677535 1.907897]
	[1.878921 -0.510564 1.074184 -0.929947 -0.122341 -0.401434 0.235396 -1.224215]
	[2.284641 -1.410709 1.486890 -0.864101 -0.780670 -1.269956 -1.014273 1.568181]
	[0.354428 -0.399201 1.114915 -0.660827 -0.220450 -0.426768 -0.321817 0.559721]
	[1.091704 -0.154234 -0.273991 -0.715171 0.971482 -0.876906 -0.662936 0.620052]
	[0.538298 -0.839852 1.704438 -0.910067 -0.567035 -0.324313 -0.011291 0.409821]
	[0.836467 -0.291591 -0.642730 -0.508428 -0.046064 -0.498319 0.027369 1.123297]]

In [4]:
inp_grad_mt


[
	[0.363159 -0.973018 0.567064 -1.343835 0.311715 0.419770 0.539747 0.115399]
	[0.896951 -0.698703 -0.395386 -1.072311 0.843488 -0.885398 0.366174 0.945186]
	[-0.068733 -0.738745 1.149902 -1.119724 0.607371 -0.916680 0.485833 0.600775]
	[0.136525 -0.894889 -0.821604 -0.203211 -0.449466 -0.352786 0.677535 1.907897]
	[1.878920 -0.510564 1.074184 -0.929947 -0.122341 -0.401434 0.235396 -1.224215]
	[2.284640 -1.410709 1.486890 -0.864101 -0.780670 -1.269956 -1.014273 1.568180]
	[0.354428 -0.399201 1.114915 -0.660828 -0.220450 -0.426768 -0.321817 0.559721]
	[1.091704 -0.154234 -0.273991 -0.715171 0.971482 -0.876906 -0.662935 0.620052]
	[0.538298 -0.839852 1.704437 -0.910067 -0.567035 -0.324313 -0.011291 0.409821]
	[0.836467 -0.291591 -0.642730 -0.508428 -0.046064 -0.498320 0.027369 1.123297]]