In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
import torch.optim as optim
from torch.utils import data

# Model

In [5]:
la1 = None
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        global la1
#         return self.mlp(x)
        for i, layer in enumerate(self.mlp):
            x = layer(x)
            if i == 0:
                la1 = x
#                 print(x.shape)
#                 print(x)
        return x

In [6]:
MlpBLock(2, [3,4])

MlpBLock(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=6, bias=True)
    (1): GELU()
    (2): Linear(in_features=6, out_features=8, bias=True)
    (3): GELU()
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

## Patch Mixer

In [7]:
class PatchMixerBlock(nn.Module):
    
    def __init__(self, patch_size, num_channel):
        super().__init__()
        self.patch_size = patch_size
        
#         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        ps = None
        if isinstance(patch_size, int):
            ps = patch_size**2
        else:
            ps = patch_size[0]*patch_size[1]
        ps = ps*num_channel
        
        self.mlp_patch = MlpBLock(ps, [2])
        
#         self.fold = nn.Fold(kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        ## x has shape-> N, C, H, W; C=Channel
        
        sz = x.shape
        
        y = nn.functional.unfold(x, 
                                 kernel_size=self.patch_size, 
                                 stride=self.patch_size
                                )
#         print("unfolded shape", y.shape)
        #### mix per patch
        y = torch.swapaxes(y, -1, -2)
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        
        y = nn.functional.fold(y, (sz[-2], sz[-1]), 
                               kernel_size=self.patch_size, 
                               stride=self.patch_size
                              )
        x = x+y
        return x

In [8]:
pmb = PatchMixerBlock(8, 3)
pmb

PatchMixerBlock(
  (mlp_patch): MlpBLock(
    (mlp): Sequential(
      (0): Linear(in_features=192, out_features=384, bias=True)
      (1): GELU()
      (2): Linear(in_features=384, out_features=192, bias=True)
    )
  )
)

In [9]:
pmb(torch.randn(5, 3, 16, 16)).shape

torch.Size([5, 3, 16, 16])

## Convert Unfold+Linear+Fold to conv2d+reshape

In [10]:
h = 2
conv2d = nn.Conv2d(3, 3*8*8*h, kernel_size=8, stride=8, groups=1)

In [11]:
conv2d.weight.shape

torch.Size([384, 3, 8, 8])

In [12]:
conv2d.bias.shape

torch.Size([384])

In [13]:
conv2d_1x1 = nn.Conv2d(3*8*8*h, 3*8*8, kernel_size=1, stride=1)

In [14]:
conv2d(torch.randn(5, 3, 16, 16)).shape

torch.Size([5, 384, 2, 2])

In [15]:
yt = conv2d_1x1(conv2d(torch.randn(5, 3, 16, 16)))
yt.shape

torch.Size([5, 192, 2, 2])

In [16]:
yt.view(-1, 3, 8, 8, 2, 2).shape

torch.Size([5, 3, 8, 8, 2, 2])

In [17]:
yt.view(-1, 3, 8, 8, 2, 2).permute(0,1,2,4,3,5).reshape(-1, 3, 16, 16).shape

torch.Size([5, 3, 16, 16])

In [18]:
pmb.mlp_patch.mlp[0], pmb.mlp_patch.mlp[2]

(Linear(in_features=192, out_features=384, bias=True),
 Linear(in_features=384, out_features=192, bias=True))

#### test first layer

In [19]:
conv2d.weight.data = pmb.mlp_patch.mlp[0].weight.data.reshape(-1, 3, 8, 8)

In [20]:
conv2d.bias.data = pmb.mlp_patch.mlp[0].bias.data

In [21]:
x = torch.randn(5, 3, 16, 16)

In [22]:
y = conv2d(x).reshape(5, 3*8*8*h, -1).transpose(1,2)
print(y)
y.shape

