# Пайплайн обучения классификационной модели

В этом ноутбуке мы построим полный пайплайн обучения модели классификации изображений на датасете ImageNet Tiny.

Мы пройдем следующие шаги:
1. Создание и загрузка датасета
2. Предобработка и аугментация данных
3. Архитектура модели
4. Цикл обучения
5. Оценка и визуализация


## Шаг 1: Создание датасета

Начнем с создания и загрузки датасета ImageNet Tiny.


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from PIL import Image
import os
import numpy as np

torch.manual_seed(42)
np.random.seed(42)


In [None]:
data_dir = './data/imagenet_tiny'

print("Setting up ImageNet Tiny dataset...")
print(f"Data directory: {data_dir}")


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(
    root=os.path.join(data_dir, 'train'),
    transform=transform
)

val_dataset = datasets.ImageFolder(
    root=os.path.join(data_dir, 'val'),
    transform=transform
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Classes: {train_dataset.classes[:10]}...")


In [None]:
batch_size = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")


In [None]:
import matplotlib.pyplot as plt

def visualize_samples(dataset, num_samples=8):
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.ravel()
    
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        image, label = dataset[idx]
        img = image.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        
        axes[i].imshow(img)
        axes[i].set_title(f'Class: {dataset.classes[label]}')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_samples(train_dataset)


## Шаг 2: Архитектура модели

Создадим простую сверточную нейронную сеть для классификации изображений.


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=200):
        super(SimpleCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(512 * 14 * 14, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)
        
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        
        x = x.view(-1, 512 * 14 * 14)
        
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x


In [None]:
num_classes = len(train_dataset.classes)
model = SimpleCNN(num_classes=num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"Model created with {num_classes} classes")
print(f"Device: {device}")

sample_input = torch.randn(1, 3, 224, 224).to(device)
output = model(sample_input)
print(f"Input shape: {sample_input.shape}")
print(f"Output shape: {output.shape}")


## Шаг 3: Lightning модуль

Создадим Lightning модуль с определением модели, loss функции, метрик и оптимизатора.


In [None]:
import pytorch_lightning as pl
from torchmetrics import Accuracy
import torch.optim as optim

class ClassificationModule(pl.LightningModule):
    def __init__(self, num_classes=200, learning_rate=0.001, steps_per_epoch=None):
        super().__init__()
        
        self.model = SimpleCNN(num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
        
        self.learning_rate = learning_rate
        self.steps_per_epoch = steps_per_epoch
        
        self.save_hyperparameters()
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        
        preds = torch.argmax(outputs, dim=1)
        self.train_accuracy(preds, labels)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        
        preds = torch.argmax(outputs, dim=1)
        self.val_accuracy(preds, labels)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_accuracy, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def training_epoch_end(self, outputs):
        self.log('train_acc_epoch', self.train_accuracy.compute())
    
    def validation_epoch_end(self, outputs):
        self.log('val_acc_epoch', self.val_accuracy.compute())
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        
        if self.steps_per_epoch is not None:
            scheduler = optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=self.learning_rate,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=self.steps_per_epoch
            )
            
            return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': scheduler,
                    'interval': 'step'
                }
            }
        else:
            return optimizer


In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

num_classes = len(train_dataset.classes)
steps_per_epoch = len(train_loader)
lightning_model = ClassificationModule(
    num_classes=num_classes, 
    learning_rate=0.001,
    steps_per_epoch=steps_per_epoch
)

print("Lightning module created successfully")
print(f"Steps per epoch: {steps_per_epoch}")


## Шаг 4: Обучение модели

Запустим обучение модели с помощью Lightning Trainer.


In [None]:
from pytorch_lightning.callbacks import Callback

class PrintMetricsCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        epoch = trainer.current_epoch
        
        train_loss = metrics.get('train_loss_epoch', 0)
        train_acc = metrics.get('train_acc_epoch', 0)
        
        print(f"Epoch {epoch} - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    
    def on_validation_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        epoch = trainer.current_epoch
        
        val_loss = metrics.get('val_loss', 0)
        val_acc = metrics.get('val_acc_epoch', 0)
        
        print(f"Epoch {epoch} - Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        print("-" * 50)


In [None]:
trainer = Trainer(
    max_epochs=10,
    accelerator='auto',
    devices='auto',
    callbacks=[PrintMetricsCallback()],
    enable_progress_bar=True,
    logger=False
)

trainer.fit(lightning_model, train_loader, val_loader)


## Шаг 5: Модели из timm

Изучим различные предобученные модели из библиотеки timm и сравним их по количеству параметров и скорости инференса.


In [None]:
import timm
import time

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def measure_inference_time(model, input_shape, num_iterations=100, device='cuda'):
    model.eval()
    model = model.to(device)
    dummy_input = torch.randn(input_shape).to(device)
    
    torch.cuda.synchronize() if device == 'cuda' else None
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(num_iterations):
            _ = model(dummy_input)
    
    torch.cuda.synchronize() if device == 'cuda' else None
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_iterations
    return avg_time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
model_names = [
    'resnet18',
    'resnet50',
    'efficientnet_b0',
    'efficientnet_b3',
    'vit_base_patch16_224',
    'convnext_tiny'
]

num_classes = len(train_dataset.classes) if 'train_dataset' in globals() else 200
input_shape = (1, 3, 224, 224)

results = []

for model_name in model_names:
    try:
        print(f"\nLoading {model_name}...")
        model = timm.create_model(
            model_name,
            pretrained=False,
            num_classes=num_classes
        )
        
        num_params = count_parameters(model)
        print(f"Parameters: {num_params:,}")
        
        inference_time = measure_inference_time(model, input_shape, num_iterations=50, device=device)
        print(f"Inference time: {inference_time*1000:.2f} ms")
        
        results.append({
            'model': model_name,
            'parameters': num_params,
            'inference_time_ms': inference_time * 1000
        })
        
        del model
        torch.cuda.empty_cache() if device == 'cuda' else None
        
    except Exception as e:
        print(f"Error loading {model_name}: {e}")
        continue


In [None]:
import pandas as pd

df_results = pd.DataFrame(results)
print("\n" + "="*60)
print("Comparison of timm models:")
print("="*60)
print(df_results.to_string(index=False))


In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.barh(df_results['model'], df_results['parameters'] / 1e6)
ax1.set_xlabel('Parameters (Millions)')
ax1.set_title('Number of Parameters')
ax1.grid(axis='x', alpha=0.3)

ax2.barh(df_results['model'], df_results['inference_time_ms'])
ax2.set_xlabel('Inference Time (ms)')
ax2.set_title('Inference Speed')
ax2.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()


## Шаг 6: Обучение модели из timm

Обучим модель ConvNeXt Tiny из timm на нашем датасете.


In [None]:
class TimmClassificationModule(pl.LightningModule):
    def __init__(self, model_name='convnext_tiny', num_classes=200, learning_rate=0.001, steps_per_epoch=None, pretrained=True):
        super().__init__()
        
        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=num_classes
        )
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
        
        self.learning_rate = learning_rate
        self.steps_per_epoch = steps_per_epoch
        
        self.save_hyperparameters()
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        
        preds = torch.argmax(outputs, dim=1)
        self.train_accuracy(preds, labels)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        
        preds = torch.argmax(outputs, dim=1)
        self.val_accuracy(preds, labels)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_accuracy, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def training_epoch_end(self, outputs):
        self.log('train_acc_epoch', self.train_accuracy.compute())
    
    def validation_epoch_end(self, outputs):
        self.log('val_acc_epoch', self.val_accuracy.compute())
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        
        if self.steps_per_epoch is not None:
            scheduler = optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=self.learning_rate,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=self.steps_per_epoch
            )
            
            return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': scheduler,
                    'interval': 'step'
                }
            }
        else:
            return optimizer


In [None]:
num_classes = len(train_dataset.classes)
steps_per_epoch = len(train_loader)

timm_model = TimmClassificationModule(
    model_name='convnext_tiny',
    num_classes=num_classes,
    learning_rate=0.001,
    steps_per_epoch=steps_per_epoch,
    pretrained=True
)

print("Timm model module created successfully")
print(f"Model: convnext_tiny")
print(f"Number of classes: {num_classes}")
print(f"Steps per epoch: {steps_per_epoch}")


In [None]:
trainer_timm = Trainer(
    max_epochs=10,
    accelerator='auto',
    devices='auto',
    callbacks=[PrintMetricsCallback()],
    enable_progress_bar=True,
    logger=False
)

trainer_timm.fit(timm_model, train_loader, val_loader)
