In [74]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from einops import rearrange, repeat
from einops.layers.torch import Rearrange





class multi_head_kron(nn.Module):
    def __init__(self, dim_in, dim_out, l_in, l_out, heads, layer_num = 0):
        super().__init__()
        self.heads = heads
        self.mat1 = nn.Linear(dim_in, heads * dim_out, bias = False)
        self.mat1.weight = nn.Parameter(torch.nn.init.uniform_(torch.randn(heads * dim_in, dim_out), a = -(3**0.5), b = 3**0.5) * ((2 ** 0.4) / (dim_in * (heads ** 0.5)) ** 0.5))
        self.mat2 = nn.Parameter(torch.nn.init.uniform_(torch.randn(heads,l_in, l_out), a = -(3**0.5), b = 3**0.5) * ((2 ** 0.4) / (l_in * (heads ** 0.5)) ** 0.5))
        self.activation = nn.ReLU()
        self.bias = nn.Parameter(torch.zeros(l_out, dim_out))
        # self.bn = nn.BatchNorm1d(l_out)
        self.layer_num = layer_num

    def forward(self, x):
        print(f'incoming var at layer {self.layer_num}: {torch.var(x)}')
        x = self.mat1(x)
        x = rearrange(x, 'b l (h d) -> b h l d', h = self.heads)
        x = torch.matmul(self.mat2, x)
        x = torch.sum(x, dim = 1)
        x = x + self.bias
        # x = self.bn(x)
        print(f'pre activation var at layer {self.layer_num}: {torch.var(x)}')
        x = self.activation(x)
        print(f'outgoing var at layer {self.layer_num}:  {torch.var(x)}')
        return x


In [87]:
  
x = torch.randn(3, 100, 100)

model = multi_head_kron(100, 100, 100, 100, 8)

print(torch.var(x))

y = model(x)

print(torch.var(y))

tensor(1.0000)
incoming var at layer 0: 0.9999833106994629
pre activation var at layer 0: 3.070035219192505
outgoing var at layer 0:  1.0589414834976196
tensor(1.0589, grad_fn=<VarBackward0>)


In [65]:
x = torch.randn(3, 100, 100)
print(torch.var(x))
l = nn.Linear(100, 100)
r = nn.ReLU()

print(torch.var(l(x)))
print(torch.var(r(l(x))))


tensor(1.0151)


tensor(0.3305, grad_fn=<VarBackward0>)
tensor(0.1140, grad_fn=<VarBackward0>)


In [92]:
x = torch.randn(3,4,5)

In [95]:
cls = torch.zeros(1,1,5)
cls_tokens = repeat(cls, '() l d -> b l d', b = 3)


In [98]:
print(cls_tokens.shape)

torch.Size([3, 1, 5])


In [99]:
x = torch.cat((cls_tokens, x), dim=1)
print(x)

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.5340, -0.3835, -1.2868, -0.4672, -0.2817],
         [ 1.3504, -0.7269,  0.3176,  1.8474, -0.7700],
         [-0.2543, -0.2582, -0.3357,  0.3720,  0.7625],
         [-2.4791,  0.3627,  0.4921, -1.5963, -0.8217]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 1.1769, -0.5190, -0.0534, -0.3373, -2.8793],
         [-1.8928,  0.4567, -0.4157, -0.7158,  0.8897],
         [-0.4947, -0.2192,  0.4749,  0.7171,  0.4025],
         [ 1.0574,  0.9918,  0.4095,  1.1486,  0.0139]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [-0.0551, -0.3210,  0.4185,  2.0281,  1.2126],
         [ 0.2033, -1.0261,  1.8899,  0.1244, -1.1643],
         [-0.3432, -0.3824, -0.2204, -0.0763, -1.4141],
         [ 1.2047, -1.0015, -0.1798, -0.2705,  1.3075]]])


In [101]:
print(x[:,0])

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
