In [25]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [26]:
from tqdm import tqdm
import random, sys, os

In [27]:
import sparse_linear_lib as sll

In [28]:
device = torch.device("cuda:0")

## Pair Linear approximation

In [29]:
N = 256
seeds = [147, 258, 369, 321, 654, 987, 741, 852, 963, 159]
SEED = seeds[0]

In [30]:
torch.manual_seed(SEED)
## A is a target matrix
# A = torch.randn(N, N).to(device)
A = torch.rand(N, N).to(device)*2-1 

In [31]:
X = torch.eye(N).to(device)

In [32]:
model = sll.PairLinear_MixerBlock(N, N).to(device)

In [33]:
model(X)

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0',
       grad_fn=<ViewBackward0>)

In [34]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

mse = nn.MSELoss()
def mae(A, B):
    return torch.abs(A-B).mean()

criterion = mse
# criterion = mae

In [12]:
### forward propagation
for i in range(20000):
    out = model.forward(X)
    loss = criterion(out, A)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%1000 == 0:
        print(f"The MSE loss is : {float(mse(out,A))}")
#         print(f"The MAE loss is : {float(mae(out,A))}")
#         diff = (out.data-A).abs()        
#         plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
#         plt.show()

The MSE loss is : 0.33718395233154297
The MSE loss is : 0.31270721554756165


KeyboardInterrupt: 

In [None]:
torch.det(A)

In [None]:
torch.det(out.data)

In [None]:
diff = (out.data-A).abs()

In [None]:
diff.min(), diff.max()

In [None]:
diff.mean(), diff.std()

In [None]:
plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
plt.show()

## Approximation using SVD / Eigen

In [None]:
U, S, V = torch.svd(A)

In [None]:
# _m = int(np.ceil(np.sqrt(N)))
_m = N // 2
_m

In [None]:
n_params = sum(p.numel() for p in model.parameters())
_m = int(np.ceil(n_params/(U.shape[0]*2)))
_m

In [None]:
_S = S.clone()
_S[_m:] *= 0
_S

In [None]:
out = torch.mm(torch.mm(U, torch.diag(_S)), V.t())
mse(out, A)

In [None]:
diff = (out.data-A).abs()        
plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
plt.show()

In [None]:
diff.min(), diff.max()

In [None]:
diff.mean(), diff.std()

In [None]:
U.shape[0]*_m*2

In [None]:
torch.numel(A)

In [None]:
## for 2x2 linear
print("number of params: ", sum(p.numel() for p in model.parameters()))

## Approximating Low Rank

In [None]:
model = nn.Sequential(nn.Linear(N, _m, bias=False), nn.Linear(_m, N, bias=False)).to(device)
# model = nn.Linear(N, N, bias=False).to(device) ## it can easily approximate to ~ 0 error

In [None]:
model(X)

In [None]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
### forward propagation
for i in range(20000):
    out = model.forward(X)
    loss = criterion(out, A)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%1000 == 0:
        print(f"The MSE loss is : {float(mse(out,A))}")
#         print(f"The MAE loss is : {float(mae(out,A))}")
#         diff = (out.data-A).abs()        
#         plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
#         plt.show()

In [None]:
torch.det(A)

In [None]:
torch.det(out.data)

In [None]:
diff = (out.data-A).abs()

In [None]:
diff.min(), diff.max()

In [None]:
diff.mean(), diff.std()

In [None]:
plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
plt.show()

### Testing Factorized Addition of 2x2 Factorization

In [None]:
class Add_PairLinears(nn.Module):
    
    def __init__(self, input_dim, num_adds):
        super().__init__()
        self.pair_mixers = []
        self.perm_indices = []
        for i in range(num_adds):
            m = sll.PairLinear_MixerBlock(input_dim, input_dim)
            self.pair_mixers.append(m)
            if i > 0:
                rm = torch.randperm(input_dim)
                self.perm_indices.append(rm)
                
        self.pair_mixers = nn.ModuleList(self.pair_mixers)
        
    def forward(self, x):
        y = torch.zeros_like(x)
        for i, m in enumerate(self.pair_mixers):
            if i > 0:
                _x = x[:, self.perm_indices[i-1]]
            else:
                _x = x
                
            y += m(_x)
        return y

