# 1. Augmentacja

In [1]:
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

cifar_trainset = CIFAR10(root='./data', train=True, download=False)
data = cifar_trainset.data / 255

mean = data.mean(axis=(0, 1, 2))
std = data.std(axis=(0, 1, 2))

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomResizedCrop(size=32, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# 2, 3. Przypięcie augmentacji, załadowanie CIFAR-10

In [2]:
train_dataset = CIFAR10(root='./data', train=True,
                        transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64,
                          shuffle=True, num_workers=4)

test_dataset = CIFAR10(root='./data', train=False,
                       transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=64,
                         shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


# 4, 5, 6. Patching i model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, hidden_dim, dropout_rate=0.1):
        super(TransformerEncoderLayer, self).__init__()

        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads)
        self.linear1 = nn.Linear(embed_size, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, embed_size)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout_rate)
        self.activation = nn.GELU()

    def forward(self, x):
        y = self.norm1(x)
        y, _ = self.multihead_attn(y, y, y)
        y = self.dropout(y)
        x = x + y

        y = self.norm2(x)
        y = self.linear1(y)
        y = self.activation(y)
        y = self.dropout(y)
        y = self.linear2(y)
        y = self.dropout(y)
        x = x + y

        return x

class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_size):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.linear = nn.Linear(patch_size * patch_size * in_channels, embed_size)

    def forward(self, x):
        num_patches_h = self.img_size // self.patch_size
        num_patches_w = self.img_size // self.patch_size
        x = x.permute(0, 2, 3, 1)
        x = x.view(x.shape[0], num_patches_h, self.patch_size, num_patches_w, self.patch_size, x.shape[3])
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(x.shape[0], -1, self.patch_size * self.patch_size * x.shape[-1])
        x = self.linear(x)
        return x
    

class ViT(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_size, num_heads, depth, n_classes):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_size)

        self.positional_encoding = nn.Parameter(torch.randn(1, self.patch_embedding.n_patches, embed_size))

        transformers = [
            TransformerEncoderLayer(
                embed_size=embed_size, 
                num_heads=num_heads, 
                hidden_dim=512
            )
        ] * depth
        
        self.transformers = nn.Sequential(*transformers)

        self.linear = nn.Linear(embed_size, n_classes)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patch_embedding(x)
        x += self.positional_encoding
        x = self.transformers(x)
        x = x.mean(dim=1)
        x = self.linear(x)
        return x

# 7. Trening

In [4]:
import torch.optim as optim

model = ViT(
    img_size=32, 
    patch_size=4,    
    in_channels=3, 
    embed_size=256, 
    num_heads=8, 
    depth=6,
    n_classes=10
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.002)
num_epochs = 50
lr_drop_epochs = [35, 45]
best_test_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss, train_correct = 0.0, 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_correct += (predicted == labels).sum().item()


    model.eval()
    test_loss, test_correct = 0.0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_correct += (predicted == labels).sum().item()


    if epoch in lr_drop_epochs:
        print("LR drop")
        for param_group in optimizer.param_groups:
            param_group['lr'] /= 10

    logs = {
        "epoch": epoch,
        "lr": optimizer.param_groups[0]['lr'], 
        "loss/train": train_loss / len(train_loader),
        "loss/test": test_loss / len(test_loader),
        "acc/train": train_correct / len(train_dataset),
        "acc/test": test_correct / len(test_dataset),
    }

    if logs["loss/test"] < best_test_loss:
        best_test_loss = logs["loss/test"]
        model_path = f"best_model_epoch_{epoch:03d}_test_loss{best_test_loss:.1e}.pth"
        torch.save(model.state_dict(), model_path)
        
    print(", ".join([f"{k}: {v:}" for k, v in logs.items()]))


epoch: 0, lr: 0.002, loss/train: 1.7006267143027556, loss/test: 1.5123579972868513, acc/train: 0.37496, acc/test: 0.4315
epoch: 1, lr: 0.002, loss/train: 1.4076622809900348, loss/test: 1.350803022931336, acc/train: 0.49014, acc/test: 0.5057
epoch: 2, lr: 0.002, loss/train: 1.2864463456603876, loss/test: 1.298111631991757, acc/train: 0.53732, acc/test: 0.53
epoch: 3, lr: 0.002, loss/train: 1.1975068883670261, loss/test: 1.1641122824067522, acc/train: 0.57464, acc/test: 0.5901
epoch: 4, lr: 0.002, loss/train: 1.1298167079001131, loss/test: 1.0969695452671901, acc/train: 0.59814, acc/test: 0.6023
epoch: 5, lr: 0.002, loss/train: 1.088044859564213, loss/test: 1.1205699360294707, acc/train: 0.61692, acc/test: 0.6048
epoch: 6, lr: 0.002, loss/train: 1.0571758250903596, loss/test: 1.0525666744845688, acc/train: 0.62754, acc/test: 0.6222
epoch: 7, lr: 0.002, loss/train: 1.0315475216149674, loss/test: 1.0242400788197852, acc/train: 0.63498, acc/test: 0.6387
epoch: 8, lr: 0.002, loss/train: 1.01

In [None]:
# from torchviz import make_dot

# # Instantiate the model
# # Dummy input tensor (batch size 1, 3 channels, 32x32 image)
# dummy_input = torch.randn(1, 3, 32, 32)

# # Forward pass and graph visualization
# output = model(dummy_input)
# graph = make_dot(output, params=dict(model.named_parameters()))

# # Save or display the graph
# graph.render("visual_transformer_graph", format="png", cleanup=True)
