# Install Dependencies

In [1]:
!pip install vit-pytorch

Collecting vit-pytorch
  Downloading vit_pytorch-0.22.0-py3-none-any.whl (39 kB)
Collecting einops>=0.3
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops, vit-pytorch
Successfully installed einops-0.3.2 vit-pytorch-0.22.0


# Datasets

* CIFAR-10

In [2]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

def DataLoaders(dataset = "CF-10", data_root = "./", batch_size = 16):
  transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(p = 0.5),
        transforms.ColorJitter(brightness=0.5, hue = 0.25),
        transforms.ToTensor(),
      ])

  transform_test = transforms.Compose([
      transforms.ToTensor(),
  ])
  root_dir = os.path.join(data_root, dataset)
  if dataset == "CF-10":
    train_dataset = datasets.CIFAR10(root=root_dir, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root=root_dir, train=False, download=True, transform=transform_test)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
  
    return train_loader, test_loader 

In [3]:
def GetNumberParameters(model):
  return sum(np.prod(p.shape).item() for p in model.parameters())

In [4]:
def save_model(model, file_name):
    from torch import save
    from os import path
    print("saving", file_name)
    return save(model.state_dict(), file_name)


def load_model():
    from torch import load
    from os import path
    r = CNNClassifier()
    r.load_state_dict(load(path.join(path.dirname(path.abspath(__file__)), 'cnn.th'), map_location='cpu'))
    return r

In [5]:
def accuracy(outputs, labels):
    outputs_idx = outputs.max(1)[1].type_as(labels)
    return outputs_idx.eq(labels).float().mean()

In [6]:
train_loader, test_loader = DataLoaders("CF-10", batch_size = 128)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./CF-10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./CF-10/cifar-10-python.tar.gz to ./CF-10
Files already downloaded and verified


# Model Architecture

## Alternating Variant

In [6]:
from vit_pytorch import ViT
import torchvision
from vit_pytorch import ViT
import torchvision.models as models

class MMBlock(torch.nn.Module):
  def __init__(self, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
    super().__init__()
    self.conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 15, stride = 2, padding = 2),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 32, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(128, 3, kernel_size=4, stride=3, padding=1),
            torch.nn.BatchNorm2d(3)
    )

    self.vision_transformer = ViT(
        image_size = image_size,
        patch_size = image_size // 16,
        num_classes = num_labels,
        dim = output_dim,
        depth = depth,
        heads = att_heads,
        mlp_dim = mlp_dim,
        dropout = 0.1,
        emb_dropout = 0.1
    )
    self.vision_transformer = torch.nn.Sequential(*(list(self.vision_transformer.children())[:-1]))

    self.upsample_vit = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(256, 3, kernel_size=3, stride=1, padding = 1),
            torch.nn.BatchNorm2d(3))

  def forward(self, x):
    res = self.conv_net(x)
    res = self.vision_transformer(res)
    res = res[:, :, :, None]
    dim = int(res.shape[2] ** 0.5)
    res = res.view(res.shape[0], res.shape[1], dim, dim)

    return self.upsample_vit(res)

class AlternatingMixtureModel(torch.nn.Module):
    def __init__(self, num_blocks = 1, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
        super().__init__()
        self.blocks =  []
        for i in range(num_blocks):
            self.blocks.append(MMBlock(image_size = image_size, num_labels = num_labels, depth = depth, att_heads = att_heads, mlp_dim = mlp_dim, output_dim = output_dim))
        self.ffn = torch.nn.Linear(3 * 32 * 32, 10)
        self.MM = torch.nn.Sequential(*self.blocks)
      
    def forward(self, x):
      return self.ffn(self.MM(x).flatten(start_dim = 1))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AlternatingMixtureModel(2).to(device)
img = torch.randn(16, 3, 32, 32).to(device)
print(model(img).shape)

torch.Size([16, 10])


# Alternating Variant with Residual Connections

In [29]:
class ResidualCNNBlock(torch.nn.Module):
    def __init__(self, c_in, c_out, should_stride=False):
        super().__init__()

        if should_stride:
            stride = 2
        else:
            stride = 1

        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(c_in, c_out, 3, padding=1, stride=stride),
            torch.nn.BatchNorm2d(c_out),
        ) 
        
        self.relu = torch.nn.ReLU()

        if c_in != c_out or should_stride:
            self.identity = torch.nn.Conv2d(c_in, c_out, 1, stride=stride)
        else:
            self.identity = lambda x: x

    def forward(self, x):
        result = self.block(x)
        x = self.identity(x)

        return self.relu(x + result)

