ДЗ должна быть выполнена строго на torch
1 - использовать dataloader. 
2 - в цикл обучения добавить сохранения лучшей модели / шедулер для скорости обучения. 
3 - вывести графики обучения. Выводить информацию по обучению в процессе. 

Решить задачу предсказания возраста. Свести к задаче классификации. 
0- 9 10 -19 - 20 -29 30 -39 40 49 50 59 ...
 
В качестве фьючеэкстрактора используйте любую вариацию vit. ** - 

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from timm.models.vision_transformer import VisionTransformer
from tqdm import tqdm
import pandas as pd
import numpy as np

In [2]:
# Конфигурация
config = {
    'data_root': 'C:/ML/ML01_P_Online/Eugene_Piuta/DZ_24/wiki_crop/cropped_imgs/',  
    'csv_path': './imgs_df.csv',    
    'batch_size': 32,
    'num_workers': 4,
    'lr': 0.001,
    'epochs': 20,
    'img_size': 128,
    'patch_size': 16,
    'embed_dim': 192,
    'depth': 6,
    'num_heads': 6,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'save_path': 'best_model.pth'
}

In [3]:
# Кастомный Dataset для работы с датафреймом
class FaceDataFrameDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(df['class'].unique())
        self.num_classes = len(self.classes)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.df.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        label = self.df.iloc[idx, 2]  # Берем класс из третьего столбца
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [4]:
# Аугментации и преобразования
transform = transforms.Compose([
    transforms.Resize((config['img_size'], config['img_size'])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
# Загрузка датафрейма
df = pd.read_csv(config['csv_path'])
config['num_classes'] = len(df['class'].unique())
df

Unnamed: 0,filename,raw,class
0,00_102100_1970-10-09_20080.jpg,38,4
1,00_1024100_1982-06-07_20110.jpg,29,3
2,00_11328300_1980-06-10_20080.jpg,28,3
3,00_1199800_1976-04-13_20120.jpg,36,4
4,00_12318000_1988-06-12_20080.jpg,20,2
...,...,...,...
6201,99_6986299_1986-05-08_20121.jpg,26,3
6202,99_7319399_1981-08-23_20100.jpg,29,3
6203,99_8419699_1973-07-02_20140.jpg,41,5
6204,99_8551399_1972-02-25_20080.jpg,36,4


In [6]:
# Создание датасета
dataset = FaceDataFrameDataset(df, config['data_root'], transform=transform)

In [7]:
# Разделение на train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], 
                         shuffle=True, num_workers=config['num_workers'])
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], 
                        shuffle=False, num_workers=config['num_workers'])

In [8]:
# Модель Vision Transformer
class FaceViT(nn.Module):
    def __init__(self, config):
        super(FaceViT, self).__init__()
        self.vit = VisionTransformer(
            img_size=config['img_size'],
            patch_size=config['patch_size'],
            in_chans=3,
            num_classes=config['num_classes'],
            embed_dim=config['embed_dim'],
            depth=config['depth'],
            num_heads=config['num_heads'],
        )
        
    def forward(self, x):
        return self.vit(x)

model = FaceViT(config).to(config['device'])

In [9]:
# Замораживаем все параметры ViT
for param in model.vit.parameters():
    param.requires_grad = False

# Размораживаем только голову классификации
for param in model.vit.head.parameters():
    param.requires_grad = True

In [10]:
# Критерий и оптимизатор
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=config['lr'])
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)



In [11]:
# Функция валидации
def validate(model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(config['device'])
            labels = labels.to(config['device'])
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_loss = running_loss / len(val_loader)
    val_acc = 100 * correct / total
    
    return val_loss, val_acc

In [12]:
# Функция для отрисовки графиков
def plot_metrics(train_loss, val_loss, train_acc, val_acc):
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_loss, label='Train Loss')
    plt.plot(val_loss, label='Val Loss')
    plt.title('Loss History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_acc, label='Train Accuracy')
    plt.plot(val_acc, label='Val Accuracy')
    plt.title('Accuracy History')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()

In [13]:
# Обучение модели
def train_model():
    train_loss_history = []
    val_loss_history = []
    train_acc_history = []
    val_acc_history = []
    best_val_loss = float('inf')
    
    for epoch in range(config['epochs']):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Progress bar
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["epochs"]}', unit='batch')
        
        for images, labels in pbar:
            images = images.to(config['device'])
            labels = labels.to(config['device'])
            
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100 * correct / total
            })
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total
        train_loss_history.append(train_loss)
        train_acc_history.append(train_acc)
        
        # Валидация
        val_loss, val_acc = validate(model, val_loader, criterion)
        val_loss_history.append(val_loss)
        val_acc_history.append(val_acc)
        
        # Шедулер
        scheduler.step(val_loss)
        
        # Сохранение лучшей модели
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), config['save_path'])
            print(f'Model saved with val_loss: {val_loss:.4f}')
        
        print(f'Epoch {epoch+1}/{config["epochs"]}')
        print(f'Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%')
        print('-' * 50)
    
    # Графики
    plot_metrics(train_loss_history, val_loss_history, train_acc_history, val_acc_history)
    
    return model

In [None]:
# Запуск обучения
if __name__ == '__main__':
    print(f"Training on {config['device']}")
    print(f"Number of classes: {config['num_classes']}")
    print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
    
    model = train_model()

Training on cuda
Number of classes: 10
Train samples: 4964, Val samples: 1242


Epoch 1/20:   0%|          | 0/156 [00:00<?, ?batch/s]

Максимально было достигнуто accuracy 42%. Потом все поломалось и перестало работать).