In [None]:
import torch
import gc
import numpy as np
import os
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import StratifiedKFold
from dataset import DeepfakeDataset
from train_utils import train_epoch, validate
from config import Config
from model import EfficientNetModel

def train_model():
    train_transform = transforms.Compose([
        transforms.Resize((Config.img_size, Config.img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((Config.img_size, Config.img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    dataset = DeepfakeDataset(Config.train_path, transform=train_transform)
    
    kfold = StratifiedKFold(n_splits=Config.n_folds, shuffle=True, random_state=42)
    fold_scores = []
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset.image_paths, dataset.labels)):
        print(f'\nFold {fold + 1}/{Config.n_folds}')
        train_loader = DataLoader(torch.utils.data.Subset(dataset, train_idx), batch_size=Config.batch_size, shuffle=True, num_workers=Config.num_workers, pin_memory=Config.pin_memory)
        val_loader = DataLoader(torch.utils.data.Subset(dataset, val_idx), batch_size=Config.batch_size, shuffle=False, num_workers=Config.num_workers, pin_memory=Config.pin_memory)

        model = EfficientNetModel().to(Config.device)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=Config.base_lr, weight_decay=Config.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.epochs)
        
        best_acc = 0
        for epoch in range(Config.epochs):
            print(f'\nEpoch {epoch + 1}/{Config.epochs}')
            train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler)
            val_loss, val_acc = validate(model, val_loader, criterion)
            
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
            
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), f'best_model_fold{fold}.pth')
        
        fold_scores.append(best_acc)
        del model, optimizer, scheduler
        torch.cuda.empty_cache()
        gc.collect()
    
    print(f'\nCross-validation scores: {fold_scores}')
    print(f'Mean accuracy: {np.mean(fold_scores):.4f} ± {np.std(fold_scores):.4f}')

if __name__ == "__main__":
    print("Starting training...")
    train_model()
