# Baseline: Обучение Multi-Branch MLP на размеченных данных


In [1]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

from model import MultiBranchMLP
from data_module import DataModule
from lightning_module import BaseLightningModule

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    import random
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


## 1. Загрузка данных


In [2]:
data_dir = '../data'

dm = DataModule(
    data_dir=data_dir,
    batch_size=128,
    num_workers=4
)

dm.setup()

print(f'Input dimension: {dm.input_dim}')
print(f'Number of classes: {dm.n_classes}')
print(f'Labeled train samples: {len(dm.train_labeled_dataset)}')
print(f'Test samples: {len(dm.test_dataset)}')


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1600
Test samples: 4000
Input dimension: 3072
Number of classes: 10
Labeled train samples: 1600
Test samples: 4000


## 2. Анализ дисбаланса классов и вычисление весов


In [3]:
train_labels = dm.train_labeled_dataset.y

unique_labels = np.unique(train_labels)
class_weights = compute_class_weight(
    'balanced',
    classes=unique_labels,
    y=train_labels
)

print(f'Class weights: {dict(zip(unique_labels, class_weights))}')

class_weights_tensor = torch.FloatTensor(class_weights)


Class weights: {np.int64(0): np.float64(1.0738255033557047), np.int64(1): np.float64(0.963855421686747), np.int64(2): np.float64(1.0256410256410255), np.int64(3): np.float64(1.103448275862069), np.int64(4): np.float64(0.9523809523809523), np.int64(5): np.float64(0.935672514619883), np.int64(6): np.float64(0.9523809523809523), np.int64(7): np.float64(1.0596026490066226), np.int64(8): np.float64(0.9248554913294798), np.int64(9): np.float64(1.0457516339869282)}


## 3. Создание модели


In [4]:
model = MultiBranchMLP(
    input_dim=dm.input_dim,
    hidden_dim=256,
    output_dim=dm.n_classes,
    num_blocks=4,
    dropout=0.1,
    combine_mode='concat'
)

print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')


Model parameters: 4,080,650


## 4. Создание Lightning модуля


In [5]:
loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)

lightning_model = BaseLightningModule(
    model=model,
    loss_fn=loss_fn,
    optimizer_type='adamw',
    learning_rate=1e-3,
    task_type='multiclass'
)


## 5. Обучение модели


In [6]:
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='best_model-{epoch:02d}-{val_accuracy:.4f}',
    monitor='val_accuracy',
    mode='max',
    save_top_k=1,
    save_last=True
)

trainer = Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback],
    enable_checkpointing=True,
    logger=True,
    enable_progress_bar=True,
    enable_model_summary=True,
    accelerator='auto',
    devices='auto'
)

trainer.fit(lightning_model, dm)


Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA L40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1600
Test samples: 4000



  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | MultiBranchMLP   | 4.1 M  | train
