In [None]:
import torch
from tqdm.auto import tqdm
import torch.nn as nn

def train_epoch(model, loader, criterion, optimizer, scheduler):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(model.device)
        labels = labels.to(model.device)
        
        optimizer.zero_grad()
        
        outputs = model(images).squeeze(1)
        loss = criterion(outputs, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if scheduler is not None:
            scheduler.step()
            
        running_loss += loss.item()
        pred = (torch.sigmoid(outputs) > 0.5).float()
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })
    
    return running_loss / len(loader), correct / total

def validate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images = images.to(model.device)
            labels = labels.to(model.device)
            
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            pred = (torch.sigmoid(outputs) > 0.5).float()
            correct += (pred == labels).sum().item()
            total += labels.size(0)
    
    return running_loss / len(loader), correct / total
