In [1]:
import torch
import torch.nn as nn

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, json
from tqdm import tqdm

In [2]:
import torch.optim as optim

In [3]:
from transformers_lib import TransformerBlock, \
        Mixer_TransformerBlock_Encoder, \
        PositionalEncoding

# Model

In [4]:
### add randomize patches for clear benefit
class Mixer_Transformer(nn.Module):
    
    def __init__(self, seq_size:int, block_seq_size:int, channel:int, num_blocks:int):
        super().__init__()
        self.transformer_blocks = []
        for i in range(num_blocks):
            L = Mixer_TransformerBlock_Encoder(seq_size, block_seq_size, channel, 8, 0.0, 2.0, nn.GELU, None)
            self.transformer_blocks.append(L)
        self.transformer_blocks = nn.Sequential(*self.transformer_blocks)
        
    def forward(self, x):
        x = self.transformer_blocks(x)
        return x

In [5]:
device = torch.device('cuda:0')
# device = torch.device('cpu')

In [6]:
M = 128
model = Mixer_Transformer(256, 16, M, 2)

In [7]:
model

Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=128, out_features=256, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=256, out_features=128, bias=True)
            )
          )
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Line

In [16]:
criterion = torch.nn.CrossEntropyLoss()
def get_time_taken(model, input_shape, bs=64, steps=50):

    time_taken = []
#     with torch.no_grad():
    if True:
        for i in tqdm(range(steps)):
            inputs = torch.randn(bs, *input_shape).to(device)

            start = time.time()

            outputs = model(inputs)
            outputs.mean().backward()

            start = time.time()-start
            time_taken.append(start)
            
    return np.mean(time_taken)

In [19]:
outputs = []
for p in range(6, 11, 2): ## (4, 11)
    seq_len = int(2**p)
    for token_size in range(6, 11):
        token_size = int(2**token_size)
        for sparse_att in [True, False]:
            blocks = 2
            seq_block = seq_len
            if sparse_att:
                blocks = blocks//2
                seq_block = int(2**np.ceil(np.log2(np.sqrt(seq_len))))

            model = Mixer_Transformer(seq_len, seq_block, token_size, blocks).to(device)
#             model.eval()
            print(model)
            num_params = sum(p.numel() for p in model.parameters())
            print("number of params: ", num_params)
    
            input_shape = [seq_len, token_size]
#             print(input_shape)
            t = get_time_taken(model, input_shape)
            
            outputs += [(seq_len, token_size, sparse_att, num_params, t)] 
            del model
            
        print()

Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=128, out_features=64, bias=True)
            )
          )
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_f

100%|████████████████████████████████████████████████████| 50/50 [00:00<00:00, 201.03it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=128, out_features=64, bias=True)
            )
          )
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    

100%|████████████████████████████████████████████████████| 50/50 [00:00<00:00, 230.52it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=128, out_features=256, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=256, out_features=128, bias=True)
            )
          )
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Lin

100%|████████████████████████████████████████████████████| 50/50 [00:00<00:00, 191.80it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=128, out_features=256, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=256, out_features=128, bias=True)
            )
          )
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=Tru

100%|████████████████████████████████████████████████████| 50/50 [00:00<00:00, 201.23it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:256 heads:8]
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=256, out_features=512, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:256 heads:8]
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Lin

100%|████████████████████████████████████████████████████| 50/50 [00:00<00:00, 116.78it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:256 heads:8]
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=256, out_features=512, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:256 heads:8]
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=Tru

100%|████████████████████████████████████████████████████| 50/50 [00:00<00:00, 122.68it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:512 heads:8]
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=1024, out_features=512, bias=True)
            )
          )
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:512 heads:8]
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): L

100%|█████████████████████████████████████████████████████| 50/50 [00:00<00:00, 53.44it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:512 heads:8]
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=1024, out_features=512, bias=True)
            )
          )
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:512 heads:8]
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=T

100%|█████████████████████████████████████████████████████| 50/50 [00:00<00:00, 55.35it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:1024 heads:8]
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=1024, out_features=2048, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=2048, out_features=1024, bias=True)
            )
          )
          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:1024 heads:8]
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
             

100%|█████████████████████████████████████████████████████| 50/50 [00:02<00:00, 16.99it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:1024 heads:8]
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=1024, out_features=2048, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=2048, out_features=1024, bias=True)
            )
          )
          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:1024 heads:8]
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_a