In [None]:
model = Add_PairLinears(N, 4).to(device)

In [None]:
model

In [None]:
model(X)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
### forward propagation
for i in range(20000):
    out = model.forward(X)
    loss = criterion(out, A)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%1000 == 0:
        print(f"The MSE loss is : {float(mse(out,A))}")
#         print(f"The MAE loss is : {float(mae(out,A))}")
#         diff = (out.data-A).abs()        
#         plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
#         plt.show()

In [None]:
'''
The MSE loss is : 0.32291921973228455
'''

In [None]:
torch.det(A)

In [None]:
torch.det(out.data)

In [None]:
diff = (out.data-A).abs()

In [None]:
diff.min(), diff.max()

In [None]:
diff.mean(), diff.std()

In [None]:
plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
plt.show()

### Testing Factorized Stacking of 2x2 Factorization

In [None]:
class Stack_PairLinears(nn.Module):
    
    def __init__(self, input_dim, num_adds):
        super().__init__()
        self.pair_mixers = []
        self.perm_indices = []
        for i in range(num_adds):
            m = sll.PairLinear_MixerBlock(input_dim, input_dim)
            self.pair_mixers.append(m)
            if i > 0:
                rm = torch.randperm(input_dim)
                self.perm_indices.append(rm)
                
        self.pair_mixers = nn.ModuleList(self.pair_mixers)
        
    def forward(self, x):
        for i, m in enumerate(self.pair_mixers):
            if i == 0:
                x = m(x)
            else:
                x = m(x[:, self.perm_indices[i-1]])
        return x

In [None]:
model = Add_PairLinears(N, 4).to(device)

In [None]:
model

In [None]:
model(X)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
### forward propagation
for i in range(20000):
    out = model.forward(X)
    loss = criterion(out, A)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%1000 == 0:
        print(f"The MSE loss is : {float(mse(out,A))}")
#         print(f"The MAE loss is : {float(mae(out,A))}")
#         diff = (out.data-A).abs()        
#         plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
#         plt.show()

In [None]:
'''
The MSE loss is : 0.32291921973228455 --> For plain 2x2
The MSE loss is : 0.2933475971221924 --> For 4 parallel added 2x2
'''

In [None]:
torch.det(A)

In [None]:
torch.det(out.data)

In [None]:
diff = (out.data-A).abs()

In [None]:
diff.min(), diff.max()

In [None]:
diff.mean(), diff.std()

In [None]:
plt.hist(diff.cpu().numpy().reshape(-1), bins=100)
plt.show()

### Testing Factorized Multiplication of 2x2 Factorization

It does not seem to work.. removing the codes

In [None]:
'''
The MSE loss is : 0.32291921973228455 --> For plain 2x2
The MSE loss is : 0.2933475971221924 --> For 4 parallel added 2x2
The MSE loss is : 0.293201208114624 --> For 4 serial composed 2x2
  --> For 4 parallel multiplied 2x2
'''
print()

## Creating mxm blocks rather than 2x2

In [13]:
## choice for m
## 1. m = sqrt(N)
## 2. m = log2(N)

In [35]:
## FFT permutation

A = torch.arange(0, 64, 1, dtype=torch.long)

In [36]:
A0 = A.reshape(-1,4,1).permute(0, 2,1)
A0

tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]],

        [[12, 13, 14, 15]],

        [[16, 17, 18, 19]],

        [[20, 21, 22, 23]],

        [[24, 25, 26, 27]],

        [[28, 29, 30, 31]],

        [[32, 33, 34, 35]],

        [[36, 37, 38, 39]],

        [[40, 41, 42, 43]],

        [[44, 45, 46, 47]],

        [[48, 49, 50, 51]],

        [[52, 53, 54, 55]],

        [[56, 57, 58, 59]],

        [[60, 61, 62, 63]]])

