# Training the Multi-Modal Neural Network

This notebook demonstrates how to configure and run training for the multi-modal neural network with double-loop learning.

## Import Libraries

In [None]:
import sys
sys.path.append('..')

import torch
import yaml
from src.training.trainer import Trainer
from src.utils.config import load_config

print("Libraries imported successfully")

## Load Configuration

In [None]:
# Load config
config_path = '../configs/default.yaml'
config = load_config(config_path)
print("Configuration loaded:")
print(yaml.dump(config, default_flow_style=False))

## Initialize Trainer

In [None]:
# Initialize trainer
trainer = Trainer(config_path=config_path)
print("Trainer initialized")
print(f"Model device: {trainer.device}")
print(f"Number of parameters: {sum(p.numel() for p in trainer.model.parameters()):,}")

## Verify Data Loading


In [None]:
# Run this cell before training
print(f"Train loader size: {len(trainer.train_loader)}")
print(f"Number of batches: {len(trainer.train_loader.dataset)}")

# Test loading one batch
try:
    batch = next(iter(trainer.train_loader))
    print(f"Batch keys: {batch.keys()}")
    print(f"Batch shapes: {[(k, v.shape) for k, v in batch.items() if isinstance(v, torch.Tensor)]}")
except StopIteration:
    print("ERROR: Train loader is empty!")

## Run Training

In [None]:
# Run training (this may take time depending on your hardware)
# Uncomment the line below to start training
# trainer.train()

# For a quick test with limited epochs:
print("To run training, set max_epochs in config and uncomment:")
print("  trainer.train()")
print("\nCurrent training config:")
print(f"  Max epochs: {trainer.config.get('training', {}).get('max_epochs', 'not set')}")
print(f"  Batch size: {trainer.config.get('data', {}).get('batch_size', 'not set')}")
print(f"  Learning rate: {trainer.config.get('training', {}).get('learning_rate', 'not set')}")

## Monitor Training Progress

You can monitor training using TensorBoard or W&B (Weights & Biases) if configured.

In [None]:
# Check training logs and outputs
import os
from pathlib import Path

output_dir = Path('../outputs')
checkpoint_dir = output_dir / 'checkpoints'
log_dir = output_dir / 'logs'

print("Training artifacts:")
if checkpoint_dir.exists():
    checkpoints = list(checkpoint_dir.glob('*.pt'))
    print(f"\n✓ Checkpoints ({len(checkpoints)}):")
    for ckpt in sorted(checkpoints)[-5:]:  # Show last 5
        size_mb = ckpt.stat().st_size / (1024 * 1024)
        print(f"  - {ckpt.name} ({size_mb:.1f} MB)")
else:
    print("\n  No checkpoints yet")

if log_dir.exists():
    logs = list(log_dir.glob('*.log'))
    print(f"\n✓ Logs ({len(logs)}):")
    for log in sorted(logs)[-3:]:  # Show last 3
        print(f"  - {log.name}")
else:
    print("\n  No logs yet")

# Note about monitoring
print("\nTo monitor training in real-time:")
print("  - Check logs in: outputs/logs/")
print("  - View checkpoints in: outputs/checkpoints/")
if config.get('logging', {}).get('use_wandb', False):
    print("  - W&B dashboard: https://wandb.ai/")