# Face Recognition and Verification Training

This notebook trains a ResNet model for face recognition and verification tasks.

In [None]:
import torch
import wandb
from src.models.resnet import ResNet50
from src.data.datasets import AlbumentationsDataset, get_transforms
from src.utils.config import get_config
from src.utils.train_utils import train_one_epoch, validate, save_checkpoint

In [None]:
# Initialize configuration
config = get_config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Initialize WandB
wandb.init(
    project="face-recognition",
    config=config
)

In [None]:
# Setup data
train_transforms, val_transforms = get_transforms()

train_dataset = AlbumentationsDataset(
    os.path.join(config['data_dir'], 'train'),
    transform=train_transforms
)

val_dataset = AlbumentationsDataset(
    os.path.join(config['data_dir'], 'dev'),
    transform=val_transforms
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    pin_memory=True,
    num_workers=config['num_workers'],
    persistent_workers=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=4,
    persistent_workers=True
)

In [None]:
# Initialize model, optimizer and scheduler
model = ResNet50(num_classes=config['num_classes']).to(device)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['lr'],
    weight_decay=config['weight_decay']
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=config['scheduler_step_size'],
    gamma=config['scheduler_gamma']
)

In [None]:
# Training loop
best_val_acc = 0

for epoch in range(config['epochs']):
    # Train
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, device, epoch
    )
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, device)
    
    # Step scheduler
    scheduler.step()
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"{config['checkpoint_dir']}/best_model.pth"
        )
    
    print(f"Epoch {epoch}:")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"Best Val Acc: {best_val_acc:.2f}%\n")