<a href="https://colab.research.google.com/github/saltysallysmine/MIPT-CV-Homeworks/blob/main/CV_HW_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Computer Vision. Homework 1. Тренировочный цикл и linear probe на ViT-Tiny

Датасет: CIFAR-100 (из torchvision)

## CNN

In [255]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random

In [256]:
# SEEDS #
seed = 124
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [257]:
# CONSTANTS #

CIFAR_MEAN = (0.5071, 0.4867, 0.4408)
# CIFAR_STD = (0.2675, 0.2565, 0.2761)
CIFAR_STD = (0.2470, 0.2435, 0.2616)

CLASS_CNT = 5

LR = 1e-3 * 5
BATCH_SIZE = 64

In [258]:
# 5 классов выбора
selected_classes = list(range(CLASS_CNT))

# Трансформации с аугментациями
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])

In [259]:
# DATASET #

full_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)

# Фильтрация по выбранным классам
indices = [i for i, label in enumerate(full_train.targets) if label in selected_classes]
subset = Subset(full_train, indices)

train_size = int(0.8 * len(subset))
val_size = len(subset) - train_size
train_subset, val_subset = random_split(subset, [train_size, val_size])

val_subset.dataset.transform = val_transform

# Даталоадеры
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [260]:
print(train_size, val_size)

2000 500


In [278]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 32x32x32
            nn.LeakyReLU(),
            nn.MaxPool2d(2),                             # 32x16x16
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 64x16x16
            nn.LeakyReLU(),
            nn.MaxPool2d(2),                             # 64x8x8
            nn.Conv2d(64, 128, kernel_size=3, padding=1),# 128x8x8
            nn.LeakyReLU(),
            nn.MaxPool2d(2),                             # 128x4x4
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*4*4, 256),
            nn.LeakyReLU(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [279]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [280]:
def validate(model, epoch):
    criterion = nn.CrossEntropyLoss()
    if isinstance(model, SimpleCNN):
        params = model.parameters()
    else:
        params = model.head.parameters()
    optimizer = optim.Adam(params, lr=LR)

    model.eval()
    running_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total * 100.0
    avg_loss = running_loss / len(val_loader)
    print(f"Epoch {epoch}, Val loss: {avg_loss:.4f}, Accuracy: {acc:.2f}%")

    writer.add_scalar('Val/Loss', avg_loss, epoch)
    writer.add_scalar('Val/Accuracy', acc, epoch)

In [281]:
# TensorBoard логгер
writer = SummaryWriter('runs/cifar100_simple_cnn')

def sanity_check(model):
    criterion = nn.CrossEntropyLoss()
    if isinstance(model, SimpleCNN):
        params = model.parameters()
    else:
        params = model.head.parameters()
    optimizer = optim.Adam(params, lr=LR)

    model.train()
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        if i > 20:
            break
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        print(f'Sanity check batch {i}, loss: {loss.item():.4f}, acc: {correct / total * 100.0:.4f}')
        # validate(i)

In [282]:
model = SimpleCNN(num_classes=len(selected_classes)).to(device)
sanity_check(model)
validate(model, len(train_loader))

Sanity check batch 0, loss: 1.6192, acc: 18.7500
Sanity check batch 1, loss: 2.0576, acc: 20.3125
Sanity check batch 2, loss: 1.5976, acc: 25.5208
Sanity check batch 3, loss: 1.6082, acc: 25.3906
Sanity check batch 4, loss: 1.5891, acc: 27.1875
Sanity check batch 5, loss: 1.5302, acc: 28.6458
Sanity check batch 6, loss: 1.4569, acc: 29.4643
Sanity check batch 7, loss: 1.4550, acc: 29.4922
Sanity check batch 8, loss: 1.3839, acc: 29.1667
Sanity check batch 9, loss: 1.2378, acc: 30.0000
Sanity check batch 10, loss: 1.4472, acc: 31.1080
Sanity check batch 11, loss: 1.2742, acc: 32.1615
Sanity check batch 12, loss: 1.3869, acc: 32.3317
Sanity check batch 13, loss: 1.2760, acc: 33.0357
Sanity check batch 14, loss: 1.1104, acc: 34.2708
Sanity check batch 15, loss: 1.2043, acc: 35.2539
Sanity check batch 16, loss: 1.5215, acc: 35.1103
Sanity check batch 17, loss: 1.1807, acc: 35.4167
Sanity check batch 18, loss: 1.1410, acc: 36.3487
Sanity check batch 19, loss: 1.4322, acc: 36.9531
Sanity che

In [283]:
def train_epoch(model, epoch):
    criterion = nn.CrossEntropyLoss()
    if isinstance(model, SimpleCNN):
        params = model.parameters()
    else:
        params = model.head.parameters()
    optimizer = optim.Adam(params, lr=LR)

    model.train()
    running_loss = 0
    correct = 0
    total = 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Логируем гистограммы весов и градиентов для первых слоев
        for name, param in model.named_parameters():
            writer.add_histogram(name, param, epoch)
            if param.grad is not None:
                writer.add_histogram(name + '/grad', param.grad, epoch)

    acc = correct / total * 100.0
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch}, Train loss: {avg_loss:.4f}, Accuracy: {acc:.2f}%")

    writer.add_scalar('Train/Loss', avg_loss, epoch)
    writer.add_scalar('Train/Accuracy', acc, epoch)
    writer.add_scalar('Train/Learning_Rate', optimizer.param_groups[0]['lr'], epoch)