class ResMMBlock(torch.nn.Module):
  def __init__(self, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
    super().__init__()
    self.conv_net = torch.nn.Sequential(
              ResidualCNNBlock(3, 32, False),
              ResidualCNNBlock(32, 32, False),
              ResidualCNNBlock(32, 128, False),
              ResidualCNNBlock(128, 128, False),
              ResidualCNNBlock(128, 3, False)
    )
    self.vision_transformer = ViT(
        image_size = image_size,
        patch_size = image_size // 16,
        num_classes = num_labels,
        dim = output_dim,
        depth = depth,
        heads = att_heads,
        mlp_dim = mlp_dim,
        dropout = 0.1,
        emb_dropout = 0.1
    )
    self.vision_transformer = torch.nn.Sequential(*(list(self.vision_transformer.children())[:-1]))

    self.upsample_vit = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(256, 3, kernel_size=3, stride=1, padding = 1),
            torch.nn.BatchNorm2d(3))

  def forward(self, x):
    res = self.conv_net(x)
    res = self.vision_transformer(res)
    res = res[:, :, :, None]
    dim = int(res.shape[2] ** 0.5)
    res = res.view(res.shape[0], res.shape[1], dim, dim)

    return self.upsample_vit(res)

class ResAlternatingMixtureModel(torch.nn.Module):
    def __init__(self, num_blocks = 1, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
        super().__init__()
        self.blocks =  []
        for i in range(num_blocks):
            self.blocks.append(ResMMBlock(image_size = image_size, num_labels = num_labels, depth = depth, att_heads = att_heads, mlp_dim = mlp_dim, output_dim = output_dim))
        self.ffn = torch.nn.Linear(3 * 32 * 32, 10)
        self.MM = torch.nn.Sequential(*self.blocks)
      
    def forward(self, x):
      return self.ffn(self.MM(x).flatten(start_dim = 1))

model = ResAlternatingMixtureModel(2)
print(model(torch.rand(16, 3, 32, 32)).shape)

torch.Size([16, 10])


## Joint Variant

In [None]:
class JointMixtureModel(torch.nn.Module):
    def __init__(self, num_blocks = 1, image_size = 32, num_labels = 10, depth = 1, att_heads = 1, mlp_dim = 2048, output_dim = 1024):
      super().__init__()
      self.blocks = []
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      for i in range(num_blocks):
        conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(3, image_size, 15, stride = 2, padding = 2),
            torch.nn.BatchNorm2d(image_size),
            torch.nn.ReLU(),
            torch.nn.Conv2d(image_size, image_size, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(image_size),
            torch.nn.ReLU(),
            torch.nn.Conv2d(image_size, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(128, 3, kernel_size=4, stride=3, padding=1),
            torch.nn.BatchNorm2d(3)
        )
        vit = ViT(
          image_size = image_size,
          patch_size = image_size // 16,
          num_classes = num_labels,
          dim = output_dim,
          depth = depth,
          heads = att_heads,
          mlp_dim = mlp_dim,
          dropout = 0.1,
          emb_dropout = 0.1
        )
        vit = torch.nn.Sequential(*(list(vit.children())[:-1]))

        upsample_vit = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(256, 3, kernel_size=3, stride=1, padding = 1),
            torch.nn.BatchNorm2d(3))
        conv_net = conv_net.to(device)
        vit = vit.to(device)
        upsample_vit = upsample_vit.to(device)
        self.blocks.append([conv_net, vit, upsample_vit])

      self.ffn = torch.nn.Linear(3 * 32 * 32, 10)

    def forward(self, x):
        input = x
        for conv_net, vit, upsample in self.blocks:
          res1 = conv_net(input)
          res2 = vit(input)
          res2 = res2[:, :, :, None]
          dim = int(res2.shape[2] ** 0.5)
          res2 = res2.view(res2.shape[0], res2.shape[1], dim, dim)
          res2 = upsample(res2)

          input = res1 + res2
        
        return self.ffn(input.flatten(start_dim = 1))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = JointMixtureModel(2).to(device)
# img = torch.randn(16, 3, 32, 32).to(device)
# print(model(img).shape)
print(model)
GetNumberParameters(model)

JointMixtureModel(
  (ffn): Linear(in_features=3072, out_features=10, bias=True)
)


30730

# Training Code

In [None]:
lr = 1e-3
epochs = 15
batch_size = 32
num_workers = 2
weight_decay = 1e-4
# Set up the cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = JointMixtureModel(1).to(device)
# model = models.resnet18(pretrained=False).to(device)
model = ResAlternatingMixtureModel(num_blocks=1, depth = 2, att_heads=2)
print(GetNumberParameters(model))
# model.load_state_dict(torch.load('det30.th', map_location='cpu'))
model = model.to(device)

# Set up loss function and optimizer
loss_func = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)

