# Install Dependencies

In [None]:
!pip install vit-pytorch
!pip install tabulate
!pip install perceiver-pytorch


# Datasets

* STL-10

In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

def DataLoaders(dataset = "CF-10", data_root = "./", batch_size = 16):
  transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(p = 0.5),
        transforms.ColorJitter(brightness=0.5, hue = 0.25),
        transforms.ToTensor(),
      ])

  transform_test = transforms.Compose([
      transforms.ToTensor(),
  ])
  root_dir = os.path.join(data_root, dataset)
  if dataset == "CF-10":
    train_dataset = datasets.CIFAR10(root=root_dir, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root=root_dir, train=False, download=True, transform=transform_test)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
  
  if dataset == "STL-10":
    train_dataset = datasets.STL10(root=root_dir, split='train', transform=transform_train, download=True)
    test_dataset = datasets.STL10(root=root_dir, split='test', transform=transform_test, download=True)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

    return train_loader, test_loader

In [None]:
def GetNumberParameters(model):
  return sum(np.prod(p.shape).item() for p in model.parameters())

In [None]:
def save_model(model, file_name):
    from torch import save
    from os import path
    print("saving", file_name)
    return save(model.state_dict(), file_name)


def load_model():
    from torch import load
    from os import path
    r = CNNClassifier()
    r.load_state_dict(load(path.join(path.dirname(path.abspath(__file__)), 'cnn.th'), map_location='cpu'))
    return r

In [None]:
def accuracy(outputs, labels):
    outputs_idx = outputs.max(1)[1].type_as(labels)
    return outputs_idx.eq(labels).float().mean()

In [None]:
train_loader, test_loader = DataLoaders("STL-10", batch_size = 128)

Files already downloaded and verified
Files already downloaded and verified


# Model Architecture

## Alternating Variant

In [None]:
from vit_pytorch import ViT
import torchvision
from vit_pytorch import ViT
import torchvision.models as models
from perceiver_pytorch import Perceiver

class MMBlock(torch.nn.Module):
  def __init__(self, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
    super().__init__()
    self.conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 15, stride = 2, padding = 2),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 32, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(128, 3, kernel_size=4, stride=3, padding=1),
            torch.nn.BatchNorm2d(3)
    )

    self.vision_transformer = ViT(
        image_size = image_size,
        patch_size = image_size // 16,
        num_classes = num_labels,
        dim = output_dim,
        depth = depth,
        heads = att_heads,
        mlp_dim = mlp_dim,
        dropout = 0.1,
        emb_dropout = 0.1
    )
    self.vision_transformer = torch.nn.Sequential(*(list(self.vision_transformer.children())[:-1]))

    self.upsample_vit = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(256, 3, kernel_size=3, stride=1, padding = 1),
            torch.nn.BatchNorm2d(3))

  def forward(self, x):
    res = self.conv_net(x)
    res = self.vision_transformer(res)
    res = res[:, :, :, None]
    dim = int(res.shape[2] ** 0.5)
    res = res.view(res.shape[0], res.shape[1], dim, dim)

    return self.upsample_vit(res)

class AlternatingMixtureModel(torch.nn.Module):
    def __init__(self, num_blocks = 1, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
        super().__init__()
        self.blocks =  []
        for i in range(num_blocks):
            self.blocks.append(MMBlock(image_size = image_size, num_labels = num_labels, depth = depth, att_heads = att_heads, mlp_dim = mlp_dim, output_dim = output_dim))
        self.ffn = torch.nn.Linear(3 * 32 * 32, 10)
        self.MM = torch.nn.Sequential(*self.blocks)
      
    def forward(self, x):
      return self.ffn(self.MM(x).flatten(start_dim = 1))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AlternatingMixtureModel(2).to(device)
img = torch.randn(16, 3, 32, 32).to(device)
print(model(img).shape)

# Alternating Variant with Residual Connections

In [None]:
class ResidualCNNBlock(torch.nn.Module):
    def __init__(self, c_in, c_out, should_stride=False):
        super().__init__()

        if should_stride:
            stride = 2
        else:
            stride = 1

        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(c_in, c_out, 3, padding=1, stride=stride),
            torch.nn.BatchNorm2d(c_out),
        ) 
        
        self.relu = torch.nn.ReLU()

        if c_in != c_out or should_stride:
            self.identity = torch.nn.Conv2d(c_in, c_out, 1, stride=stride)
        else:
            self.identity = lambda x: x

    def forward(self, x):
        result = self.block(x)
        x = self.identity(x)

        return self.relu(x + result)

class ResMMBlock(torch.nn.Module):
  def __init__(self, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
    super().__init__()
    self.conv_net = torch.nn.Sequential(
              ResidualCNNBlock(3, 32, False),
              ResidualCNNBlock(32, 32, False),
              ResidualCNNBlock(32, 128, False),
              ResidualCNNBlock(128, 128, False),
              ResidualCNNBlock(128, 3, False)
    )
    self.vision_transformer = ViT(
        image_size = image_size,
        patch_size = image_size // 16,
        num_classes = num_labels,
        dim = output_dim,
        depth = depth,
        heads = att_heads,
        mlp_dim = mlp_dim,
        dropout = 0.1,
        emb_dropout = 0.1
    )
    self.vision_transformer = torch.nn.Sequential(*(list(self.vision_transformer.children())[:-1]))

    self.upsample_vit = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(256, 3, kernel_size=3, stride=1, padding = 1),
            torch.nn.BatchNorm2d(3))

  def forward(self, x):
    res = self.conv_net(x)
    res = self.vision_transformer(res)
    res = res[:, :, :, None]
    dim = int(res.shape[2] ** 0.5)
    res = res.view(res.shape[0], res.shape[1], dim, dim)

    return self.upsample_vit(res)