1 | loss_fn | CrossEntropyLoss | 0      | train
2 | metrics | ModuleDict       | 0      | train
-----------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.323    Total estimated model params size (MB)
70        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch 0: accuracy=0.1250, f1_macro=0.0485


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0: accuracy=0.0984, f1_macro=0.0320


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1: accuracy=0.1252, f1_macro=0.0728


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2: accuracy=0.1413, f1_macro=0.1026


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3: accuracy=0.1686, f1_macro=0.1458


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4: accuracy=0.1835, f1_macro=0.1643


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5: accuracy=0.1864, f1_macro=0.1724


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6: accuracy=0.1990, f1_macro=0.1922


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7: accuracy=0.1978, f1_macro=0.1910


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8: accuracy=0.2008, f1_macro=0.1967


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9: accuracy=0.1967, f1_macro=0.1946


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10: accuracy=0.1982, f1_macro=0.1956


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11: accuracy=0.1982, f1_macro=0.1966


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12: accuracy=0.2042, f1_macro=0.2034


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13: accuracy=0.2042, f1_macro=0.2036


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14: accuracy=0.2062, f1_macro=0.2048


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15: accuracy=0.2069, f1_macro=0.2050


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16: accuracy=0.2103, f1_macro=0.2082


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17: accuracy=0.2127, f1_macro=0.2111


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18: accuracy=0.2124, f1_macro=0.2117


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19: accuracy=0.2146, f1_macro=0.2147


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20: accuracy=0.2185, f1_macro=0.2192


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21: accuracy=0.2212, f1_macro=0.2216


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22: accuracy=0.2228, f1_macro=0.2233


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23: accuracy=0.2258, f1_macro=0.2264


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24: accuracy=0.2276, f1_macro=0.2284


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25: accuracy=0.2301, f1_macro=0.2310


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26: accuracy=0.2335, f1_macro=0.2343


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27: accuracy=0.2358, f1_macro=0.2370


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28: accuracy=0.2374, f1_macro=0.2389


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29: accuracy=0.2398, f1_macro=0.2416


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30: accuracy=0.2417, f1_macro=0.2434


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31: accuracy=0.2432, f1_macro=0.2448


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32: accuracy=0.2448, f1_macro=0.2462


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33: accuracy=0.2479, f1_macro=0.2493


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34: accuracy=0.2494, f1_macro=0.2509


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35: accuracy=0.2515, f1_macro=0.2530


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36: accuracy=0.2526, f1_macro=0.2541


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37: accuracy=0.2533, f1_macro=0.2547


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38: accuracy=0.2550, f1_macro=0.2560


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39: accuracy=0.2556, f1_macro=0.2567


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40: accuracy=0.2563, f1_macro=0.2575


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41: accuracy=0.2572, f1_macro=0.2583


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42: accuracy=0.2587, f1_macro=0.2597


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43: accuracy=0.2598, f1_macro=0.2609


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44: accuracy=0.2612, f1_macro=0.2621


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45: accuracy=0.2629, f1_macro=0.2638


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46: accuracy=0.2639, f1_macro=0.2649


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47: accuracy=0.2651, f1_macro=0.2661


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48: accuracy=0.2664, f1_macro=0.2675


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49: accuracy=0.2669, f1_macro=0.2679


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50: accuracy=0.2683, f1_macro=0.2693


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51: accuracy=0.2698, f1_macro=0.2706


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52: accuracy=0.2704, f1_macro=0.2710


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53: accuracy=0.2710, f1_macro=0.2717


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54: accuracy=0.2711, f1_macro=0.2718


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55: accuracy=0.2716, f1_macro=0.2725


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56: accuracy=0.2725, f1_macro=0.2733


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57: accuracy=0.2733, f1_macro=0.2741


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58: accuracy=0.2747, f1_macro=0.2755


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59: accuracy=0.2757, f1_macro=0.2764


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60: accuracy=0.2766, f1_macro=0.2773


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 61: accuracy=0.2774, f1_macro=0.2779


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62: accuracy=0.2778, f1_macro=0.2784


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 63: accuracy=0.2784, f1_macro=0.2789


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 64: accuracy=0.2787, f1_macro=0.2792


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65: accuracy=0.2787, f1_macro=0.2795


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 66: accuracy=0.2785, f1_macro=0.2796


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 67: accuracy=0.2783, f1_macro=0.2792


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68: accuracy=0.2759, f1_macro=0.2774


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 69: accuracy=0.2740, f1_macro=0.2752


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 70: accuracy=0.2720, f1_macro=0.2728


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71: accuracy=0.2697, f1_macro=0.2710


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 72: accuracy=0.2673, f1_macro=0.2682


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 73: accuracy=0.2655, f1_macro=0.2663


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74: accuracy=0.2638, f1_macro=0.2648


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 75: accuracy=0.2624, f1_macro=0.2635


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 76: accuracy=0.2615, f1_macro=0.2627


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77: accuracy=0.2609, f1_macro=0.2621


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 78: accuracy=0.2605, f1_macro=0.2618


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 79: accuracy=0.2602, f1_macro=0.2614


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80: accuracy=0.2602, f1_macro=0.2612


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 81: accuracy=0.2605, f1_macro=0.2616


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 82: accuracy=0.2602, f1_macro=0.2612


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83: accuracy=0.2603, f1_macro=0.2612


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 84: accuracy=0.2602, f1_macro=0.2609


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 85: accuracy=0.2606, f1_macro=0.2613


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86: accuracy=0.2608, f1_macro=0.2615


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 87: accuracy=0.2608, f1_macro=0.2616


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 88: accuracy=0.2607, f1_macro=0.2615


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89: accuracy=0.2599, f1_macro=0.2610


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 90: accuracy=0.2599, f1_macro=0.2609


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 91: accuracy=0.2600, f1_macro=0.2611


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92: accuracy=0.2599, f1_macro=0.2610


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 93: accuracy=0.2600, f1_macro=0.2612


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 94: accuracy=0.2602, f1_macro=0.2614


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95: accuracy=0.2596, f1_macro=0.2610


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 96: accuracy=0.2596, f1_macro=0.2611


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 97: accuracy=0.2600, f1_macro=0.2614


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98: accuracy=0.2604, f1_macro=0.2618


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 99: accuracy=0.2608, f1_macro=0.2621


`Trainer.fit` stopped: `max_epochs=100` reached.


## 6. Оценка на тестовой выборке


In [7]:
best_model_path = checkpoint_callback.best_model_path
print(f'Loading best model from: {best_model_path}')

if best_model_path:
    best_model = BaseLightningModule.load_from_checkpoint(
        best_model_path,
        model=model,
        loss_fn=loss_fn,
        optimizer_type='adamw',
        learning_rate=1e-3,
        task_type='multiclass'
    )
else:
    best_model = lightning_model

test_results = trainer.test(best_model, dm)

print('\n=== Финальные результаты на тестовой выборке ===')
for key, value in test_results[0].items():
    print(f'{key}: {value:.4f}')


Loading best model from: /home/tam2511/mounts/0/arcadia/market/robotics/cv/ml/user_data/shad/dl2025/lesson7/homework/baseline/checkpoints/best_model-epoch=65-val_accuracy=0.2787.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1600
Test samples: 4000


Testing: |          | 0/? [00:00<?, ?it/s]

Test results: accuracy=0.2830, f1_macro=0.2737



=== Финальные результаты на тестовой выборке ===
test_loss: 13.6383
test_accuracy: 0.2830
test_f1_macro: 0.2737