# Set up training data and validation data
data_train, data_val = DataLoaders("CF-10", batch_size = batch_size)

# Set up loggers
# log_dir = '.'
# log_time = '{}'.format(time.strftime('%H-%M-%S'))
# log_name = 'lr=%s_epoch=%s_batch_size=%s_wd=%s' % (lr, epochs, batch_size, weight_decay)
# logger = tb.SummaryWriter()
# train_logger = tb.SummaryWriter(path.join(log_dir, 'train') + '/%s_%s' % (log_name, log_time))
# valid_logger = tb.SummaryWriter(path.join(log_dir, 'test') + '/%s_%s' % (log_name, log_time))
# global_step = 0

# Wrap in a progress bar.
for epoch in range(epochs):
    print("!!!!!!!!!!!!!!EPOCH {}!!!!!!!!!!!!!".format(epoch))
    # Set the model to training mode.
    model.train()

    train_accuracy_val = list()
    for x, y in tqdm(data_train):
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)
        train_accuracy_val.append(accuracy(y_pred, y))

        # Compute loss and update model weights.
        loss = loss_func(y_pred, y)

        loss.backward()
        optim.step()
        optim.zero_grad()
        
        # Add loss to TensorBoard.
        # train_logger.add_scalar('Loss', loss.item(), global_step=global_step)
        # global_step += 1

    train_accuracy_total = torch.FloatTensor(train_accuracy_val).mean().item()
    # train_logger.add_scalar('Train Accuracy', train_accuracy_total, global_step=global_step)
    print("Train Accuracy: {:.4f}".format(train_accuracy_total))

    if epoch % 2 == 0:
      save_model(model, 'det' + str(epoch) + '.th')

    # Set the model to eval mode and compute accuracy.
    # No need to change this, but feel free to implement additional logging.
    model.eval()

    accuracys_val = list()

    for x, y in data_val:
        x = x.to(device)
        y = y.to(device)  
        y_pred = model(x)
        accuracys_val.append(accuracy(y_pred, y))
 
    accuracy_total = torch.FloatTensor(accuracys_val).mean().item()
    print("Validation Accuracy: {:.4f}".format(accuracy_total))
    # valid_logger.add_scalar('Validation Accuracy', accuracy_total, global_step=global_step)


9708095
Files already downloaded and verified
Files already downloaded and verified
!!!!!!!!!!!!!!EPOCH 0!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.3460
saving det0.th
Validation Accuracy: 0.4508
!!!!!!!!!!!!!!EPOCH 1!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.4814
Validation Accuracy: 0.5146
!!!!!!!!!!!!!!EPOCH 2!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.5392
saving det2.th
Validation Accuracy: 0.5618
!!!!!!!!!!!!!!EPOCH 3!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.5796
Validation Accuracy: 0.6039
!!!!!!!!!!!!!!EPOCH 4!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.6112
saving det4.th
Validation Accuracy: 0.6403
!!!!!!!!!!!!!!EPOCH 5!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.6317
Validation Accuracy: 0.6510
!!!!!!!!!!!!!!EPOCH 6!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.6498
saving det6.th
Validation Accuracy: 0.6603
!!!!!!!!!!!!!!EPOCH 7!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.6636
Validation Accuracy: 0.6749
!!!!!!!!!!!!!!EPOCH 8!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.6767
saving det8.th
Validation Accuracy: 0.7004
!!!!!!!!!!!!!!EPOCH 9!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.6873
Validation Accuracy: 0.6987
!!!!!!!!!!!!!!EPOCH 10!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.6987
saving det10.th
Validation Accuracy: 0.7125
!!!!!!!!!!!!!!EPOCH 11!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

