In [1]:
import logging
import time
from typing import Dict, Tuple, Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

logger = logging.getLogger(__name__)


class DistillationTrainer:
    def __init__(
        self,
        teacher_model: nn.Module,
        student_model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        loss_fn: Callable,
        device: torch.device,
        config: Dict
    ):
        self.teacher = teacher_model
        self.student = student_model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device
        self.config = config
        
        # Настройка логирования
        self.logger = logger
        self.logger.setLevel(config.get('log_level', logging.INFO))
        self.metrics = {'train_loss': [], 'val_loss': [], 'val_acc': []}

    def _run_epoch(self, epoch: int) -> float:
        self.student.train()
        total_loss = 0.0
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            # Forward pass учителя
            with torch.no_grad():
                teacher_output = self.teacher(data)
            
            # Forward pass студента
            student_output = self.student(data)
            
            # Вычисление потерь
            loss = self.loss_fn(
                student_output=student_output,
                teacher_output=teacher_output,
                target=target,
                config=self.config
            )
            
            # Оптимизация
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % self.config.get('log_interval', 100) == 0:
                self.logger.info(
                    f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] "
                    f"Loss: {loss.item():.6f}"
                )
        
        avg_loss = total_loss / len(self.train_loader)
        self.metrics['train_loss'].append(avg_loss)
        return avg_loss

    def _validate(self) -> Tuple[float, float]:
        self.student.eval()
        val_loss = 0.0
        correct = 0
        
        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.student(data)
                
                # Потери на валидации (только student loss)
                val_loss += F.cross_entropy(output, target).item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
        
        val_loss /= len(self.val_loader)
        accuracy = 100. * correct / len(self.val_loader.dataset)
        
        self.metrics['val_loss'].append(val_loss)
        self.metrics['val_acc'].append(accuracy)
        return val_loss, accuracy

    def run(self) -> Dict:
        start_time = time.time()
        
        for epoch in range(1, self.config['epochs'] + 1):
            train_loss = self._run_epoch(epoch)
            val_loss, val_acc = self._validate()
            
            self.logger.info(
                f"\nEpoch: {epoch} | Train Loss: {train_loss:.4f} | "
                f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%\n"
            )
        
        training_time = time.time() - start_time
        self.logger.info(f"Training completed in {training_time:.2f} seconds")
        return self.metrics


def distillation_loss(
    student_output: torch.Tensor,
    teacher_output: torch.Tensor,
    target: torch.Tensor,
    config: Dict
) -> torch.Tensor:
    """
    Комбинированная функция потерь для дистилляции:
    loss = α * soft_loss + (1 - α) * hard_loss
    
    Args:
        student_output: Выход модели-студента (логиты)
        teacher_output: Выход модели-учителя (логиты)
        target: Истинные метки
        config: Конфигурация с параметрами
        
    Returns:
        Комбинированный лосс
    """
    T = config.get('temperature', 5.0)
    alpha = config.get('alpha', 0.7)
    
    # Soft targets (дистилляция)
    soft_loss = F.kl_div(
        input=F.log_softmax(student_output / T, dim=1),
        target=F.softmax(teacher_output / T, dim=1),
        reduction='batchmean'
    ) * (T * T)  # Scale by T^2
    
    # Hard targets (обычная кросс-энтропия)
    hard_loss = F.cross_entropy(student_output, target)
    
    return alpha * soft_loss + (1 - alpha) * hard_loss


def distill(
    teacher_model: nn.Module,
    student_model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    config: Dict,
    loss_fn: Optional[Callable] = distillation_loss,
    device: Optional[torch.device] = None
) -> Tuple[nn.Module, Dict]:
    """
    Основная функция дистилляции знаний.
    
    Args:
        teacher_model: Обученная модель-учитель
        student_model: Модель-студент для обучения
        train_loader: DataLoader для обучения
        val_loader: DataLoader для валидации
        optimizer: Оптимизатор для студента
        config: Конфигурационный словарь
        loss_fn: Функция потерь (по умолчанию distillation_loss)
        device: Устройство для вычислений
    
    Returns:
        Обученная модель-студент и метрики
    """
    # Настройка устройства
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    teacher_model.to(device)
    student_model.to(device)
    teacher_model.eval()
    
    # Инициализация тренера
    trainer = DistillationTrainer(
        teacher_model=teacher_model,
        student_model=student_model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
        config=config
    )
    
    # Запуск обучения
    metrics = trainer.run()
    return student_model, metrics

ModuleNotFoundError: No module named 'torch'