In [37]:
A1 = A.reshape(-1,4,4).permute(0, 2,1)
A1

tensor([[[ 0,  4,  8, 12],
         [ 1,  5,  9, 13],
         [ 2,  6, 10, 14],
         [ 3,  7, 11, 15]],

        [[16, 20, 24, 28],
         [17, 21, 25, 29],
         [18, 22, 26, 30],
         [19, 23, 27, 31]],

        [[32, 36, 40, 44],
         [33, 37, 41, 45],
         [34, 38, 42, 46],
         [35, 39, 43, 47]],

        [[48, 52, 56, 60],
         [49, 53, 57, 61],
         [50, 54, 58, 62],
         [51, 55, 59, 63]]])

In [38]:
A2 = A.reshape(-1,4,16).permute(0, 2,1)
A2

tensor([[[ 0, 16, 32, 48],
         [ 1, 17, 33, 49],
         [ 2, 18, 34, 50],
         [ 3, 19, 35, 51],
         [ 4, 20, 36, 52],
         [ 5, 21, 37, 53],
         [ 6, 22, 38, 54],
         [ 7, 23, 39, 55],
         [ 8, 24, 40, 56],
         [ 9, 25, 41, 57],
         [10, 26, 42, 58],
         [11, 27, 43, 59],
         [12, 28, 44, 60],
         [13, 29, 45, 61],
         [14, 30, 46, 62],
         [15, 31, 47, 63]]])

In [49]:
class BlockWeight(nn.Module):
    def __init__(self, input_dim, block_dim):
        super().__init__()
        self.block_dim = block_dim
        
        assert input_dim%block_dim == 0, "Input dim must be even number"
        self.weight = torch.eye(block_dim).unsqueeze(0).repeat_interleave(input_dim//block_dim, dim=0)
        self.weight = nn.Parameter(self.weight)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        print(x.shape, self.weight.shape)
        x = x.view(bs, -1, self.block_dim).transpose(0,1)
        print(x.shape)
        x = torch.bmm(x, self.weight)
        x = x.transpose(1,0).reshape(bs, -1)
        return x

In [50]:
class BlockLinear_MixerBlock(nn.Module):
    
    def __init__(self, input_dim, block_dim):
        super().__init__()
        
        assert input_dim%block_dim == 0, "Input dim must be even number"
        self.input_dim = input_dim
        
        def log_base(a, base):
            return np.log(a) / np.log(base)
        
        num_layers = int(np.ceil(log_base(input_dim, base=block_dim)))
            
        self.facto_nets = []
        for i in range(num_layers):
            net = BlockWeight(self.input_dim, block_dim)
            self.facto_nets.append(net)
            
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
    def forward(self, x):
        bs = x.shape[0]
        y = x
        for i, fn in enumerate(self.facto_nets):
            y = y.view(-1,4,4**i).permute(0, 2, 1).contiguous().view(bs, -1)
            y = fn(y)
            y = y.view(-1,4**i,4).permute(0, 2, 1).contiguous()

        y = y.view(bs, -1)
        return y

In [55]:
model = BlockLinear_MixerBlock(N, 4).to(device)
model

BlockLinear_MixerBlock(
  (facto_nets): ModuleList(
    (0): BlockWeight()
    (1): BlockWeight()
    (2): BlockWeight()
    (3): BlockWeight()
  )
)

In [56]:
model(X)

torch.Size([256, 256]) torch.Size([64, 4, 4])
torch.Size([64, 256, 4])
torch.Size([256, 256]) torch.Size([64, 4, 4])
torch.Size([64, 256, 4])
torch.Size([256, 256]) torch.Size([64, 4, 4])
torch.Size([64, 256, 4])
torch.Size([256, 256]) torch.Size([64, 4, 4])
torch.Size([64, 256, 4])


tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0',
       grad_fn=<ViewBackward0>)