In [1]:
import torch

from torch import nn
from torch import optim
from torch.utils import data
from torchvision import datasets
from torchvision import transforms

from vit_pytorch.deepvit import DeepViT
from vit_pytorch.vit import ViT
from vit_pytorch.mae import MAE

class ReplicateChannels(object):

    def __init__(self):
        pass

    def __call__(self, sample):
        if sample.shape[0] == 1:
            sample = torch.tile(sample, dims=(3, 1, 1))
        return sample
    
trn_transform = transforms.Compose([
    transforms.Resize(300),
    transforms.RandomCrop(256, 256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    ReplicateChannels(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
tst_transform = transforms.Compose([
    transforms.Resize(300),
    transforms.CenterCrop((256, 256)),
    transforms.ToTensor(),
    ReplicateChannels(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# n_classes = 102
# trn_set = datasets.Flowers102('./data/',
#                               split='train',
#                               transform=trn_transform,
#                               download=True)
# tst_set = datasets.Flowers102('./data/',
#                               split='test',
#                               transform=tst_transform,
#                               download=True)

n_classes = 256
trn_set = datasets.Caltech256('./data/',
                              transform=trn_transform,
                              download=True)

# n_classes = 365
# trn_set = datasets.Places365('./data/',
#                              split='train-standard',
#                              small=True,
#                              transform=trn_transform,
#                              download=False)
# tst_set = datasets.Places365('./data/',
#                              split='val',
#                              small=True,
#                              transform=tst_transform,
#                              download=False)

# n_classes = 101
# trn_set = datasets.Food101('./data/',
#                            split='train',
#                            transform=trn_transform,
#                            download=True)
# tst_set = datasets.Food101('./data/',
#                            split='test',
#                            transform=tst_transform,
#                            download=True)

# batch_size = 384
batch_size = 768
# batch_size = 512
trn_loader = data.DataLoader(trn_set,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=16)
# tst_loader = data.DataLoader(tst_set,
#                              batch_size=batch_size,
#                              shuffle=False,
#                              num_workers=16)

Files already downloaded and verified


In [2]:
# Instantiating model.
# vit = DeepViT(
#     image_size = 256,
#     patch_size = 32,
#     num_classes = 102,
# #     num_classes = 101,
# #     num_classes = 365,
#     dim = 512,
#     depth = 6,
#     heads = 16,
#     mlp_dim = 2048,
#     dropout = 0.1,
#     emb_dropout = 0.1
# ).cuda()

vit = ViT(
    image_size=256,
    patch_size=32,
    num_classes=n_classes,
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048
).cuda()

mae = MAE(
    encoder=vit,
    masking_ratio=0.75,   # the paper recommended 75% masked patches
    decoder_dim=512,      # paper showed good results with just 512
    decoder_depth=8       # anywhere from 1 to 8
).cuda()
mae = nn.DataParallel(mae)

print(vit)
print(mae)

n_epochs = 200

# Optimizer and scheduler.
opt = optim.Adam(mae.parameters(), lr=0.005, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(opt, step_size=n_epochs // 20, gamma=0.9)

# Iterating over epochs.
for epoch in range(n_epochs):
    
    print('Epoch %d/%d' % (epoch + 1, n_epochs))
    
    # Iterating over batches.
    for i, batch in enumerate(trn_loader):
        
        img, lab = batch
        
        img = img.cuda()
        lab = lab.cuda()
        
        opt.zero_grad()
        
        loss = mae(img).mean()
        
        loss.backward()
        
        opt.step()
        
        print('  It %d/%d, loss %.4f' % (i + 1, len(trn_loader), loss.item()))
        
    scheduler.step()

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): Linear(in_features=3072, out_features=1024, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (dropout): Dropout(p=0.0, inplace=False)
            (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=1024, out_features



  It 1/40, loss 1.6867
  It 2/40, loss 4.2132
  It 3/40, loss 13.0971
  It 4/40, loss 347.8084
  It 5/40, loss 184.3882
  It 6/40, loss 533.8763
  It 7/40, loss 20188.9180
  It 8/40, loss 356.6282
  It 9/40, loss 499.9250
  It 10/40, loss 263.2239
  It 11/40, loss 59.9980
  It 12/40, loss 74.0893
  It 13/40, loss 518.2119
  It 14/40, loss 378.7319
  It 15/40, loss 3395.3826
  It 16/40, loss 1150.1802
  It 17/40, loss 436.6139
  It 18/40, loss 1280.1106
  It 19/40, loss 3251.7148
  It 20/40, loss 2563.8992
  It 21/40, loss 680.8047
  It 22/40, loss 891.2957
  It 23/40, loss 9296.8730
  It 24/40, loss 1146.0613
  It 25/40, loss 648.8156
  It 26/40, loss 278.0064
  It 27/40, loss 1437.6039
  It 28/40, loss 2758.1030
  It 29/40, loss 3520.9580
  It 30/40, loss 494.9175
  It 31/40, loss 629.9749
  It 32/40, loss 1141.4482
  It 33/40, loss 1946.7241
  It 34/40, loss 696.6232
  It 35/40, loss 335.6487
  It 36/40, loss 445.6132
  It 37/40, loss 301.3249
  It 38/40, loss 237.4362
  It 39/40, lo