In [284]:
# TRAINING CNN #

model = SimpleCNN(num_classes=len(selected_classes)).to(device)

num_epochs = 20
for epoch in range(1, num_epochs + 1):
    train_epoch(model, epoch)
    validate(model, epoch)

writer.close()

Epoch 1, Train loss: 1.4767, Accuracy: 37.10%
Epoch 1, Val loss: 1.3889, Accuracy: 42.40%
Epoch 2, Train loss: 1.2240, Accuracy: 47.45%
Epoch 2, Val loss: 1.1617, Accuracy: 49.80%
Epoch 3, Train loss: 1.1341, Accuracy: 53.90%
Epoch 3, Val loss: 1.0105, Accuracy: 57.40%
Epoch 4, Train loss: 1.1497, Accuracy: 52.50%
Epoch 4, Val loss: 1.1214, Accuracy: 54.00%
Epoch 5, Train loss: 1.1493, Accuracy: 54.35%
Epoch 5, Val loss: 1.0666, Accuracy: 57.20%
Epoch 6, Train loss: 1.0134, Accuracy: 60.15%
Epoch 6, Val loss: 0.9873, Accuracy: 64.00%
Epoch 7, Train loss: 0.9986, Accuracy: 58.80%
Epoch 7, Val loss: 0.9913, Accuracy: 60.40%
Epoch 8, Train loss: 0.9501, Accuracy: 62.10%
Epoch 8, Val loss: 0.9628, Accuracy: 61.20%
Epoch 9, Train loss: 1.0051, Accuracy: 60.90%
Epoch 9, Val loss: 0.9555, Accuracy: 62.80%
Epoch 10, Train loss: 0.9497, Accuracy: 61.00%
Epoch 10, Val loss: 0.9306, Accuracy: 64.80%
Epoch 11, Train loss: 0.9958, Accuracy: 61.35%
Epoch 11, Val loss: 1.0324, Accuracy: 60.60%
Epoch 

## ViT-Tiny

In [285]:
import timm
import torch.nn as nn
import torch.profiler

In [286]:
model_vit = timm.create_model('vit_tiny_patch16_224', pretrained=True)

for param in model_vit.parameters():
    param.requires_grad = False

in_features = model_vit.head.in_features
model_vit.head = nn.Linear(in_features, CLASS_CNT)

params_to_optimize = model_vit.head.parameters()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

In [241]:
def train_with_profiling(model, dataloader, criterion, optimizer, device, steps=50):
    model.train()
    profiler = torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=0, warmup=0, active=steps, repeat=0),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profiler'),
        record_shapes=True,
        with_stack=True
    )

    profiler.start()
    for step, (inputs, labels) in enumerate(dataloader):
        if step >= steps:
            break
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        profiler.step()
    profiler.stop()


## Сравнение моделей