class ResAlternatingMixtureModel(torch.nn.Module):
    def __init__(self, num_blocks = 1, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
        super().__init__()
        self.blocks =  []
        for i in range(num_blocks):
            self.blocks.append(ResMMBlock(image_size = image_size, num_labels = num_labels, depth = depth, att_heads = att_heads, mlp_dim = mlp_dim, output_dim = output_dim))
        self.ffn = torch.nn.Linear(3 * 32 * 32, 10)
        self.MM = torch.nn.Sequential(*self.blocks)
      
    def forward(self, x):
      return self.ffn(self.MM(x).flatten(start_dim = 1))

class ConvNet(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.conv_net = torch.nn.Sequential(
              ResidualCNNBlock(3, 32, False),
              ResidualCNNBlock(32, 32, False),
              ResidualCNNBlock(32, 128, False),
              ResidualCNNBlock(128, 128, False),
              ResidualCNNBlock(128, 3, False)
    )
    self.ffn = torch.nn.Linear(3 * 96 * 96, 10)

  def forward (self, x):
    return self.ffn(self.conv_net(x).flatten(start_dim = 1))

model = ConvNet()
print(model(torch.rand(16, 3, 96, 96)).shape)

# Alternating Variant (CNN + Perceiver)



In [None]:
class PerceiverMMBlock(torch.nn.Module):
  def __init__(self, image_size = 32, depth = 1, att_heads = 1, output_dim = 1024):
    super().__init__()
    self.conv_net = torch.nn.Sequential(
              ResidualCNNBlock(3, 32, False),
              ResidualCNNBlock(32, 32, False),
              ResidualCNNBlock(32, 128, False),
              ResidualCNNBlock(128, 128, False),
              ResidualCNNBlock(128, 3, False)
    )
    self.perceiver = Perceiver(
          input_channels = 3,          # number of channels for each token of the input
          input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
          num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
          max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
          depth = depth,                   # depth of net. The shape of the final attention mechanism will be:
                                      #   depth * (cross attention -> self_per_cross_attn * self attention)
          num_latents = 32,           # number of latents, or induced set points, or centroids. different papers giving it different names
          latent_dim = 32,            # latent dimension
          cross_heads = att_heads,             # number of heads for cross attention. paper said 1
          latent_heads = 8,            # number of heads for latent self attention, 8
          cross_dim_head = 16,         # number of dimensions per cross attention head
          latent_dim_head = 16,        # number of dimensions per latent self attention head
          num_classes = output_dim,    # output number of classes
          attn_dropout = 0.2,
          ff_dropout = 0.2,
          weight_tie_layers = False,   # whether to weight tie layers (optional, as indicated in the diagram)
          fourier_encode_data = True,  # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
          self_per_cross_attn = 2      # number of self attention blocks per cross attention
    )

  def forward(self, x):
    res = self.conv_net(x)
    res = res.permute(0, 2, 3, 1)
    res = self.perceiver(res)
    res = res[:, :, None, None]
    res = res.view(res.shape[0], 3, 96, 96)

    return res

class PerceiverAltMM(torch.nn.Module):
    def __init__(self, num_blocks = 1, image_size = 32, depth = 1, att_heads = 1, output_dim = 1024):
        super().__init__()
        self.blocks =  []
        for i in range(num_blocks):
            self.blocks.append(PerceiverMMBlock(image_size = image_size, depth = depth, att_heads = att_heads, output_dim = output_dim))
        self.ffn = torch.nn.Linear(3 * image_size * image_size, 10)
        self.MM = torch.nn.Sequential(*self.blocks)
      
    def forward(self, x):
      return self.ffn(self.MM(x).flatten(start_dim = 1))

model = PerceiverAltMM(num_blocks=2, image_size=96, depth = 2, att_heads=3, output_dim=3 * 96 * 96)
print(model(torch.rand(16, 3, 96, 96)).shape)

# Joint Variant (CNN + Perceiver)

In [None]:
class JointPerceiverMM(torch.nn.Module):
    def __init__(self, num_blocks = 1, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
      super().__init__()
      self.blocks = []
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      self.conv_net = torch.nn.Sequential(
          ResidualCNNBlock(3, 32, False),
          ResidualCNNBlock(32, 32, False),
          ResidualCNNBlock(32, 128, False),
          ResidualCNNBlock(128, 128, False),
          ResidualCNNBlock(128, 3, False)
      )
      self.perceiver = Perceiver(
        input_channels = 3,          # number of channels for each token of the input
        input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
        num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
        max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
        depth = depth,                   # depth of net. The shape of the final attention mechanism will be:
                                    #   depth * (cross attention -> self_per_cross_attn * self attention)
        num_latents = 32,           # number of latents, or induced set points, or centroids. different papers giving it different names
        latent_dim = 32,            # latent dimension
        cross_heads = att_heads,             # number of heads for cross attention. paper said 1
        latent_heads = 8,            # number of heads for latent self attention, 8
        cross_dim_head = 16,         # number of dimensions per cross attention head
        latent_dim_head = 16,        # number of dimensions per latent self attention head
        num_classes = 3 * 96 * 96,    # output number of classes
        attn_dropout = 0.2,
        ff_dropout = 0.2,
        weight_tie_layers = False,   # whether to weight tie layers (optional, as indicated in the diagram)
        fourier_encode_data = True,  # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
        self_per_cross_attn = 2      # number of self attention blocks per cross attention
      )   

      self.conv_net = self.conv_net.to(device)
      self.perceiver = self.perceiver.to(device)

      if num_blocks == 2:
        self.conv_net2 = torch.nn.Sequential(
            ResidualCNNBlock(3, 32, False),
            ResidualCNNBlock(32, 32, False),
            ResidualCNNBlock(32, 128, False),
            ResidualCNNBlock(128, 128, False),
            ResidualCNNBlock(128, 3, False)
        )
        self.perceiver2 = Perceiver(
          input_channels = 3,          # number of channels for each token of the input
          input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
          num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
          max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
          depth = depth,                   # depth of net. The shape of the final attention mechanism will be:
                                      #   depth * (cross attention -> self_per_cross_attn * self attention)
          num_latents = 32,           # number of latents, or induced set points, or centroids. different papers giving it different names
          latent_dim = 32,            # latent dimension
          cross_heads = att_heads,             # number of heads for cross attention. paper said 1
          latent_heads = 8,            # number of heads for latent self attention, 8
          cross_dim_head = 16,         # number of dimensions per cross attention head
          latent_dim_head = 16,        # number of dimensions per latent self attention head
          num_classes = 3 * 96 * 96,    # output number of classes
          attn_dropout = 0.2,
          ff_dropout = 0.2,
          weight_tie_layers = False,   # whether to weight tie layers (optional, as indicated in the diagram)
          fourier_encode_data = True,  # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
          self_per_cross_attn = 2      # number of self attention blocks per cross attention
        )   

        self.conv_net2 = self.conv_net2.to(device)
        self.perceiver2 = self.perceiver2.to(device)

      self.ffn = torch.nn.Linear(3 * 96 * 96, 10)
      self.norm = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        input = x
        res1 = self.conv_net(input)
        res2 = input.permute(0, 2, 3, 1)
        res2 = self.perceiver(res2)
        res2 = res2[:, :, None, None]
        res2 = res2.view(res2.shape[0], 3, 96, 96)

        input = self.norm(res1) + self.norm(res2)

        res1 = self.conv_net2(input)
        res2 = input.permute(0, 2, 3, 1)
        res2 = self.perceiver2(res2)
        res2 = res2[:, :, None, None]
        res2 = res2.view(res2.shape[0], 3, 96, 96)

        input = self.norm(res1) + self.norm(res2)
        
        return self.ffn(input.flatten(start_dim = 1))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = JointPerceiverMM(2).to(device)
print(model(torch.rand(16, 3, 96, 96).to(device)).shape)
# print(model)

torch.Size([16, 10])


## Joint Variant

In [None]:
class JointMixtureModel(torch.nn.Module):
    def __init__(self, num_blocks = 1, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
      super().__init__()
      self.blocks = []
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      for i in range(num_blocks):
        conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(3, image_size, 15, stride = 2, padding = 2),
            torch.nn.BatchNorm2d(image_size),
            torch.nn.ReLU(),
            torch.nn.Conv2d(image_size, image_size, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(image_size),
            torch.nn.ReLU(),
            torch.nn.Conv2d(image_size, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(128, 3, kernel_size=4, stride=3, padding=1),
            torch.nn.BatchNorm2d(3)
        )
        vit = ViT(
          image_size = image_size,
          patch_size = image_size // 16,
          num_classes = num_labels,
          dim = output_dim,
          depth = depth,
          heads = att_heads,
          mlp_dim = mlp_dim,
          dropout = 0.1,
          emb_dropout = 0.1
        )
        vit = torch.nn.Sequential(*(list(vit.children())[:-1]))

        upsample_vit = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(256, 3, kernel_size=3, stride=1, padding = 1),
            torch.nn.BatchNorm2d(3))
        conv_net = conv_net.to(device)
        vit = vit.to(device)
        upsample_vit = upsample_vit.to(device)
        self.blocks.append([conv_net, vit, upsample_vit])

      self.ffn = torch.nn.Linear(3 * 32 * 32, 10)

    def forward(self, x):
        input = x
        for conv_net, vit, upsample in self.blocks:
          res1 = conv_net(input)
          res2 = vit(input)
          res2 = res2[:, :, :, None]
          dim = int(res2.shape[2] ** 0.5)
          res2 = res2.view(res2.shape[0], res2.shape[1], dim, dim)
          res2 = upsample(res2)

          input = res1 + res2
        
        return self.ffn(input.flatten(start_dim = 1))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = JointMixtureModel(2).to(device)
# img = torch.randn(16, 3, 32, 32).to(device)
# print(model(img).shape)
print(model)
GetNumberParameters(model)

JointMixtureModel(
  (ffn): Linear(in_features=3072, out_features=10, bias=True)
)


30730

# Training Code

In [None]:
lr = 3e-4
epochs = 50
batch_size = 16
num_workers = 2
weight_decay = 1e-4
# Set up the cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = PerceiverAltMM(num_blocks=2, image_size=96, depth = 4, att_heads=2, output_dim=3 * 96 * 96)
# print(model)
# model = Perceiver(
#           input_channels = 3,          # number of channels for each token of the input
#           input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
#           num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
#           max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
#           depth = 4,                   # depth of net. The shape of the final attention mechanism will be:
#                                       #   depth * (cross attention -> self_per_cross_attn * self attention)
#           num_latents = 64,           # number of latents, or induced set points, or centroids. different papers giving it different names
#           latent_dim = 64,            # latent dimension
#           cross_heads = 8,             # number of heads for cross attention. paper said 1
#           latent_heads = 8,            # number of heads for latent self attention, 8
#           cross_dim_head = 16,         # number of dimensions per cross attention head
#           latent_dim_head = 16,        # number of dimensions per latent self attention head
#           num_classes = 10,    # output number of classes
#           attn_dropout = 0.2,
#           ff_dropout = 0.2,
#           weight_tie_layers = False,   # whether to weight tie layers (optional, as indicated in the diagram)
#           fourier_encode_data = True,  # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
#           self_per_cross_attn = 4      # number of self attention blocks per cross attention
#     )

# model = ConvNet()
model = JointPerceiverMM(2)
print(GetNumberParameters(model))
# model.load_state_dict(torch.load('det128.th', map_location='cpu'))
# model = models.resnet18(pretrained=False)
model = model.to(device)

# Set up loss function and optimizer
loss_func = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)

# Set up training data and validation data
data_train, data_val = DataLoaders("STL-10", batch_size = batch_size)

# Wrap in a progress bar.
for epoch in range(epochs):
    print("EPOCH {}".format(epoch))
    # Set the model to training mode.
    model.train()

    train_accuracy_val = list()
    for x, y in tqdm(data_train):
        x = x.to(device)
        y = y.to(device)
        # x = x.permute(0, 2, 3, 1)
        y_pred = model(x)
        train_accuracy_val.append(accuracy(y_pred, y))

        # Compute loss and update model weights.
        loss = loss_func(y_pred, y)

        loss.backward()
        optim.step()
        optim.zero_grad()
        
        # Add loss to TensorBoard.
        # train_logger.add_scalar('Loss', loss.item(), global_step=global_step)
        # global_step += 1

    train_accuracy_total = torch.FloatTensor(train_accuracy_val).mean().item()
    # train_logger.add_scalar('Train Accuracy', train_accuracy_total, global_step=global_step)
    print("Train Accuracy: {:.4f}".format(train_accuracy_total))

    if epoch % 4 == 0:
      save_model(model, 'det128_' + str(epoch) + '.th')

    # Set the model to eval mode and compute accuracy.
    model.eval()

    accuracys_val = list()
    torch.cuda.empty_cache()
    for x, y in data_val:
        x = x.to(device)
        y = y.to(device)  
        # x = x.permute(0, 2, 3, 1)        
        y_pred = model(x)
        accuracys_val.append(accuracy(y_pred, y))
 
    accuracy_total = torch.FloatTensor(accuracys_val).mean().item()
    print("Validation Accuracy: {:.4f}".format(accuracy_total))
    # valid_logger.add_scalar('Validation Accuracy', accuracy_total, global_step=global_step)


# Results

In [29]:
#1693278 - 100 min
temp = "Train Accuracy: 0.1889;Validation Accuracy: 0.2278;Train Accuracy: 0.2466;Validation Accuracy: 0.2759;Train Accuracy: 0.2684;Validation Accuracy: 0.2755;Train Accuracy: 0.2875;Validation Accuracy: 0.2833;Train Accuracy: 0.3193;Validation Accuracy: 0.3279;Train Accuracy: 0.3442;Validation Accuracy: 0.4025;Train Accuracy: 0.3828;Validation Accuracy: 0.3336;Train Accuracy: 0.3986;Validation Accuracy: 0.4116;Train Accuracy: 0.4269;Validation Accuracy: 0.4956;Train Accuracy: 0.4599;Validation Accuracy: 0.4837;Train Accuracy: 0.4722;Validation Accuracy: 0.4760;Train Accuracy: 0.4782;Validation Accuracy: 0.5289;Train Accuracy: 0.4942;Validation Accuracy: 0.4901;Train Accuracy: 0.4998;Validation Accuracy: 0.5336;Train Accuracy: 0.5234;Validation Accuracy: 0.5013;Train Accuracy: 0.5152;Validation Accuracy: 0.5239;Train Accuracy: 0.5224;Validation Accuracy: 0.5260;Train Accuracy: 0.5341;Validation Accuracy: 0.5592;Train Accuracy: 0.5321;Validation Accuracy: 0.5475;Train Accuracy: 0.5427;Validation Accuracy: 0.5446;Train Accuracy: 0.5405;Validation Accuracy: 0.5644;Train Accuracy: 0.5529;Validation Accuracy: 0.5362;Train Accuracy: 0.5547;Validation Accuracy: 0.5681;Train Accuracy: 0.5627;Validation Accuracy: 0.5403;Train Accuracy: 0.5765;Validation Accuracy: 0.5779;Train Accuracy: 0.5683;Validation Accuracy: 0.5667;Train Accuracy: 0.5855;Validation Accuracy: 0.5860;Train Accuracy: 0.5731;Validation Accuracy: 0.5767;Train Accuracy: 0.5865;Validation Accuracy: 0.5539;Train Accuracy: 0.5875;Validation Accuracy: 0.5523;Train Accuracy: 0.5940;Validation Accuracy: 0.5702;Train Accuracy: 0.5970;Validation Accuracy: 0.5782;Train Accuracy: 0.6064;Validation Accuracy: 0.6079;Train Accuracy: 0.6126;Validation Accuracy: 0.5454;Train Accuracy: 0.6142;Validation Accuracy: 0.6083;Train Accuracy: 0.6140;Validation Accuracy: 0.5940;Train Accuracy: 0.6164;Validation Accuracy: 0.5803;Train Accuracy: 0.6272;Validation Accuracy: 0.6080;Train Accuracy: 0.6336;Validation Accuracy: 0.6131;Train Accuracy: 0.6290;Validation Accuracy: 0.6173;Train Accuracy: 0.6370;Validation Accuracy: 0.6069;Train Accuracy: 0.6368;Validation Accuracy: 0.6036;Train Accuracy: 0.6468;Validation Accuracy: 0.5626;Train Accuracy: 0.6583;Validation Accuracy: 0.5698;Train Accuracy: 0.6454;Validation Accuracy: 0.5849;Train Accuracy: 0.6444;Validation Accuracy: 0.6110;Train Accuracy: 0.6556;Validation Accuracy: 0.6438;Train Accuracy: 0.6623;Validation Accuracy: 0.6259;Train Accuracy: 0.6641;Validation Accuracy: 0.5885;Train Accuracy: 0.6611;Validation Accuracy: 0.6080"

splitted = temp.split(";")
train_alt = ["Alternating 1 Block"]
val_alt = ["Alternating 1 Block"]
for i in splitted:
    if "Train" in i:
        train_alt.append(float(i[15:]))
    
    else:
        val_alt.append(float(i[20:]))

In [35]:
# Perceiver - 50 min, 301170
temp = "Train Accuracy: 0.1162;Validation Accuracy: 0.1119;Train Accuracy: 0.1693;Validation Accuracy: 0.2457;Train Accuracy: 0.2226;Validation Accuracy: 0.2404;Train Accuracy: 0.2392;Validation Accuracy: 0.2785;Train Accuracy: 0.2600;Validation Accuracy: 0.2928;Train Accuracy: 0.2644;Validation Accuracy: 0.2934;Train Accuracy: 0.2835;Validation Accuracy: 0.2810;Train Accuracy: 0.2961;Validation Accuracy: 0.3039;Train Accuracy: 0.2945;Validation Accuracy: 0.3154;Train Accuracy: 0.2979;Validation Accuracy: 0.3252;Train Accuracy: 0.3133;Validation Accuracy: 0.3161;Train Accuracy: 0.3139;Validation Accuracy: 0.3261;Train Accuracy: 0.3207;Validation Accuracy: 0.3360;Train Accuracy: 0.3229;Validation Accuracy: 0.3327;Train Accuracy: 0.3319;Validation Accuracy: 0.3442;Train Accuracy: 0.3399;Validation Accuracy: 0.3405;Train Accuracy: 0.3460;Validation Accuracy: 0.3579;Train Accuracy: 0.3510;Validation Accuracy: 0.3569;Train Accuracy: 0.3586;Validation Accuracy: 0.3696;Train Accuracy: 0.3640;Validation Accuracy: 0.3778;Train Accuracy: 0.3712;Validation Accuracy: 0.3750;Train Accuracy: 0.3764;Validation Accuracy: 0.3923;Train Accuracy: 0.3870;Validation Accuracy: 0.3832;Train Accuracy: 0.3790;Validation Accuracy: 0.3900;Train Accuracy: 0.3896;Validation Accuracy: 0.3909;Train Accuracy: 0.3954;Validation Accuracy: 0.3944;Train Accuracy: 0.3976;Validation Accuracy: 0.3902;Train Accuracy: 0.3996;Validation Accuracy: 0.4026;Train Accuracy: 0.4004;Validation Accuracy: 0.4085;Train Accuracy: 0.4050;Validation Accuracy: 0.4002;Train Accuracy: 0.4109;Validation Accuracy: 0.4106;Train Accuracy: 0.4093;Validation Accuracy: 0.4008;Train Accuracy: 0.4261;Validation Accuracy: 0.4145;Train Accuracy: 0.4205;Validation Accuracy: 0.4091;Train Accuracy: 0.4223;Validation Accuracy: 0.4025;Train Accuracy: 0.4227;Validation Accuracy: 0.4157;Train Accuracy: 0.4265;Validation Accuracy: 0.4083;Train Accuracy: 0.4283;Validation Accuracy: 0.4115;Train Accuracy: 0.4271;Validation Accuracy: 0.4141;Train Accuracy: 0.4313;Validation Accuracy: 0.4186;Train Accuracy: 0.4343;Validation Accuracy: 0.4185;Train Accuracy: 0.4341;Validation Accuracy: 0.4162;Train Accuracy: 0.4379;Validation Accuracy: 0.4202;Train Accuracy: 0.4395;Validation Accuracy: 0.4175;Train Accuracy: 0.4475;Validation Accuracy: 0.4206;Train Accuracy: 0.4485;Validation Accuracy: 0.4310;Train Accuracy: 0.4445;Validation Accuracy: 0.4129;Train Accuracy: 0.4505;Validation Accuracy: 0.4295;Train Accuracy: 0.4483;Validation Accuracy: 0.4243;Train Accuracy: 0.4503;Validation Accuracy: 0.4286"

splitted = temp.split(";")
train_perceiver = ["Perceiver"]
val_perceiver = ["Perceiver"]
for i in splitted:
    if "Train" in i:
        train_perceiver.append(float(i[15:]))
    
    else:
        val_perceiver.append(float(i[20:]))

In [34]:
# CNN - 33 Minutes 480054
temp = "Train Accuracy: 0.2588;Validation Accuracy: 0.3828;Train Accuracy: 0.4679;Validation Accuracy: 0.4499;Train Accuracy: 0.6234;Validation Accuracy: 0.4574;Train Accuracy: 0.7298;Validation Accuracy: 0.4535;Train Accuracy: 0.8053;Validation Accuracy: 0.4499;Train Accuracy: 0.8776;Validation Accuracy: 0.4355;Train Accuracy: 0.9119;Validation Accuracy: 0.4451;Train Accuracy: 0.9371;Validation Accuracy: 0.4471;Train Accuracy: 0.9631;Validation Accuracy: 0.4416;Train Accuracy: 0.9655;Validation Accuracy: 0.4419;Train Accuracy: 0.9692;Validation Accuracy: 0.4256;Train Accuracy: 0.9746;Validation Accuracy: 0.4286;Train Accuracy: 0.9766;Validation Accuracy: 0.3996;Train Accuracy: 0.9800;Validation Accuracy: 0.4179;Train Accuracy: 0.9858;Validation Accuracy: 0.4206;Train Accuracy: 0.9824;Validation Accuracy: 0.4184;Train Accuracy: 0.9796;Validation Accuracy: 0.4288;Train Accuracy: 0.9848;Validation Accuracy: 0.4136;Train Accuracy: 0.9818;Validation Accuracy: 0.4219;Train Accuracy: 0.9884;Validation Accuracy: 0.4087;Train Accuracy: 0.9878;Validation Accuracy: 0.4080;Train Accuracy: 0.9872;Validation Accuracy: 0.4212;Train Accuracy: 0.9854;Validation Accuracy: 0.4210;Train Accuracy: 0.9894;Validation Accuracy: 0.4264;Train Accuracy: 0.9912;Validation Accuracy: 0.4205;Train Accuracy: 0.9888;Validation Accuracy: 0.4103;Train Accuracy: 0.9870;Validation Accuracy: 0.4134;Train Accuracy: 0.9894;Validation Accuracy: 0.4106;Train Accuracy: 0.9894;Validation Accuracy: 0.4119;Train Accuracy: 0.9896;Validation Accuracy: 0.4185;Train Accuracy: 0.9932;Validation Accuracy: 0.4135;Train Accuracy: 0.9938;Validation Accuracy: 0.4204;Train Accuracy: 0.9890;Validation Accuracy: 0.4084;Train Accuracy: 0.9904;Validation Accuracy: 0.4055;Train Accuracy: 0.9928;Validation Accuracy: 0.4169;Train Accuracy: 0.9954;Validation Accuracy: 0.4149;Train Accuracy: 0.9932;Validation Accuracy: 0.4156;Train Accuracy: 0.9944;Validation Accuracy: 0.4004;Train Accuracy: 0.9898;Validation Accuracy: 0.4176;Train Accuracy: 0.9918;Validation Accuracy: 0.4139;Train Accuracy: 0.9878;Validation Accuracy: 0.4042;Train Accuracy: 0.9902;Validation Accuracy: 0.4059;Train Accuracy: 0.9912;Validation Accuracy: 0.4083;Train Accuracy: 0.9966;Validation Accuracy: 0.4126;Train Accuracy: 0.9944;Validation Accuracy: 0.4157;Train Accuracy: 0.9956;Validation Accuracy: 0.4149;Train Accuracy: 0.9980;Validation Accuracy: 0.4196;Train Accuracy: 0.9936;Validation Accuracy: 0.4209;Train Accuracy: 0.9928;Validation Accuracy: 0.4002;Train Accuracy: 0.9932;Validation Accuracy: 0.4119"

splitted = temp.split(";")
train_cnn = ["CNN"]
val_cnn = ["CNN"]
for i in splitted:
    if "Train" in i:
        train_cnn.append(float(i[15:]))
    
    else:
        val_cnn.append(float(i[20:]))

In [33]:
# Baseline (Perceiver): 125 Minutes 1625586
temp = "Train Accuracy: 0.1194;Validation Accuracy: 0.1785;Train Accuracy: 0.1771;Validation Accuracy: 0.2262;Train Accuracy: 0.2234;Validation Accuracy: 0.2244;Train Accuracy: 0.2442;Validation Accuracy: 0.2574;Train Accuracy: 0.2486;Validation Accuracy: 0.2747;Train Accuracy: 0.2652;Validation Accuracy: 0.2784;Train Accuracy: 0.2604;Validation Accuracy: 0.2955;Train Accuracy: 0.2792;Validation Accuracy: 0.2896;Train Accuracy: 0.2760;Validation Accuracy: 0.2959;Train Accuracy: 0.2907;Validation Accuracy: 0.2916;Train Accuracy: 0.2905;Validation Accuracy: 0.3225;Train Accuracy: 0.2983;Validation Accuracy: 0.3145;Train Accuracy: 0.3105;Validation Accuracy: 0.3061;Train Accuracy: 0.3033;Validation Accuracy: 0.3239;Train Accuracy: 0.3061;Validation Accuracy: 0.3160;Train Accuracy: 0.3097;Validation Accuracy: 0.3296;Train Accuracy: 0.3225;Validation Accuracy: 0.3229;Train Accuracy: 0.3169;Validation Accuracy: 0.3259;Train Accuracy: 0.3233;Validation Accuracy: 0.3331;Train Accuracy: 0.3277;Validation Accuracy: 0.3363;Train Accuracy: 0.3365;Validation Accuracy: 0.3483;Train Accuracy: 0.3373;Validation Accuracy: 0.3426;Train Accuracy: 0.3440;Validation Accuracy: 0.3646;Train Accuracy: 0.3506;Validation Accuracy: 0.3627;Train Accuracy: 0.3644;Validation Accuracy: 0.3663;Train Accuracy: 0.3544;Validation Accuracy: 0.3724;Train Accuracy: 0.3580;Validation Accuracy: 0.3634;Train Accuracy: 0.3572;Validation Accuracy: 0.3685;Train Accuracy: 0.3746;Validation Accuracy: 0.3735;Train Accuracy: 0.3632;Validation Accuracy: 0.3643;Train Accuracy: 0.3690;Validation Accuracy: 0.3770;Train Accuracy: 0.3774;Validation Accuracy: 0.3573;Train Accuracy: 0.3796;Validation Accuracy: 0.3811;Train Accuracy: 0.3874;Validation Accuracy: 0.3824;Train Accuracy: 0.3736;Validation Accuracy: 0.3672;Train Accuracy: 0.3776;Validation Accuracy: 0.3823;Train Accuracy: 0.3808;Validation Accuracy: 0.3949;Train Accuracy: 0.3928;Validation Accuracy: 0.3915;Train Accuracy: 0.3850;Validation Accuracy: 0.3925;Train Accuracy: 0.3972;Validation Accuracy: 0.3873;Train Accuracy: 0.3922;Validation Accuracy: 0.3988;Train Accuracy: 0.3934;Validation Accuracy: 0.3849;Train Accuracy: 0.3914;Validation Accuracy: 0.3898;Train Accuracy: 0.4000;Validation Accuracy: 0.3991;Train Accuracy: 0.4077;Validation Accuracy: 0.4017;Train Accuracy: 0.4085;Validation Accuracy: 0.4040;Train Accuracy: 0.4127;Validation Accuracy: 0.4039"

splitted = temp.split(";")
train_baseline = ["Baseline"]
val_baseline = ["Baseline"]
for i in splitted:
    if "Train" in i:
        train_baseline.append(float(i[15:]))
    
    else:
        val_baseline.append(float(i[20:]))

In [32]:
# Alt 2 Block: 300 Minutes 3110066
temp = "Train Accuracy: 0.2993;Validation Accuracy: 0.3063;Train Accuracy: 0.3001;Validation Accuracy: 0.3268;Train Accuracy: 0.3259;Validation Accuracy: 0.3226;Train Accuracy: 0.3181;Validation Accuracy: 0.3147;Train Accuracy: 0.3363;Validation Accuracy: 0.3500;Train Accuracy: 0.3526;Validation Accuracy: 0.3749;Train Accuracy: 0.3622;Validation Accuracy: 0.4010;Train Accuracy: 0.3780;Validation Accuracy: 0.3844;Train Accuracy: 0.3736;Validation Accuracy: 0.3483;Train Accuracy: 0.4054;Validation Accuracy: 0.3985;Train Accuracy: 0.4063;Validation Accuracy: 0.4190;Train Accuracy: 0.4153;Validation Accuracy: 0.3831;Train Accuracy: 0.4269;Validation Accuracy: 0.4034;Train Accuracy: 0.4159;Validation Accuracy: 0.4518;Train Accuracy: 0.4337;Validation Accuracy: 0.4354;Train Accuracy: 0.4485;Validation Accuracy: 0.4504;Train Accuracy: 0.4485;Validation Accuracy: 0.4288;Train Accuracy: 0.4609;Validation Accuracy: 0.4270;Train Accuracy: 0.4734;Validation Accuracy: 0.4796;Train Accuracy: 0.4694;Validation Accuracy: 0.4734;Train Accuracy: 0.4916;Validation Accuracy: 0.4807;Train Accuracy: 0.4960;Validation Accuracy: 0.5004;Train Accuracy: 0.5070;Validation Accuracy: 0.5321;Train Accuracy: 0.5288;Validation Accuracy: 0.4991;Train Accuracy: 0.5321;Validation Accuracy: 0.5071;Train Accuracy: 0.5355;Validation Accuracy: 0.5185;Train Accuracy: 0.5294;Validation Accuracy: 0.5496;Train Accuracy: 0.5483;Validation Accuracy: 0.5163;Train Accuracy: 0.5491;Validation Accuracy: 0.5374;Train Accuracy: 0.5565;Validation Accuracy: 0.5586;Train Accuracy: 0.5679;Validation Accuracy: 0.5677;Train Accuracy: 0.5773;Validation Accuracy: 0.5794;Train Accuracy: 0.5815;Validation Accuracy: 0.5576;Train Accuracy: 0.5835;Validation Accuracy: 0.5400;Train Accuracy: 0.5901;Validation Accuracy: 0.5691;Train Accuracy: 0.6052;Validation Accuracy: 0.5232;Train Accuracy: 0.6032;Validation Accuracy: 0.5698;Train Accuracy: 0.6170;Validation Accuracy: 0.5804;Train Accuracy: 0.6244;Validation Accuracy: 0.5805;Train Accuracy: 0.6250;Validation Accuracy: 0.5829;Train Accuracy: 0.6148;Validation Accuracy: 0.5058;Train Accuracy: 0.6212;Validation Accuracy: 0.6030;Train Accuracy: 0.6310;Validation Accuracy: 0.5991;Train Accuracy: 0.6358;Validation Accuracy: 0.5136;Train Accuracy: 0.6294;Validation Accuracy: 0.5370;Train Accuracy: 0.6372;Validation Accuracy: 0.5929;Train Accuracy: 0.6458;Validation Accuracy: 0.6146;Train Accuracy: 0.6494;Validation Accuracy: 0.5811;Train Accuracy: 0.6396;Validation Accuracy: 0.6055;Train Accuracy: 0.6595;Validation Accuracy: 0.5738"

splitted = temp.split(";")
train_2 = ["Alternating - 2 Block"]
val_2 = ["Alternating - 2 Block"]
for i in splitted:
    if "Train" in i:
        train_2.append(float(i[15:]))
    
    else:
        val_2.append(float(i[20:]))

In [31]:
# Joint Block - 1466518, 30 min
temp = "Train Accuracy: 0.2823;Validation Accuracy: 0.3825;Train Accuracy: 0.4523;Validation Accuracy: 0.4565;Train Accuracy: 0.5769;Validation Accuracy: 0.4391;Train Accuracy: 0.6765;Validation Accuracy: 0.4726;Train Accuracy: 0.7438;Validation Accuracy: 0.4891;Train Accuracy: 0.7959;Validation Accuracy: 0.4699;Train Accuracy: 0.8295;Validation Accuracy: 0.4629;Train Accuracy: 0.8564;Validation Accuracy: 0.4691;Train Accuracy: 0.8658;Validation Accuracy: 0.4700;Train Accuracy: 0.9077;Validation Accuracy: 0.4633;Train Accuracy: 0.9207;Validation Accuracy: 0.4696;Train Accuracy: 0.9377;Validation Accuracy: 0.4900;Train Accuracy: 0.8958;Validation Accuracy: 0.4855;Train Accuracy: 0.9287;Validation Accuracy: 0.4836;Train Accuracy: 0.9379;Validation Accuracy: 0.4636;Train Accuracy: 0.9383;Validation Accuracy: 0.4860;Train Accuracy: 0.9489;Validation Accuracy: 0.4871;Train Accuracy: 0.9297;Validation Accuracy: 0.4760;Train Accuracy: 0.9687;Validation Accuracy: 0.4734;Train Accuracy: 0.9289;Validation Accuracy: 0.4576;Train Accuracy: 0.9561;Validation Accuracy: 0.4874;Train Accuracy: 0.9605;Validation Accuracy: 0.4724;Train Accuracy: 0.9497;Validation Accuracy: 0.4515;Train Accuracy: 0.9605;Validation Accuracy: 0.4627;Train Accuracy: 0.9601;Validation Accuracy: 0.4807;Train Accuracy: 0.9748;Validation Accuracy: 0.4688;Train Accuracy: 0.9649;Validation Accuracy: 0.4740;Train Accuracy: 0.9631;Validation Accuracy: 0.4655;Train Accuracy: 0.9595;Validation Accuracy: 0.4836"

splitted = temp.split(";")
train_joint = ["Joint 1 Block"]
val_joint = ["Joint 1 Block"]
for i in splitted:
    if "Train" in i:
        train_joint.append(float(i[15:]))
    
    else:
        val_joint.append(float(i[20:]))

In [30]:
# Joint Block2 - 68 min, 2656540

temp = "Train Accuracy: 0.2550;Validation Accuracy: 0.3294;Train Accuracy: 0.3828;Validation Accuracy: 0.4243;Train Accuracy: 0.4986;Validation Accuracy: 0.4185;Train Accuracy: 0.5553;Validation Accuracy: 0.3821;Train Accuracy: 0.6254;Validation Accuracy: 0.4216;Train Accuracy: 0.7250;Validation Accuracy: 0.4181;Train Accuracy: 0.7588;Validation Accuracy: 0.3902;Train Accuracy: 0.7965;Validation Accuracy: 0.4196;Train Accuracy: 0.8387;Validation Accuracy: 0.3960;Train Accuracy: 0.8810;Validation Accuracy: 0.4011;Train Accuracy: 0.8802;Validation Accuracy: 0.3895;Train Accuracy: 0.8986;Validation Accuracy: 0.3641;Train Accuracy: 0.9018;Validation Accuracy: 0.3870;Train Accuracy: 0.9241;Validation Accuracy: 0.3924;Train Accuracy: 0.9313;Validation Accuracy: 0.3971;Train Accuracy: 0.9219;Validation Accuracy: 0.3898;Train Accuracy: 0.8996;Validation Accuracy: 0.3717;Train Accuracy: 0.9113;Validation Accuracy: 0.3785;Train Accuracy: 0.9239;Validation Accuracy: 0.3639;Train Accuracy: 0.9525;Validation Accuracy: 0.3830;Train Accuracy: 0.9545;Validation Accuracy: 0.3876;Train Accuracy: 0.9463;Validation Accuracy: 0.3714;Train Accuracy: 0.9487;Validation Accuracy: 0.3934;Train Accuracy: 0.9541;Validation Accuracy: 0.3991;Train Accuracy: 0.9515;Validation Accuracy: 0.3808;Train Accuracy: 0.9569;Validation Accuracy: 0.3516;Train Accuracy: 0.9471;Validation Accuracy: 0.3650"

splitted = temp.split(";")
train_joint2 = ["Joint 2 Block"]
val_joint2 = ["Joint 2 Block"]
for i in splitted:
    if "Train" in i:
        train_joint2.append(float(i[15:]))
    
    else:
        val_joint2.append(float(i[20:]))

In [42]:
import csv
a = [val_alt, val_2, val_baseline, val_cnn, val_joint, val_perceiver, val_joint2]
with open("perceiver_validation_acc.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerows(a)

In [None]:
table = [['Model', 'Top-1 Test Accuracy', 'Total Training Time (Minutes)', 'Number of Parameters'], 
         ['Perceiver Baseline', 0.404, 125, 1625586],
         ['Perceiver', 0.431, 50, 301170],
         ['Joint (1 Block)', 0.49, 30, 1466518],
         ['Joint (2 Block)', 0.4243, 70, 2656540],
         ['Alternating (1 Block)', 0.6438, 100, 1693278],
         ['Alternating (2 Block)', 0.6146, 300, 3110066],
         ['CNN', 0.4574, 33, 480054],
         ['ResNet-18', 0.5699, '--', 11689512]]

print(tabulate(table, headers='firstrow', tablefmt='fancy_grid'))
