## TTIC 31230 - Fundamentals of Deep Learning Final Project
Rui Wang

My project submission and extension will be based on the [visual transformer (ViT) paper](https://arxiv.org/abs/2010.11929) under the proposed extension to incorporate convolutions in some way to the Vision Transformer, in the interest of bridging the performance and accuracy that convolution networks bring, and the efficiency of computational resources that the original Vision Transformers enjoy.

#### Preliminaries

In [None]:
!pip install einops



In [None]:
bs = 512
size = 32


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import numpy as np
import math

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [None]:
trainset = torchvision.datasets.CIFAR10("data", download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=2)

Files already downloaded and verified


In [None]:
trainset

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               RandomCrop(size=(32, 32), padding=4)
               Resize(size=32, interpolation=bilinear, max_size=None, antialias=warn)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
           )

In [None]:
trainset.data.shape

(50000, 32, 32, 3)

In [None]:
testset = torchvision.datasets.CIFAR10("data", train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified


In [None]:
testset

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=32, interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
           )

In [None]:
testset.data.shape

(10000, 32, 32, 3)

In [None]:
device = 'cuda'
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=False)


def train(net, epoch):
    print('\nEpoch: %d' % epoch)
    optimizer = optim.Adam(net.parameters(), lr=1e-4)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Train with amp
        with torch.cuda.amp.autocast(enabled=False):
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx == len(trainloader) - 1:
          print('train', batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return train_loss/(batch_idx+1)


def test(net, epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if batch_idx == len(testloader) - 1:
              print('test', batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    acc = 100.*correct/total

    return test_loss, acc

In [None]:
def pair(t):
  return t if isinstance(t, tuple) else (t, t)

#### Vision Transformer
This is an implementation of the original Vision Transformer in PyTorch found [here](https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/models/vit.py)

In [None]:
class PreNorm(nn.Module):
  def __init__(self, dim, fn):
    super().__init__()
    self.norm = nn.LayerNorm(dim)
    self.fn = fn

  def forward(self, x, **kwargs):
    return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
  def __init__(self, dim, hidden_dim, dropout = 0.):
    super().__init__()
    self.net = nn.Sequential(
      nn.Linear(dim, hidden_dim),
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(hidden_dim, dim),
      nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.net(x)

In [None]:
class Attention(nn.Module):
  def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
    super().__init__()
    inner_dim = dim_head * heads
    project_out = not (heads == 1 and dim_head == dim)

    self.heads = heads
    self.scale = dim_head ** -0.5

    self.attend = nn.Softmax(dim = -1)
    self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

    self.to_out = nn.Sequential(
      nn.Linear(inner_dim, dim),
      nn.Dropout(dropout)
      ) if project_out else nn.Identity()
  def forward(self, x):
    qkv = self.to_qkv(x).chunk(3, dim = -1)
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

    dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

    attn = self.attend(dots)

    out = torch.matmul(attn, v)
    out = rearrange(out, 'b h n d -> b n (h d)')
    return self.to_out(out)

In [None]:
class Transformer(nn.Module):
  def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
    super().__init__()
    self.layers = nn.ModuleList([])
    for _ in range(depth):
      self.layers.append(nn.ModuleList([
        PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
        PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
        ]))
  def forward(self, x):
    for attn, ff in self.layers:
      x = attn(x) + x
      x = ff(x) + x
      return x

In [None]:
class ViT(nn.Module):
  def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
    super().__init__()
    image_height, image_width = pair(image_size)
    patch_height, patch_width = pair(patch_size)

    assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

    num_patches = (image_height // patch_height) * (image_width // patch_width)
    patch_dim = channels * patch_height * patch_width
    assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

    self.to_patch_embedding = nn.Sequential(
      Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), # from 512,3,32,32 -> 512, 64, 48
      nn.Linear(patch_dim, dim),
      )

    self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
    self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
    self.dropout = nn.Dropout(emb_dropout)

    self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

    self.pool = pool
    self.to_latent = nn.Identity()

    self.mlp_head = nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, num_classes)
    )

  def forward(self, img):
    x = self.to_patch_embedding(img)
    b, n, _ = x.shape

    cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
    x = torch.cat((cls_tokens, x), dim=1)
    x += self.pos_embedding[:, :(n + 1)]
    x = self.dropout(x)

    x = self.transformer(x)

    x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

    x = self.to_latent(x)
    return self.mlp_head(x)

##### Results with ViT

In [None]:
net = ViT(
    image_size = size,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)

In [None]:
net = torch.nn.DataParallel(net)

In [None]:
for i in range(25):
  train(net, i)
  test(net, i)


Epoch: 0
train 97 98 Loss: 2.074 | Acc: 22.938% (11469/50000)
test 99 100 Loss: 1.830 | Acc: 33.630% (3363/10000)

Epoch: 1
train 97 98 Loss: 1.830 | Acc: 33.116% (16558/50000)
test 99 100 Loss: 1.683 | Acc: 39.430% (3943/10000)

Epoch: 2
train 97 98 Loss: 1.732 | Acc: 36.692% (18346/50000)
test 99 100 Loss: 1.628 | Acc: 41.600% (4160/10000)

Epoch: 3
train 97 98 Loss: 1.679 | Acc: 38.794% (19397/50000)
test 99 100 Loss: 1.570 | Acc: 43.980% (4398/10000)

Epoch: 4
train 97 98 Loss: 1.645 | Acc: 40.314% (20157/50000)
test 99 100 Loss: 1.555 | Acc: 44.570% (4457/10000)

Epoch: 5
train 97 98 Loss: 1.623 | Acc: 41.166% (20583/50000)
test 99 100 Loss: 1.518 | Acc: 46.090% (4609/10000)

Epoch: 6
train 97 98 Loss: 1.601 | Acc: 41.976% (20988/50000)
test 99 100 Loss: 1.511 | Acc: 46.660% (4666/10000)

Epoch: 7
train 97 98 Loss: 1.583 | Acc: 42.636% (21318/50000)
test 99 100 Loss: 1.487 | Acc: 47.390% (4739/10000)

Epoch: 8
train 97 98 Loss: 1.570 | Acc: 43.128% (21564/50000)
test 99 100 Loss:

#### Adding Convolutions - 1: On the patches
Instead of having a Linear Projection from the patches being fed into the Transformer encoder, we can convolve over the patches then feed them into the encoder. We can simulate this by having both kernel size and stride be equal to the size of the patches.

In [None]:
class ViTConv1(nn.Module):
  def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
    super().__init__()
    image_height, image_width = pair(image_size)
    patch_height, patch_width = pair(patch_size)

    assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

    num_patches = (image_height // patch_height) * (image_width // patch_width) # 64
    patch_dim = channels * patch_height * patch_width # 48
    assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

    self.to_patch_embedding = nn.Sequential(
      nn.Conv2d(3, 3, kernel_size=patch_height, stride=patch_height), # -> 512, 3, 8, 8
      Rearrange('b c h w -> b (h w) c'), # -> 512 64 3
      nn.Linear(3, dim),
      )

    self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
    self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
    self.dropout = nn.Dropout(emb_dropout)

    self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

    self.pool = pool
    self.to_latent = nn.Identity()

    self.mlp_head = nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, num_classes)
    )

  def forward(self, img):
    x = self.to_patch_embedding(img)
    b, n, _ = x.shape
    cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
    x = torch.cat((cls_tokens, x), dim=1)
    x += self.pos_embedding[:, :(n + 1)]
    x = self.dropout(x)

    x = self.transformer(x)

    x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

    x = self.to_latent(x)
    return self.mlp_head(x)

In [None]:
net2 = ViTConv1(
    image_size = size,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,W
    dropout = 0.1,
    emb_dropout = 0.1
)
net2 = torch.nn.DataParallel(net2)

In [None]:
for i in range(25):
  train(net2, i)
  test(net2, i)


Epoch: 0
train 97 98 Loss: 2.126 | Acc: 20.312% (10156/50000)
test 99 100 Loss: 1.916 | Acc: 29.790% (2979/10000)

Epoch: 1
train 97 98 Loss: 1.934 | Acc: 28.966% (14483/50000)
test 99 100 Loss: 1.818 | Acc: 34.740% (3474/10000)

Epoch: 2
train 97 98 Loss: 1.858 | Acc: 32.516% (16258/50000)
test 99 100 Loss: 1.745 | Acc: 37.620% (3762/10000)

Epoch: 3
train 97 98 Loss: 1.798 | Acc: 34.682% (17341/50000)
test 99 100 Loss: 1.700 | Acc: 38.910% (3891/10000)

Epoch: 4
train 97 98 Loss: 1.755 | Acc: 36.468% (18234/50000)
test 99 100 Loss: 1.652 | Acc: 40.520% (4052/10000)

Epoch: 5
train 97 98 Loss: 1.727 | Acc: 37.000% (18500/50000)
test 99 100 Loss: 1.634 | Acc: 41.150% (4115/10000)

Epoch: 6
train 97 98 Loss: 1.705 | Acc: 38.270% (19135/50000)
test 99 100 Loss: 1.616 | Acc: 42.020% (4202/10000)

Epoch: 7
train 97 98 Loss: 1.685 | Acc: 39.106% (19553/50000)
test 99 100 Loss: 1.581 | Acc: 42.790% (4279/10000)

Epoch: 8
train 97 98 Loss: 1.666 | Acc: 39.432% (19716/50000)
test 99 100 Loss:

#### Adding Convolutions - 2:
We can expand on the idea of applying convolutions on the raw image and then applying some embedding/tokenization process. Here we will replace the ViT patch embeddings module with a new module that convolves the images and pools them. We then split this convolved form as part of our embeddings.

We also modify the MLP/Feed Forward network in the original ViT architecture by incorporating both point-wise and depth-wise convolutions, as well as use a form of attention on the class token from the vision transformer at the end of the transformer encoder.

In [None]:
class ImageToEmbedding(nn.Module):
  def __init__(self, in_chans=3, out_chans=64, kernel_size=7, stride=2):
    super(ImageToEmbedding, self).__init__()
    self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_size, stride=stride,
                          padding=kernel_size // 2, bias=False)
    self.bn = nn.BatchNorm2d(out_chans)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.maxpool(x)
    return x

In [None]:
class FeedForward(nn.Module):
  def __init__(self, dim, hidden_dim, dropout = 0.):
    super().__init__()
    self.net = nn.Sequential(
      nn.Linear(dim, hidden_dim),
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(hidden_dim, dim),
      nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.net(x)

In [None]:
class ConvFeedForward(nn.Module):
  def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., kernel_size=3):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features
    # pointwise
    self.conv1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1, padding=0)
    # depthwise
    self.conv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=hidden_features)
    # pointwise
    self.conv3 = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1, padding=0)
    self.act = nn.GELU()

    self.bn1 = nn.BatchNorm2d(hidden_features)
    self.bn2 = nn.BatchNorm2d(hidden_features)
    self.bn3 = nn.BatchNorm2d(out_features)

  def forward(self, x):
    b, n, k = x.size()
    cls_token, tokens = torch.split(x, [1, n - 1], dim=1)
    x = tokens.reshape(b, int(math.sqrt(n - 1)), int(math.sqrt(n - 1)), k).permute(0, 3, 1, 2)

    x = self.conv1(x)
    x = self.bn1(x)
    x = self.act(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.act(x)
    x = self.conv3(x)
    x = self.bn3(x)

    tokens = x.flatten(2).permute(0, 2, 1)
    out = torch.cat((cls_token, tokens), dim=1)
    return out

In [None]:
class Attention(nn.Module):
  def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.):
    super().__init__()
    self.num_heads = num_heads
    head_dim = dim // num_heads

    self.scale = head_dim ** -0.5

    self.qkv = nn.Linear(dim, dim * 3, bias=False)
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)
    self.attention_map = None

  def forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]
    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    # self.attention_map = attn
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