Train Accuracy: 0.7072
Validation Accuracy: 0.7028
!!!!!!!!!!!!!!EPOCH 12!!!!!!!!!!!!!


  0%|          | 0/1563 [00:00<?, ?it/s]

!!!!!!!!!!!!!!EPOCH 0!!!!!!!!!!!!!
100%
391/391 [07:36<00:00, 1.17s/it]
Train Accuracy: 0.2272
saving det0.th
Validation Accuracy: 0.2416
!!!!!!!!!!!!!!EPOCH 1!!!!!!!!!!!!!
100%
391/391 [07:39<00:00, 1.18s/it]
Train Accuracy: 0.2493
Validation Accuracy: 0.3089
!!!!!!!!!!!!!!EPOCH 2!!!!!!!!!!!!!
100%
391/391 [07:39<00:00, 1.18s/it]
Train Accuracy: 0.3103
saving det2.th
Validation Accuracy: 0.3562
!!!!!!!!!!!!!!EPOCH 3!!!!!!!!!!!!!
100%
391/391 [07:39<00:00, 1.18s/it]
Train Accuracy: 0.3403
Validation Accuracy: 0.3302
!!!!!!!!!!!!!!EPOCH 4!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.3626
saving det4.th
Validation Accuracy: 0.3685
!!!!!!!!!!!!!!EPOCH 5!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.3847
Validation Accuracy: 0.4074
!!!!!!!!!!!!!!EPOCH 6!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.4088
saving det6.th
Validation Accuracy: 0.4354
!!!!!!!!!!!!!!EPOCH 7!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.4320
Validation Accuracy: 0.4342
!!!!!!!!!!!!!!EPOCH 8!!!!!!!!!!!!!
100%
391/391 [07:41<00:00, 1.18s/it]
Train Accuracy: 0.4476
saving det8.th
Validation Accuracy: 0.4768
!!!!!!!!!!!!!!EPOCH 9!!!!!!!!!!!!!
100%
391/391 [07:41<00:00, 1.18s/it]
Train Accuracy: 0.4725
Validation Accuracy: 0.4845
!!!!!!!!!!!!!!EPOCH 10!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.17s/it]
Train Accuracy: 0.4905
saving det10.th
Validation Accuracy: 0.4929
!!!!!!!!!!!!!!EPOCH 11!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.5083
Validation Accuracy: 0.5254
!!!!!!!!!!!!!!EPOCH 12!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.5277
saving det12.th
Validation Accuracy: 0.5397
!!!!!!!!!!!!!!EPOCH 13!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.5515
Validation Accuracy: 0.5611
!!!!!!!!!!!!!!EPOCH 14!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.5733
saving det14.th
Validation Accuracy: 0.5756

!!!!!!!!!!!!!!EPOCH 0!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.5861
saving det0.th
Validation Accuracy: 0.5900
!!!!!!!!!!!!!!EPOCH 1!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.6078
Validation Accuracy: 0.6211
!!!!!!!!!!!!!!EPOCH 2!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.6260
saving det2.th
Validation Accuracy: 0.6295
!!!!!!!!!!!!!!EPOCH 3!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.18s/it]
Train Accuracy: 0.6409
Validation Accuracy: 0.6452
!!!!!!!!!!!!!!EPOCH 4!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.19s/it]
Train Accuracy: 0.6549
saving det4.th
Validation Accuracy: 0.6539
!!!!!!!!!!!!!!EPOCH 5!!!!!!!!!!!!!
100%
391/391 [07:40<00:00, 1.17s/it]
Train Accuracy: 0.6677
Validation Accuracy: 0.6552
!!!!!!!!!!!!!!EPOCH 6!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.6803
saving det6.th
Validation Accuracy: 0.6754
!!!!!!!!!!!!!!EPOCH 7!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.18s/it]
Train Accuracy: 0.6884
Validation Accuracy: 0.6861
!!!!!!!!!!!!!!EPOCH 8!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.18s/it]
Train Accuracy: 0.7004
saving det8.th
Validation Accuracy: 0.6779
!!!!!!!!!!!!!!EPOCH 9!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.7087
Validation Accuracy: 0.6846
!!!!!!!!!!!!!!EPOCH 10!!!!!!!!!!!!!
100%
391/391 [07:35<00:00, 1.17s/it]
Train Accuracy: 0.7197
saving det10.th
Validation Accuracy: 0.6792
!!!!!!!!!!!!!!EPOCH 11!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.7307
Validation Accuracy: 0.6765
!!!!!!!!!!!!!!EPOCH 12!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.7394
saving det12.th
Validation Accuracy: 0.6761
!!!!!!!!!!!!!!EPOCH 13!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.18s/it]
Train Accuracy: 0.7481
Validation Accuracy: 0.6956
!!!!!!!!!!!!!!EPOCH 14!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.7539
saving det14.th
Validation Accuracy: 0.7141