tensor([[[-1.0938e-01,  3.0875e-01,  1.1378e-01,  ..., -3.9103e-01,
           3.8987e-02, -2.9499e-01],
         [-3.6289e-02,  4.2033e-01, -1.5503e+00,  ..., -5.1576e-01,
          -8.1168e-01,  4.0114e-02],
         [-4.7239e-01,  3.5274e-02,  1.0130e+00,  ...,  8.1015e-03,
           6.5814e-01,  5.3013e-01],
         [ 1.1671e+00, -5.4763e-01,  4.2892e-01,  ..., -9.5840e-01,
           1.0319e+00,  4.3571e-01]],

        [[-4.6715e-02, -2.4464e-01,  3.8992e-01,  ...,  4.6404e-01,
          -1.0017e-01,  4.0540e-01],
         [ 1.8285e-01,  2.9831e-01, -3.1520e-01,  ..., -6.1594e-01,
          -7.1937e-01,  5.7502e-01],
         [ 7.9037e-01, -1.4237e+00, -9.7666e-02,  ..., -1.0188e+00,
          -5.3295e-01,  6.5312e-02],
         [-6.1959e-01,  8.4910e-01,  8.8066e-02,  ..., -4.1976e-02,
          -1.3695e+00, -1.4766e-01]],

        [[ 8.5568e-02,  3.6752e-01,  3.7391e-01,  ..., -9.2734e-02,
          -5.8245e-01, -4.6754e-01],
         [-6.9989e-01,  1.2125e-01, -4.1096e-01,  .

torch.Size([5, 4, 384])

In [23]:
pmb(x)
print()




In [24]:
torch.allclose(la1, y)

False

In [25]:
diff = (la1-y).abs().data
diff.mean(), diff.std(), diff.min(), diff.max()

(tensor(1.2817e-07), tensor(1.2062e-07), tensor(0.), tensor(1.1921e-06))

##### Putting it together

In [26]:
conv2d_1x1.weight.data.shape, pmb.mlp_patch.mlp[-1].weight.data.shape

(torch.Size([192, 384, 1, 1]), torch.Size([192, 384]))

In [27]:
conv2d_1x1.bias.shape, pmb.mlp_patch.mlp[-1].bias.shape

(torch.Size([192]), torch.Size([192]))

In [28]:
conv2d_1x1.weight.data = pmb.mlp_patch.mlp[-1].weight.data.reshape(192, 384, 1, 1)
conv2d_1x1

Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1))

In [29]:
conv2d_1x1.bias.data = pmb.mlp_patch.mlp[-1].bias.data

In [30]:
### for first layer
conv2d.weight.data.shape, pmb.mlp_patch.mlp[0].weight.data.shape

(torch.Size([384, 3, 8, 8]), torch.Size([384, 192]))

In [31]:
conv2d.bias.shape, pmb.mlp_patch.mlp[0].bias.shape

(torch.Size([384]), torch.Size([384]))

In [32]:
conv2d.weight.data = pmb.mlp_patch.mlp[0].weight.data.reshape(-1, 3, 8, 8)

In [33]:
conv2d.bias.data = pmb.mlp_patch.mlp[0].bias.data

In [34]:
## test
x = torch.randn(5, 3, 16, 16)
actf = nn.GELU()

In [35]:
%timeit x + conv2d_1x1(actf(conv2d(x))).view(-1, 3, 8, 8, 2, 2).permute(0,1,4,2,5,3).reshape(-1, 3, 16, 16)

425 µs ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [36]:
%timeit pmb(x)

388 µs ± 3.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [37]:
a = x + conv2d_1x1(actf(conv2d(x))).view(-1, 3, 8, 8, 2, 2).permute(0,1,4,2,5,3).reshape(-1, 3, 16, 16)
b = pmb(x)

In [38]:
torch.allclose(a, b)

False

In [39]:
diff = (a-b).abs().data
diff.mean(), diff.std(), diff.min(), diff.max()

(tensor(7.1437e-08), tensor(6.7618e-08), tensor(0.), tensor(4.7684e-07))