<a href="https://colab.research.google.com/github/saltysallysmine/MIPT-CV-Homeworks/blob/main/CV_HW_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Computer Vision. Homework 1. Тренировочный цикл и linear probe на ViT-Tiny

Датасет: CIFAR-100 (из torchvision)

## CNN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random

In [None]:
# SEEDS #
seed = 124
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# CONSTANTS #

CIFAR_MEAN = (0.5071, 0.4867, 0.4408)
# CIFAR_STD = (0.2675, 0.2565, 0.2761)
CIFAR_STD = (0.2470, 0.2435, 0.2616)

CLASS_CNT = 5

LR = 1e-3 * 5
BATCH_SIZE = 64

In [None]:
# 5 классов выбора
selected_classes = list(range(CLASS_CNT))

# Трансформации с аугментациями
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])

In [None]:
# DATASET #

full_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)

# Фильтрация по выбранным классам
indices = [i for i, label in enumerate(full_train.targets) if label in selected_classes]
subset = Subset(full_train, indices)

train_size = int(0.8 * len(subset))
val_size = len(subset) - train_size
train_subset, val_subset = random_split(subset, [train_size, val_size])

val_subset.dataset.transform = val_transform

# Даталоадеры
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
print(train_size, val_size)

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 32x32x32
            nn.LeakyReLU(),
            nn.MaxPool2d(2),                             # 32x16x16
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 64x16x16
            nn.LeakyReLU(),
            nn.MaxPool2d(2),                             # 64x8x8
            nn.Conv2d(64, 128, kernel_size=3, padding=1),# 128x8x8
            nn.LeakyReLU(),
            nn.MaxPool2d(2),                             # 128x4x4
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*4*4, 256),
            nn.LeakyReLU(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def validate(model, epoch):
    criterion = nn.CrossEntropyLoss()
    if isinstance(model, SimpleCNN):
        params = model.parameters()
    else:
        params = model.head.parameters()
    optimizer = optim.Adam(params, lr=LR)

    model.eval()
    running_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total * 100.0
    avg_loss = running_loss / len(val_loader)
    print(f"Epoch {epoch}, Val loss: {avg_loss:.4f}, Accuracy: {acc:.2f}%")

    writer.add_scalar('Val/Loss', avg_loss, epoch)
    writer.add_scalar('Val/Accuracy', acc, epoch)

In [None]:
# TensorBoard логгер
writer = SummaryWriter('runs/cifar100_simple_cnn')

def sanity_check(model):
    criterion = nn.CrossEntropyLoss()
    if isinstance(model, SimpleCNN):
        params = model.parameters()
    else:
        params = model.head.parameters()
    optimizer = optim.Adam(params, lr=LR)

    model.train()
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        if i > 20:
            break
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        print(f'Sanity check batch {i}, loss: {loss.item():.4f}, acc: {correct / total * 100.0:.4f}')
        # validate(i)

In [None]:
model = SimpleCNN(num_classes=len(selected_classes)).to(device)
sanity_check(model)
validate(model, len(train_loader))

In [None]:
def train_epoch(model, epoch):
    criterion = nn.CrossEntropyLoss()
    if isinstance(model, SimpleCNN):
        params = model.parameters()
    else:
        params = model.head.parameters()
    optimizer = optim.Adam(params, lr=LR)

    model.train()
    running_loss = 0
    correct = 0
    total = 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Логируем гистограммы весов и градиентов для первых слоев
        for name, param in model.named_parameters():
            writer.add_histogram(name, param, epoch)
            if param.grad is not None:
                writer.add_histogram(name + '/grad', param.grad, epoch)

    acc = correct / total * 100.0
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch}, Train loss: {avg_loss:.4f}, Accuracy: {acc:.2f}%")

    writer.add_scalar('Train/Loss', avg_loss, epoch)
    writer.add_scalar('Train/Accuracy', acc, epoch)
    writer.add_scalar('Train/Learning_Rate', optimizer.param_groups[0]['lr'], epoch)

In [None]:
# TRAINING CNN #

model = SimpleCNN(num_classes=len(selected_classes)).to(device)

num_epochs = 20
for epoch in range(1, num_epochs + 1):
    train_epoch(model, epoch)
    validate(model, epoch)

writer.close()

## ViT-Tiny

In [None]:
import timm
import torch.nn as nn
import torch.profiler

In [None]:
model_vit = timm.create_model('vit_tiny_patch16_224', pretrained=True)

for param in model_vit.parameters():
    param.requires_grad = False

in_features = model_vit.head.in_features
model_vit.head = nn.Linear(in_features, CLASS_CNT)

params_to_optimize = model_vit.head.parameters()

In [None]:
def train_with_profiling(model, dataloader, criterion, optimizer, device, steps=50):
    model.train()
    profiler = torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=0, warmup=0, active=steps, repeat=0),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profiler'),
        record_shapes=True,
        with_stack=True
    )

    profiler.start()
    for step, (inputs, labels) in enumerate(dataloader):
        if step >= steps:
            break
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        profiler.step()
    profiler.stop()


## Сравнение моделей