In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from tqdm.notebook import tqdm
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

In [13]:
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

In [14]:
batch_size = 100
n_iters = 3000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)
device = "cuda"
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

valid_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [15]:
from einops import rearrange, repeat

MIN_NUM_PATCHES = 16

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img, mask = None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x, mask)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [16]:
model = ViT(
    image_size=28,
    patch_size=4,
    num_classes=10,
    dim=128,
    depth=64,
    heads=4,
    mlp_dim=128,
    channels=1,
).to("cuda")

In [17]:
torch.cuda.is_available()

True

In [18]:
model

ViT(
  (patch_to_embedding): Linear(in_features=16, out_features=128, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): Residual(
          (fn): PreNorm(
            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (fn): Attention(
              (to_qkv): Linear(in_features=128, out_features=768, bias=False)
              (to_out): Sequential(
                (0): Linear(in_features=256, out_features=128, bias=True)
                (1): Dropout(p=0.0, inplace=False)
              )
            )
          )
        )
        (1): Residual(
          (fn): PreNorm(
            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (fn): FeedForward(
              (net): Sequential(
                (0): Linear(in_features=128, out_features=128, bias=True)
                (1): GELU()
                (2): Dropout(p=0.0, inplace=False)
                

In [19]:
lr = 3e-4
gamma = 0.7
epochs = 10
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [20]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 1 - loss : 1.1405 - acc: 0.5914 - val_loss : 0.4144 - val_acc: 0.8737



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 2 - loss : 0.2740 - acc: 0.9177 - val_loss : 0.2085 - val_acc: 0.9363



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 3 - loss : 0.1847 - acc: 0.9444 - val_loss : 0.1812 - val_acc: 0.9470



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 4 - loss : 0.1458 - acc: 0.9555 - val_loss : 0.1398 - val_acc: 0.9588



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 5 - loss : 0.1187 - acc: 0.9638 - val_loss : 0.1271 - val_acc: 0.9629



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 6 - loss : 0.0998 - acc: 0.9700 - val_loss : 0.0965 - val_acc: 0.9706



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 7 - loss : 0.0919 - acc: 0.9724 - val_loss : 0.0984 - val_acc: 0.9707



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 8 - loss : 0.0828 - acc: 0.9743 - val_loss : 0.0887 - val_acc: 0.9729



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 9 - loss : 0.0749 - acc: 0.9772 - val_loss : 0.1066 - val_acc: 0.9689



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))


Epoch : 10 - loss : 0.0711 - acc: 0.9782 - val_loss : 0.0916 - val_acc: 0.9727



In [57]:
model.transformer.layers[0][0].fn.fn.to_qkv.weight

Parameter containing:
tensor([[ 0.0018,  0.0594,  0.0684,  ...,  0.0808, -0.0304,  0.0774],
        [-0.0115,  0.0919,  0.0489,  ..., -0.0375,  0.0328, -0.0561],
        [-0.0432, -0.1457, -0.0638,  ...,  0.0164,  0.0696, -0.0432],
        ...,
        [-0.0201,  0.0657, -0.0618,  ...,  0.0240, -0.0236, -0.0439],
        [-0.0355, -0.0727,  0.0708,  ..., -0.0118, -0.0200,  0.0528],
        [-0.0696, -0.0612,  0.0940,  ...,  0.0682,  0.0082,  0.0381]],
       device='cuda:0', requires_grad=True)

In [52]:
model.mlp_head[1]

Linear(in_features=128, out_features=10, bias=True)

## SAVING AND LOADING NETWORK

In [65]:
path = "./checkpoints/vision_transformer_pytorch/vit_torch"
torch.save(model.state_dict(), path)

In [66]:
model = ViT(
            image_size=28,
            patch_size=4,
            num_classes=10,
            dim=128,
            depth=64,
            heads=4,
            mlp_dim=128,
            channels=1,
).to("cuda")
model.load_state_dict(torch.load(path))

<All keys matched successfully>