In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from tqdm import tqdm
from sklearn import datasets
import random, time, os, sys

In [3]:
class BlockLinear(nn.Module):
    def __init__(self, num_blocks, input_block_dim, output_block_dim, bias=True):
        super().__init__()
        self.weight = torch.randn(num_blocks, input_block_dim, output_block_dim)
        
        self.weight = nn.Parameter(self.weight)
        
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.zeros(self.weight.shape[0], 1, output_block_dim))
        
    def forward(self, x):
#         nblocks, bs, dim = x.shape[0], x.shape[1], x.shape[2]
#         print(x.shape)
        x = torch.bmm(x, self.weight)
        if self.bias is not None:
            x = x + self.bias
        return x
    
    def __repr__(self):
        S = f'BlockLinear: [{self.weight.shape}]'
        return S

In [4]:
bl = BlockLinear(256//4, 4, 5)

In [5]:
bl.weight.shape, bl.bias.shape

(torch.Size([64, 4, 5]), torch.Size([64, 1, 5]))

In [6]:
bl(torch.randn(64, 2, 4)).shape

torch.Size([64, 2, 5])

In [7]:
class BlockMLP(nn.Module):
    def __init__(self, input_dim, layer_dims, actf=nn.ELU):
        super().__init__()
        self.block_dim = layer_dims[0]
        
        assert input_dim%self.block_dim == 0, "Input dim must be even number"
        ### Create a block MLP
        self.mlp = []
        n_blocks = input_dim//layer_dims[0]
        for i in range(len(layer_dims)-1):
            l = BlockLinear(n_blocks, layer_dims[i], layer_dims[i+1])
#             print(l.weight.shape)
            a = actf()
            self.mlp.append(l)
            self.mlp.append(a)
        self.mlp = self.mlp[:-1]
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, self.block_dim).transpose(0,1)
        x = self.mlp(x) + x
        x = x.transpose(1,0).reshape(bs, -1)
        return x
    
#     def __repr__(self):
#         S = f'BlockLinear: [{self.weight.shape}]'
#         return S

In [8]:
mlps = BlockMLP(256, [4, 5, 6])

In [9]:
mlps.mlp[0].weight.shape

torch.Size([64, 4, 5])

In [10]:
mlps(torch.randn(1, 256)).shape

RuntimeError: The size of tensor a (6) must match the size of tensor b (4) at non-singleton dimension 2

In [11]:
64*6

384

In [12]:
class BlockMLP_MixerBlock(nn.Module):
    
    def __init__(self, input_dim, block_dim, hidden_layers_ratio=[2], actf=nn.ELU):
        super().__init__()
        
        assert input_dim%block_dim == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.block_dim = block_dim
        
        def log_base(a, base):
            return np.log(a) / np.log(base)
        
        num_layers = int(np.ceil(log_base(input_dim, base=block_dim)))
        hidden_layers_ratio = [1] + hidden_layers_ratio + [1]
        
        block_layer_dims = [int(a*block_dim) for a in hidden_layers_ratio]
        self.facto_nets = []
        for i in range(num_layers):
            net = BlockMLP(self.input_dim, block_layer_dims, actf)
            self.facto_nets.append(net)
            
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
    def forward(self, x):
        bs = x.shape[0]
        y = x
        for i, fn in enumerate(self.facto_nets):
            y = y.view(-1, self.block_dim, self.block_dim**i).permute(0, 2, 1).contiguous().view(bs, -1)
            y = fn(y)
            y = y.view(-1, self.block_dim**i, self.block_dim).permute(0, 2, 1).contiguous()

        y = y.view(bs, -1)
        return y

In [13]:
N = 32
M = 8
bmlp = BlockMLP_MixerBlock(N, M, [2]) ## Input dim must be power of 2 and so does the block dim
### Input dim / block dim = (in) 2^I

In [14]:
bmlp

BlockMLP_MixerBlock(
  (facto_nets): ModuleList(
    (0): BlockMLP(
      (mlp): Sequential(
        (0): BlockLinear: [torch.Size([4, 8, 16])]
        (1): ELU(alpha=1.0)
        (2): BlockLinear: [torch.Size([4, 16, 8])]
      )
    )
    (1): BlockMLP(
      (mlp): Sequential(
        (0): BlockLinear: [torch.Size([4, 8, 16])]
        (1): ELU(alpha=1.0)
        (2): BlockLinear: [torch.Size([4, 16, 8])]
      )
    )
  )
)