In [None]:
class LayerAttention(Attention):
  def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.):
    super(LayerAttention, self).__init__(dim, num_heads, attn_drop, proj_drop)
    self.dim = dim

  def forward(self, x):
    q_weight = self.qkv.weight[:self.dim, :]
    q_bias = None
    kv_weight = self.qkv.weight[self.dim:, :]
    kv_bias = None

    B, N, C = x.shape
    _, last_token = torch.split(x, [N-1, 1], dim=1)

    q = F.linear(last_token, q_weight, q_bias).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
    kv = F.linear(x, kv_weight, kv_bias).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    k, v = kv[0], kv[1]

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    # self.attention_map = attn
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, 1, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x



In [None]:
class ConvTransformerBlock(nn.Module):
  def __init__(self, dim, num_heads, mlp_ratio=4, drop=0., attn_drop=0., kernel_size=3, use_cff=True):
    super().__init__()
    self.norm1 = nn.LayerNorm(dim)
    self.norm2 = nn.LayerNorm(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.use_cff = use_cff
    if self.use_cff:
      self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
      self.ff = ConvFeedForward(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop, kernel_size=kernel_size)
    else:  # use layer attention
      self.attn = LayerAttention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
      self.ff = FeedForward(dim, mlp_hidden_dim, dropout=drop)

  def forward(self, x):
    if self.use_cff:
      x = x + self.attn(self.norm1(x))
      x = x + self.ff(self.norm2(x))
      return x, x[:, 0]
    else:
      _, last_token = torch.split(x, [x.size(1)-1, 1], dim=1)
      x = last_token + self.attn(self.norm1(x))
      x = x + self.ff(self.norm2(x))
      return x

In [None]:
class HybridEmbed(nn.Module):
  def __init__(self, backbone, img_size=32, patch_size=16, feature_size=None, in_chans=3, embed_dim=512):
    super().__init__()
    assert isinstance(backbone, nn.Module)
    img_size = pair(img_size)
    self.img_size = img_size
    self.backbone = backbone
    if feature_size is None:
      with torch.no_grad():
        training = backbone.training
        if training:
          backbone.eval()
        o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
        if isinstance(o, (list, tuple)):
          o = o[-1]
        feature_size = o.shape[-2:]
        feature_dim = o.shape[1]
        backbone.train(training)
    else:
      feature_size = pair(feature_size)
      feature_dim = self.backbone.feature_info.channels()[-1]

    self.num_patches = (feature_size[0] // patch_size) * (feature_size[1] // patch_size)
    self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)

  def forward(self, x):
    x = self.backbone(x)
    if isinstance(x, (list, tuple)):
      x = x[-1]
    x = self.proj(x).flatten(2).transpose(1, 2)
    return x

In [None]:
class ViTConv2(nn.Module):
  def __init__(self, img_size=32, patch_size=16, in_chans=3, num_classes=10, embed_dim=512, depth=6, num_heads=8,
               mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., cff_local_size=3):
    super().__init__()
    self.num_classes = num_classes
    self.num_features = self.embed_dim = embed_dim

    self.i2thead = ImageToEmbedding()
    self.i2t = HybridEmbed(self.i2thead, img_size=img_size, patch_size=patch_size)

    num_patches = self.i2t.num_patches
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
    self.pos_drop = nn.Dropout(p=drop_rate)

    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
    self.blocks = nn.ModuleList([
        ConvTransformerBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                                 drop=drop_rate, attn_drop=attn_drop_rate,
                                 kernel_size=cff_local_size)
        for i in range(depth)])

    self.layerattention = ConvTransformerBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                                               drop=drop_rate, attn_drop=attn_drop_rate, use_cff=False)
    self.pos_layer_embed = nn.Parameter(torch.zeros(1, depth, embed_dim))

    self.norm = nn.LayerNorm(embed_dim)


    # Classifier head
    self.mlp_head = nn.Linear(embed_dim, num_classes)

  def forward_features(self, x):
    B = x.shape[0]
    x = self.i2t(x)

    cls_tokens = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)
    x = x + self.pos_embed
    x = self.pos_drop(x)

    cls_token_list = []
    for blk in self.blocks:
      x, curr_cls_token = blk(x)
      cls_token_list.append(curr_cls_token)

    all_cls_token = torch.stack(cls_token_list, dim=1)
    all_cls_token = all_cls_token + self.pos_layer_embed

    last_cls_token = self.layerattention(all_cls_token)
    last_cls_token = self.norm(last_cls_token)

    return last_cls_token.view(B, -1)

  def forward(self, x):
    x = self.forward_features(x)
    x = self.mlp_head(x)
    return x