100%|█████████████████████████████████████████████████████| 50/50 [00:02<00:00, 17.51it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:2048 heads:8]
          (norm1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=2048, out_features=4096, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=4096, out_features=2048, bias=True)
            )
          )
          (norm2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:2048 heads:8]
          (norm1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
             

100%|█████████████████████████████████████████████████████| 50/50 [00:11<00:00,  4.54it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:2048 heads:8]
          (norm1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=2048, out_features=4096, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=4096, out_features=2048, bias=True)
            )
          )
          (norm2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:2048 heads:8]
          (norm1): LayerNorm((2048,), eps=1e-05, elementwise_a

100%|█████████████████████████████████████████████████████| 50/50 [00:10<00:00,  4.63it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=128, out_features=64, bias=True)
            )
          )
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_

100%|█████████████████████████████████████████████████████| 50/50 [00:00<00:00, 76.03it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=128, out_features=64, bias=True)
            )
          )
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    

100%|█████████████████████████████████████████████████████| 50/50 [00:01<00:00, 41.09it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=128, out_features=256, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=256, out_features=128, bias=True)
            )
          )
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Lin

100%|█████████████████████████████████████████████████████| 50/50 [00:00<00:00, 55.65it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=128, out_features=256, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=256, out_features=128, bias=True)
            )
          )
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:128 heads:8]
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=Tru

100%|█████████████████████████████████████████████████████| 50/50 [00:01<00:00, 33.84it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:256 heads:8]
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=256, out_features=512, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:256 heads:8]
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Lin

100%|█████████████████████████████████████████████████████| 50/50 [00:01<00:00, 31.95it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:256 heads:8]
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=256, out_features=512, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:256 heads:8]
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=Tru

100%|█████████████████████████████████████████████████████| 50/50 [00:02<00:00, 23.78it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:512 heads:8]
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=1024, out_features=512, bias=True)
            )
          )
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:512 heads:8]
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): L

100%|█████████████████████████████████████████████████████| 50/50 [00:03<00:00, 13.32it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:512 heads:8]
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=1024, out_features=512, bias=True)
            )
          )
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:512 heads:8]
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=T

100%|█████████████████████████████████████████████████████| 50/50 [00:04<00:00, 11.83it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:1024 heads:8]
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=1024, out_features=2048, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=2048, out_features=1024, bias=True)
            )
          )
          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:1024 heads:8]
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
             

100%|█████████████████████████████████████████████████████| 50/50 [00:12<00:00,  3.88it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:1024 heads:8]
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=1024, out_features=2048, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=2048, out_features=1024, bias=True)
            )
          )
          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:1024 heads:8]
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_a

100%|█████████████████████████████████████████████████████| 50/50 [00:13<00:00,  3.78it/s]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:2048 heads:8]
          (norm1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=2048, out_features=4096, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=4096, out_features=2048, bias=True)
            )
          )
          (norm2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:2048 heads:8]
          (norm1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
             

100%|█████████████████████████████████████████████████████| 50/50 [00:52<00:00,  1.05s/it]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:2048 heads:8]
          (norm1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=2048, out_features=4096, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=4096, out_features=2048, bias=True)
            )
          )
          (norm2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:2048 heads:8]
          (norm1): LayerNorm((2048,), eps=1e-05, elementwise_a

100%|█████████████████████████████████████████████████████| 50/50 [00:52<00:00,  1.06s/it]



Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=128, out_features=64, bias=True)
            )
          )
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_

100%|█████████████████████████████████████████████████████| 50/50 [00:02<00:00, 20.37it/s]


Mixer_Transformer(
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=128, out_features=64, bias=True)
            )
          )
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (1): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention Sparse: [embed:64 heads:8]
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    

  0%|                                                              | 0/50 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 10.91 GiB total capacity; 8.28 GiB already allocated; 523.00 MiB free; 10.10 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [21]:
2//2

1

In [20]:
outputs

[(64, 64, True, 66944, 0.0025853347778320313),
 (64, 64, False, 66944, 0.002426915168762207),
 (64, 128, True, 264960, 0.0024931812286376953),
 (64, 128, False, 264960, 0.002423653602600098),
 (64, 256, True, 1054208, 0.0025581026077270506),
 (64, 256, False, 1054208, 0.0024839115142822267),
 (64, 512, True, 4205568, 0.002698359489440918),
 (64, 512, False, 4205568, 0.0026272010803222658),
 (64, 1024, True, 16799744, 0.0027275371551513674),
 (64, 1024, False, 16799744, 0.002667121887207031),
 (64, 2048, True, 67153920, 0.0028499650955200197),
 (64, 2048, False, 67153920, 0.002659769058227539),
 (256, 64, True, 66944, 0.0026912593841552732),
 (256, 64, False, 66944, 0.0026161861419677734),
 (256, 128, True, 264960, 0.0026175880432128904),
 (256, 128, False, 264960, 0.002594738006591797),
 (256, 256, True, 1054208, 0.0026099395751953123),
 (256, 256, False, 1054208, 0.0025563621520996093),
 (256, 512, True, 4205568, 0.0029493141174316406),
 (256, 512, False, 4205568, 0.002875447273254394