In [15]:
bmlp(torch.randn(2, N)).shape

torch.Size([2, 32])

### Finding Valid sizes

In [18]:
valids = {}
for p in range(1, 15):
    N = int(2**p)
    valids[N] = []
    for q in range(1, p+1):
        M = int(2**q)
        net = BlockMLP_MixerBlock(N, M, [2])
        try:
            net(torch.randn(1, N))
            print(f"Valid for M:{M} N:{N}")
            valids[N].append(M)
        except RuntimeError as e:
#             print(e)
            pass

Valid for M:2 N:2
Valid for M:2 N:4
Valid for M:4 N:4
Valid for M:2 N:8
Valid for M:8 N:8
Valid for M:2 N:16
Valid for M:4 N:16
Valid for M:16 N:16
Valid for M:2 N:32
Valid for M:32 N:32
Valid for M:2 N:64
Valid for M:4 N:64
Valid for M:8 N:64
Valid for M:64 N:64
Valid for M:2 N:128
Valid for M:128 N:128
Valid for M:2 N:256
Valid for M:4 N:256
Valid for M:16 N:256
Valid for M:256 N:256
Valid for M:2 N:512
Valid for M:8 N:512
Valid for M:512 N:512
Valid for M:2 N:1024
Valid for M:4 N:1024
Valid for M:32 N:1024
Valid for M:1024 N:1024
Valid for M:2 N:2048
Valid for M:2048 N:2048
Valid for M:2 N:4096
Valid for M:4 N:4096
Valid for M:8 N:4096
Valid for M:16 N:4096
Valid for M:64 N:4096
Valid for M:4096 N:4096
Valid for M:2 N:8192
Valid for M:8192 N:8192
Valid for M:2 N:16384
Valid for M:4 N:16384
Valid for M:128 N:16384
Valid for M:16384 N:16384


In [21]:
valids

{2: [2],
 4: [2, 4],
 8: [2, 8],
 16: [2, 4, 16],
 32: [2, 32],
 64: [2, 4, 8, 64],
 128: [2, 128],
 256: [2, 4, 16, 256],
 512: [2, 8, 512],
 1024: [2, 4, 32, 1024],
 2048: [2, 2048],
 4096: [2, 4, 8, 16, 64, 4096],
 8192: [2, 8192],
 16384: [2, 4, 128, 16384]}

In [24]:
[(2**i, np.sqrt(2**i)) for i in range(15)]

[(1, 1.0),
 (2, 1.4142135623730951),
 (4, 2.0),
 (8, 2.8284271247461903),
 (16, 4.0),
 (32, 5.656854249492381),
 (64, 8.0),
 (128, 11.313708498984761),
 (256, 16.0),
 (512, 22.627416997969522),
 (1024, 32.0),
 (2048, 45.254833995939045),
 (4096, 64.0),
 (8192, 90.50966799187809),
 (16384, 128.0)]

In [20]:
#### this is wrong
{2: [2],
 4: [2, 4],
 8: [2, 4, 8],
 16: [2, 4, 16],
 32: [2, 4, 8, 32],
 64: [2, 4, 8, 64],
 128: [2, 4, 16, 128],
 256: [2, 4, 8, 16, 256],
 512: [2, 4, 8, 32, 512],
 1024: [2, 4, 32, 1024],
 2048: [2, 4, 8, 16, 64, 2048],
 4096: [2, 4, 8, 16, 64, 4096]}

{2: [2],
 4: [2, 4],
 8: [2, 4, 8],
 16: [2, 4, 16],
 32: [2, 4, 8, 32],
 64: [2, 4, 8, 64],
 128: [2, 4, 16, 128],
 256: [2, 4, 8, 16, 256],
 512: [2, 4, 8, 32, 512],
 1024: [2, 4, 32, 1024],
 2048: [2, 4, 8, 16, 64, 2048],
 4096: [2, 4, 8, 16, 64, 4096]}

In [271]:
a = [print(a, 2**a, np.sqrt(2**a)) for a in range(1, 10)]

1 2 1.4142135623730951
2 4 2.0
3 8 2.8284271247461903
4 16 4.0
5 32 5.656854249492381
6 64 8.0
7 128 11.313708498984761
8 256 16.0
9 512 22.627416997969522
