In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torch.optim import AdamW
import torchvision.transforms.v2 as v2
from torchvision.datasets import CIFAR10
from transformers import get_cosine_schedule_with_warmup
from torchvision.transforms import RandAugment

In [30]:
device ="cuda" if torch.cuda.is_available() else 'cpu'

In [31]:
train_transformation = transforms.Compose([
    transforms.RandomCrop(32, padding =4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    #RandAugment(num_ops=3, magnitude=10),
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447),(0.247, 0.243, 0.261))
])

test_transformation = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447),(0.247, 0.243, 0.261))
])

In [None]:
batch_size = 128
num_classes=10
img_size = 32
patch_size = 4
in_chans = 3
embed_dim = 384
depth = 6
num_heads = 12
mlp_ratio = 4.0
drop_rate = 0.1
epochs = 200
lr = 9e-4
weight_decay = 0.1

In [33]:
cutmix = v2.CutMix(num_classes=num_classes)
mixup = v2.MixUp(num_classes=num_classes)
mix_transform = v2.RandomChoice([cutmix, mixup])


In [34]:
train_dataset = CIFAR10(root='./data', train=True, download=True, transform= train_transformation)
test_dataset  = CIFAR10(root='./data', train=False, download=True, transform=test_transformation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [35]:
steps_per_epoch = len(train_loader)

In [36]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


In [37]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [38]:
class TokenMixer(nn.Module):
    def __init__(self, embed_dim, mix_type='avg'):
        super().__init__()
        if mix_type == 'conv':
            self.mix = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=1, groups=embed_dim)
        else:
            self.mix = lambda x: x.mean(dim=1, keepdim=True).repeat(1, x.size(1), 1)
        

    def forward(self, x):
        if hasattr(self, 'mix'):
            x_mixed = self.mix(x.transpose(1,2)).transpose(1,2) if isinstance(self.mix, nn.Conv1d) else self.mix(x)
            return x + x_mixed
        else:
            return x

In [39]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.0, token_mixing=False):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True, dropout=drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim*mlp_ratio), drop=drop)
        self.token_mixer = TokenMixer(dim) if token_mixing else nn.Identity() 

    def forward(self, x):
        x_attn = self.norm1(x)
        attn_out, _ = self.attn(x_attn, x_attn, x_attn, need_weights=False)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        x = self.token_mixer(x)
        return x

In [40]:
class ViT(nn.Module):
    def __init__(self, *, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=256, depth=6, num_heads=8, mlp_ratio=4.0, drop_rate=0.0):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,
                                      in_chans=in_chans, embed_dim=embed_dim)
        n_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.blocks = nn.ModuleList([
            TransformerBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls = x[:, 0]
        out = self.head(cls)
        return out

In [41]:
model = ViT(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes,
            embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate)
model = torch.nn.DataParallel(model)
model = model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
total_steps =epochs*steps_per_epoch
warmup_steps = int(0.15 * total_steps)  
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
)

In [42]:
def train_one_epoch(epoch, lr_schduler):
    model.train()
    running_loss = 0.0
    total = 0
    correct = 0
    for i, (imgs, targets) in enumerate(train_loader):
        imgs, targets = imgs.to(device), targets.to(device)
        #imgs,targets = mix_transform(imgs, targets)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        if targets.ndim == 1:
            _, preds = outputs.max(1)
            correct += preds.eq(targets).sum().item()
        else:
            preds = outputs.argmax(dim=1)
            targets_indices = targets.argmax(dim=1)
            correct += preds.eq(targets_indices).sum().item()
        total += imgs.size(0)  
    acc = 100.0 * correct / total
    lr_scheduler.step()
    print(f"Epoch {epoch}: Train Loss {running_loss/total:.4f} Acc {acc:.2f}%")

In [43]:
def evaluate():
    model.eval()
    total = 0
    correct = 0
    loss_sum = 0.0
    with torch.no_grad():
        for imgs, targets in test_loader:
            imgs, targets = imgs.to(device), targets.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, targets)
            loss_sum += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            total += targets.size(0)
            correct += preds.eq(targets).sum().item()
    acc = 100.0 * correct / total
    print(f"Test Loss {loss_sum/total:.4f} Acc {acc:.2f}%")
    return acc

In [44]:
best_acc = 0.
for epoch in range(1, epochs+1):
    train_one_epoch(epoch,lr_scheduler)
    acc = evaluate()
    if acc > best_acc:
        best_acc = acc

        torch.save(model.state_dict(), "best_vit_cifar10.pth")
    print(f"Epoch {epoch} complete...")
print("Best test acc:", best_acc)

Epoch 1: Train Loss 2.3543 Acc 12.94%
Test Loss 2.3364 Acc 13.24%
Epoch 1 complete...
Epoch 2: Train Loss 2.3168 Acc 13.48%
Test Loss 2.2326 Acc 17.30%
Epoch 2 complete...
Epoch 3: Train Loss 2.2648 Acc 15.46%
Test Loss 2.1536 Acc 22.48%
Epoch 3 complete...
Epoch 4: Train Loss 2.2287 Acc 17.84%
Test Loss 2.1114 Acc 23.28%
Epoch 4 complete...
Epoch 5: Train Loss 2.2080 Acc 18.57%
Test Loss 2.0864 Acc 24.97%
Epoch 5 complete...
Epoch 6: Train Loss 2.1901 Acc 19.36%
Test Loss 2.0691 Acc 26.24%
Epoch 6 complete...
Epoch 7: Train Loss 2.1773 Acc 20.35%
Test Loss 2.0565 Acc 26.65%
Epoch 7 complete...
Epoch 8: Train Loss 2.1614 Acc 20.86%
Test Loss 2.0342 Acc 27.78%
Epoch 8 complete...
Epoch 9: Train Loss 2.1434 Acc 21.85%
Test Loss 2.0120 Acc 28.83%
Epoch 9 complete...
Epoch 10: Train Loss 2.1285 Acc 22.45%
Test Loss 1.9912 Acc 29.95%
Epoch 10 complete...
Epoch 11: Train Loss 2.1075 Acc 23.31%
Test Loss 1.9840 Acc 30.00%
Epoch 11 complete...
Epoch 12: Train Loss 2.0976 Acc 23.71%
Test Loss 1