!!!!!!!!!!!!!!EPOCH 0!!!!!!!!!!!!!
100%
391/391 [07:38<00:00, 1.17s/it]
Train Accuracy: 0.7643
saving det0.th
Validation Accuracy: 0.6928
!!!!!!!!!!!!!!EPOCH 1!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.7742
Validation Accuracy: 0.6847
!!!!!!!!!!!!!!EPOCH 2!!!!!!!!!!!!!
100%
391/391 [07:35<00:00, 1.17s/it]
Train Accuracy: 0.7812
saving det2.th
Validation Accuracy: 0.6996
!!!!!!!!!!!!!!EPOCH 3!!!!!!!!!!!!!
100%
391/391 [07:35<00:00, 1.17s/it]
Train Accuracy: 0.7892
Validation Accuracy: 0.7182
!!!!!!!!!!!!!!EPOCH 4!!!!!!!!!!!!!
100%
391/391 [07:38<00:00, 1.17s/it]
Train Accuracy: 0.7948
saving det4.th
Validation Accuracy: 0.7044
!!!!!!!!!!!!!!EPOCH 5!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.8035
Validation Accuracy: 0.7001
!!!!!!!!!!!!!!EPOCH 6!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.8092
saving det6.th
Validation Accuracy: 0.7106
!!!!!!!!!!!!!!EPOCH 7!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.8163
Validation Accuracy: 0.7104
!!!!!!!!!!!!!!EPOCH 8!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.8218
saving det8.th
Validation Accuracy: 0.7052
!!!!!!!!!!!!!!EPOCH 9!!!!!!!!!!!!!
100%
391/391 [07:36<00:00, 1.16s/it]
Train Accuracy: 0.8288
Validation Accuracy: 0.7007
!!!!!!!!!!!!!!EPOCH 10!!!!!!!!!!!!!
100%
391/391 [07:36<00:00, 1.17s/it]
Train Accuracy: 0.8364
saving det10.th
Validation Accuracy: 0.7109
!!!!!!!!!!!!!!EPOCH 11!!!!!!!!!!!!!
100%
391/391 [07:37<00:00, 1.17s/it]
Train Accuracy: 0.8435
Validation Accuracy: 0.7217
!!!!!!!!!!!!!!EPOCH 12!!!!!!!!!!!!!

# What to Test on CIFAR-10:

## Sai (Senior Dev):
* Convolutional Network Individually

resnet18 = models.resnet18(pretrained=False)

* Alternating Variant (num_blocks = 1, 2)

num_blocks = 1: model = AlternatingMixtureModel(num_blocks=1, depth = 2, att_heads=2) - Best Validation Accuracy => 72.17%



## Ian (Junior Dev):
* ViT Individually

vision_transformer = ViT(
    image_size = 32,
    patch_size = 32 // 16,
    num_classes = 10,
    dim = 2048,
    depth = 2,
    heads = 1,
    mlp_dim = 1024,
    dropout = 0.1,
    emb_dropout = 0.1
)

* Joint Variant (num_blocks = 1, 2)

num_blocks = 1: model = JointMixtureModel(num_blocks=1, depth = 2, att_heads=2)