In [None]:
net3 = ViTConv2(img_size=size, patch_size=4, drop_rate = 0.1)
net3 = torch.nn.DataParallel(net3)

In [None]:
for i in range(25):
  train(net3, i)
  test(net3, i)


Epoch: 0
train 97 98 Loss: 1.650 | Acc: 39.978% (19989/50000)
test 99 100 Loss: 1.336 | Acc: 52.210% (5221/10000)

Epoch: 1
train 97 98 Loss: 1.333 | Acc: 51.586% (25793/50000)
test 99 100 Loss: 1.195 | Acc: 56.740% (5674/10000)

Epoch: 2
train 97 98 Loss: 1.212 | Acc: 56.600% (28300/50000)
test 99 100 Loss: 1.066 | Acc: 62.070% (6207/10000)

Epoch: 3
train 97 98 Loss: 1.129 | Acc: 59.644% (29822/50000)
test 99 100 Loss: 1.031 | Acc: 63.770% (6377/10000)

Epoch: 4
train 97 98 Loss: 1.077 | Acc: 61.148% (30574/50000)
test 99 100 Loss: 0.969 | Acc: 65.620% (6562/10000)

Epoch: 5
train 97 98 Loss: 1.020 | Acc: 63.546% (31773/50000)
test 99 100 Loss: 0.951 | Acc: 66.470% (6647/10000)

Epoch: 6
train 97 98 Loss: 0.970 | Acc: 65.470% (32735/50000)
test 99 100 Loss: 0.940 | Acc: 67.090% (6709/10000)

Epoch: 7
train 97 98 Loss: 0.933 | Acc: 66.740% (33370/50000)
test 99 100 Loss: 0.871 | Acc: 68.730% (6873/10000)

Epoch: 8
train 97 98 Loss: 0.892 | Acc: 68.408% (34204/50000)
test 99 100 Loss: