In [12]:
import torch
from torchvision import datasets,transforms
from torch.utils.data import DataLoader





In [13]:
trainset_plain=datasets.CIFAR10(root='./data',train=True,download=True,transform=transforms.ToTensor())
loader=DataLoader(trainset_plain,batch_size=50000,shuffle=False,num_workers=2)
data_iter=iter(loader)
images,labels=next(data_iter)
mean=images.mean(dim=[0,2,3])
std=images.std(dim=[0,2,3])
print('mean',mean,'std',std)

mean tensor([0.4914, 0.4822, 0.4465]) std tensor([0.2470, 0.2435, 0.2616])


In [14]:
train_transform=transforms.Compose([
    transforms.RandomCrop(32,padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])

test_transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])

trainset=datasets.CIFAR10(root='./data',train=True,download=False,transform=train_transform)
testset=datasets.CIFAR10(root='./data',train=False,download=True,transform=test_transform)

train_loader=DataLoader(trainset,batch_size=128,shuffle=True,num_workers=2)
test_loader=DataLoader(testset,batch_size=128,shuffle=False,num_workers=2)

print(len(trainset),len(testset))


50000 10000


In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbed(nn.Module):
    def __init__(self,img_size=32,patch_size=4,in_chans=3,embed_dim=192):
        super().__init__()
        self.img_size=img_size
        self.patch_size=patch_size
        self.embed_dim=embed_dim
        self.grid_size=img_size//patch_size
        self.num_patches=self.grid_size*self.grid_size
        self.proj=nn.Conv2d(in_chans,embed_dim,patch_size,patch_size)
    def forward(self,x):
        x=self.proj(x)
        x=x.flatten(2).transpose(1,2) #[batch, seq_len=64, embed_dim=192] transformer expects this
        return x

class ResidualConnection(nn.Module):
    def __init__(self,fn):
        super().__init__()
        self.fn=fn
    def forward(self,x):
        return x+self.fn(x)

class MHSA(nn.Module):
    def __init__(self,embed_dim,num_heads=12,attn_dropout=0.1,proj_dropout=0.1): #according to wu et al
        super().__init__()
        self.num_heads=num_heads #12
        self.head_dim=embed_dim//num_heads
        self.scale=self.head_dim**-0.5
        self.qkv=nn.Linear(embed_dim,embed_dim*3) # qkv at a time
        self.attn_drop=nn.Dropout(attn_dropout)
        self.proj=nn.Linear(embed_dim,embed_dim)
        self.proj_drop=nn.Dropout(proj_dropout)
    def forward(self,x):
        B,N,C=x.shape
        qkv=self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4) #(B, N, 3*C) to (B, N, 3, num_heads, head_dim) to (3, B, num_heads, N, head_dim)

        q,k,v=qkv[0],qkv[1],qkv[2]
        attn=(q@k.transpose(-2,-1))*self.scale
        attn=attn.softmax(dim=-1)
        attn=self.attn_drop(attn)
        x=(attn@v).transpose(1,2).reshape(B,N,C) # back into embedding space
        x=self.proj(x)
        x=self.proj_drop(x)
        return x

class MLP(nn.Module):
    def __init__(self,embed_dim,mlp_ratio=4,drop=0.1):
        super().__init__()
        self.fc1=nn.Linear(embed_dim,embed_dim*mlp_ratio)
        self.fc2=nn.Linear(embed_dim*mlp_ratio,embed_dim)
        self.act=nn.GELU()
        self.drop=nn.Dropout(drop)
    def forward(self,x):
        x=self.fc1(x)
        x=self.act(x)
        x=self.drop(x)
        x=self.fc2(x)
        x=self.drop(x) # each dropout call in independent
        return x

class Block(nn.Module):
    def __init__(self,embed_dim,num_heads=12,mlp_ratio=4,drop=0.1):
        super().__init__()
        self.norm1=nn.LayerNorm(embed_dim)
        self.attn=MHSA(embed_dim,num_heads,drop,drop)
        self.res1=ResidualConnection(lambda x:self.attn(self.norm1(x)))
        self.norm2=nn.LayerNorm(embed_dim)
        self.mlp=MLP(embed_dim,mlp_ratio,drop)
        self.res2=ResidualConnection(lambda x:self.mlp(self.norm2(x)))
    def forward(self,x):
        x=self.res1(x)
        x=self.res2(x)
        return x

class ViT(nn.Module):
    def __init__(self,img_size=32,patch_size=4,in_chans=3,embed_dim=192,num_classes=10,depth=12,num_heads=12,mlp_ratio=4,drop=0.1):
        super().__init__()
        self.patch_embed=PatchEmbed(img_size,patch_size,in_chans,embed_dim)
        self.cls_token = nn.Parameter(torch.empty(1,1,embed_dim))
        nn.init.trunc_normal_(self.cls_token, mean=0.0, std=0.02) # as implemented in the ViT paper

        self.pos_embed = nn.Parameter(torch.empty(1, 1+self.patch_embed.num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        self.pos_drop=nn.Dropout(drop)
        self.blocks=nn.ModuleList([Block(embed_dim,num_heads,mlp_ratio,drop) for _ in range(depth)])
        self.norm=nn.LayerNorm(embed_dim)
        self.head=nn.Linear(embed_dim,num_classes)
    def forward(self,x):
        B=x.shape[0]
        x=self.patch_embed(x)
        cls_tokens=self.cls_token.expand(B,-1,-1) #(B, 1, embed_dim)
        x=torch.cat((cls_tokens,x),dim=1)
        x=x+self.pos_embed
        x=self.pos_drop(x)
        for blk in self.blocks:
            x=blk(x)
        x=self.norm(x)
        cls_out=x[:,0]
        x=self.head(cls_out)
        return x

model=ViT()
sample_images, _=next(iter(train_loader))
sample_images=sample_images[:2]
out=model(sample_images)
print(out.shape)


torch.Size([2, 10])


In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ViT().to(device)
scaler = torch.amp.GradScaler()

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self,smoothing=0.1):
        super().__init__()
        self.smoothing=smoothing
        self.confidence=1.0-smoothing
    def forward(self,pred,target):
        logprobs=nn.functional.log_softmax(pred,dim=-1)
        nll_loss=-logprobs.gather(dim=-1,index=target.unsqueeze(1)).squeeze(1)
        smooth_loss=-logprobs.mean(dim=-1)
        loss=self.confidence*nll_loss+self.smoothing*smooth_loss
        return loss.mean()

criterion=LabelSmoothingCrossEntropy(0.1)
optimizer=optim.AdamW(model.parameters(),lr=3e-4,weight_decay=0.05)
T_max=200
scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=T_max)
warmup_epochs=10
initial_lr=3e-5

epochs=50

for epoch in range(epochs):
    model.train()
    loop=tqdm(train_loader)
    running_loss=0
    correct=0
    total=0
    if epoch<warmup_epochs:
        lr=initial_lr+(3e-4-initial_lr)*(epoch/warmup_epochs)
        for g in optimizer.param_groups:
            g['lr']=lr
    for imgs,labels in loop:
        imgs,labels=imgs.to(device),labels.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs=model(imgs)
            loss=criterion(outputs,labels)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        scaler.step(optimizer)
        scaler.update()
        running_loss+=loss.item()*imgs.size(0)
        _,preds=outputs.max(1)
        correct+=preds.eq(labels).sum().item()
        total+=labels.size(0)
        loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
        loop.set_postfix(loss=running_loss/total,acc=100*correct/total)
    if epoch>=warmup_epochs:
        scheduler.step()
    model.eval()
    test_loss=0
    correct=0
    total=0
    with torch.no_grad():
        for imgs,labels in test_loader:
            imgs,labels=imgs.to(device),labels.to(device)
            with torch.cuda.amp.autocast():
                outputs=model(imgs)
                loss=criterion(outputs,labels)
            test_loss+=loss.item()*imgs.size(0)
            _,preds=outputs.max(1)
            correct+=preds.eq(labels).sum().item()
            total+=labels.size(0)
    print(f"Test Loss:{test_loss/total:.4f} Test Acc:{100*correct/total:.2f}%")


  with torch.cuda.amp.autocast():
Epoch [1/50]: 100%|██████████| 391/391 [00:44<00:00,  8.87it/s, acc=29, loss=1.99]
  with torch.cuda.amp.autocast():


Test Loss:1.9496 Test Acc:33.19%


Epoch [2/50]: 100%|██████████| 391/391 [00:43<00:00,  9.07it/s, acc=35.4, loss=1.86]


Test Loss:1.9301 Test Acc:33.97%


Epoch [3/50]: 100%|██████████| 391/391 [00:43<00:00,  8.96it/s, acc=41.2, loss=1.75]


Test Loss:1.7821 Test Acc:42.39%


Epoch [4/50]: 100%|██████████| 391/391 [00:43<00:00,  9.07it/s, acc=46, loss=1.65]


Test Loss:1.6820 Test Acc:45.41%


Epoch [5/50]: 100%|██████████| 391/391 [00:43<00:00,  9.04it/s, acc=49.8, loss=1.58]


Test Loss:1.5936 Test Acc:50.12%


Epoch [6/50]: 100%|██████████| 391/391 [00:43<00:00,  8.93it/s, acc=53.1, loss=1.51]


Test Loss:1.4433 Test Acc:56.31%


Epoch [7/50]: 100%|██████████| 391/391 [00:43<00:00,  9.07it/s, acc=56.2, loss=1.45]


Test Loss:1.4025 Test Acc:58.86%


Epoch [8/50]: 100%|██████████| 391/391 [00:43<00:00,  9.05it/s, acc=58.2, loss=1.4]


Test Loss:1.3900 Test Acc:59.18%


Epoch [9/50]: 100%|██████████| 391/391 [00:43<00:00,  9.01it/s, acc=60.6, loss=1.36]


Test Loss:1.3058 Test Acc:63.09%


Epoch [10/50]: 100%|██████████| 391/391 [00:43<00:00,  9.02it/s, acc=62.3, loss=1.32]


Test Loss:1.2610 Test Acc:64.29%


Epoch [11/50]: 100%|██████████| 391/391 [00:44<00:00,  8.88it/s, acc=64.7, loss=1.27]


Test Loss:1.2414 Test Acc:66.14%


Epoch [12/50]: 100%|██████████| 391/391 [00:43<00:00,  9.03it/s, acc=66.7, loss=1.23]


Test Loss:1.2131 Test Acc:68.81%


Epoch [13/50]: 100%|██████████| 391/391 [00:43<00:00,  8.93it/s, acc=68, loss=1.2]


Test Loss:1.1725 Test Acc:69.35%


Epoch [14/50]: 100%|██████████| 391/391 [00:43<00:00,  8.93it/s, acc=69.5, loss=1.17]


Test Loss:1.1305 Test Acc:71.51%


Epoch [15/50]: 100%|██████████| 391/391 [00:43<00:00,  9.06it/s, acc=70.9, loss=1.14]


Test Loss:1.1059 Test Acc:72.07%


Epoch [16/50]: 100%|██████████| 391/391 [00:44<00:00,  8.85it/s, acc=71.9, loss=1.12]


Test Loss:1.0949 Test Acc:73.18%


Epoch [17/50]: 100%|██████████| 391/391 [00:43<00:00,  9.01it/s, acc=73, loss=1.09]


Test Loss:1.0881 Test Acc:73.78%


Epoch [18/50]: 100%|██████████| 391/391 [00:43<00:00,  9.03it/s, acc=73.9, loss=1.08]


Test Loss:1.0761 Test Acc:73.93%


Epoch [19/50]: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s, acc=74.9, loss=1.06]


Test Loss:1.0376 Test Acc:75.69%


Epoch [20/50]: 100%|██████████| 391/391 [00:43<00:00,  9.07it/s, acc=75.8, loss=1.04]


Test Loss:1.0577 Test Acc:74.90%


Epoch [21/50]: 100%|██████████| 391/391 [00:43<00:00,  9.03it/s, acc=76.3, loss=1.03]


Test Loss:1.0802 Test Acc:74.30%


Epoch [22/50]: 100%|██████████| 391/391 [00:43<00:00,  9.04it/s, acc=76.9, loss=1.01]


Test Loss:0.9972 Test Acc:77.77%


Epoch [23/50]: 100%|██████████| 391/391 [00:43<00:00,  9.05it/s, acc=77.8, loss=0.993]


Test Loss:1.0112 Test Acc:77.48%


Epoch [24/50]: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s, acc=78.3, loss=0.976]


Test Loss:1.0089 Test Acc:77.17%


Epoch [25/50]: 100%|██████████| 391/391 [00:44<00:00,  8.87it/s, acc=78.7, loss=0.969]


Test Loss:1.0342 Test Acc:76.03%


Epoch [26/50]: 100%|██████████| 391/391 [00:43<00:00,  8.96it/s, acc=79.4, loss=0.955]


Test Loss:0.9989 Test Acc:78.03%


Epoch [27/50]: 100%|██████████| 391/391 [00:44<00:00,  8.89it/s, acc=80, loss=0.946]


Test Loss:1.0390 Test Acc:76.76%


Epoch [28/50]: 100%|██████████| 391/391 [00:42<00:00,  9.10it/s, acc=80.4, loss=0.934]


Test Loss:0.9946 Test Acc:78.62%


Epoch [29/50]: 100%|██████████| 391/391 [00:43<00:00,  8.99it/s, acc=81, loss=0.92]


Test Loss:0.9784 Test Acc:79.06%


Epoch [30/50]: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s, acc=81.5, loss=0.91]


Test Loss:0.9733 Test Acc:79.00%


Epoch [31/50]: 100%|██████████| 391/391 [00:44<00:00,  8.87it/s, acc=81.9, loss=0.903]


Test Loss:0.9617 Test Acc:79.69%


Epoch [32/50]: 100%|██████████| 391/391 [00:43<00:00,  8.93it/s, acc=82.3, loss=0.893]


Test Loss:0.9440 Test Acc:80.56%


Epoch [33/50]: 100%|██████████| 391/391 [00:44<00:00,  8.81it/s, acc=83, loss=0.879]


Test Loss:0.9736 Test Acc:79.35%


Epoch [34/50]: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s, acc=83.3, loss=0.869]


Test Loss:0.9470 Test Acc:80.65%


Epoch [35/50]: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s, acc=83.7, loss=0.861]


Test Loss:0.9461 Test Acc:80.86%


Epoch [36/50]: 100%|██████████| 391/391 [00:44<00:00,  8.72it/s, acc=84.3, loss=0.85]


Test Loss:0.9586 Test Acc:79.70%


Epoch [37/50]: 100%|██████████| 391/391 [00:44<00:00,  8.69it/s, acc=84.4, loss=0.845]


Test Loss:0.9392 Test Acc:81.19%


Epoch [38/50]: 100%|██████████| 391/391 [00:44<00:00,  8.80it/s, acc=85, loss=0.831]


Test Loss:0.9761 Test Acc:79.99%


Epoch [39/50]: 100%|██████████| 391/391 [00:44<00:00,  8.84it/s, acc=85.3, loss=0.823]


Test Loss:0.9391 Test Acc:81.13%


Epoch [40/50]: 100%|██████████| 391/391 [00:44<00:00,  8.69it/s, acc=85.7, loss=0.817]


Test Loss:0.9182 Test Acc:82.08%


Epoch [41/50]: 100%|██████████| 391/391 [00:44<00:00,  8.69it/s, acc=86.2, loss=0.805]


Test Loss:0.9349 Test Acc:81.39%


Epoch [42/50]: 100%|██████████| 391/391 [00:44<00:00,  8.69it/s, acc=86.5, loss=0.797]


Test Loss:0.9233 Test Acc:82.34%


Epoch [43/50]: 100%|██████████| 391/391 [00:44<00:00,  8.80it/s, acc=87, loss=0.791]


Test Loss:0.9502 Test Acc:81.17%


Epoch [44/50]: 100%|██████████| 391/391 [00:45<00:00,  8.61it/s, acc=87.1, loss=0.784]


Test Loss:0.9506 Test Acc:81.08%


Epoch [45/50]: 100%|██████████| 391/391 [00:44<00:00,  8.73it/s, acc=87.6, loss=0.775]


Test Loss:0.9466 Test Acc:81.45%


Epoch [46/50]: 100%|██████████| 391/391 [00:44<00:00,  8.84it/s, acc=87.9, loss=0.769]


Test Loss:0.9525 Test Acc:81.18%


Epoch [47/50]: 100%|██████████| 391/391 [00:44<00:00,  8.74it/s, acc=88.5, loss=0.758]


Test Loss:0.9203 Test Acc:82.73%


Epoch [48/50]: 100%|██████████| 391/391 [00:46<00:00,  8.47it/s, acc=88.5, loss=0.755]


Test Loss:0.9242 Test Acc:82.56%


Epoch [49/50]: 100%|██████████| 391/391 [00:44<00:00,  8.79it/s, acc=88.9, loss=0.747]


Test Loss:0.9354 Test Acc:82.38%


Epoch [50/50]: 100%|██████████| 391/391 [00:44<00:00,  8.81it/s, acc=89, loss=0.745]


Test Loss:0.9368 Test Acc:82.45%
