# Training the Multi-Modal Neural Network

This notebook demonstrates training with the multi-dataset selector.

**Prerequisites:**
- Complete `01_getting_started.ipynb` to verify your environment
- Configure datasets in `../configs/default.yaml` under `data.datasets`

**Testing:** Run `make test` or `pytest tests/test_integration.py` to verify training components.

## Inspect Configured Datasets

In [None]:
import sys
sys.path.append('..')

from src.utils.config import load_config

config_path = '../configs/default.yaml'
config = load_config(config_path)
print('Datasets defined:')
for ds in config.get('data', {}).get('datasets', []):
    print(f"- {ds.get('name')} type={ds.get('type')} enabled={ds.get('enabled', True)} splits={ds.get('splits')}")

## Build DataLoaders (Selector)

In [None]:
from src.data import build_dataloaders
train_loader, val_loader, test_loader = build_dataloaders(config)
print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader) if val_loader else 0}')
print(f'Test batches: {len(test_loader) if test_loader else 0}')

## Initialize Trainer (Auto selector fallback)

In [None]:
from src.training.trainer import Trainer
trainer = Trainer(config_path=config_path)
print('Trainer loaders: train', len(trainer.train_loader), 'val', len(trainer.val_loader) if trainer.val_loader else 0, 'test', len(trainer.test_loader) if getattr(trainer, 'test_loader', None) else 0)

## Inspect First Batch

In [None]:
batch = next(iter(trainer.train_loader))
print('Batch keys:', batch.keys())
print('Tensor shapes:', [(k, v.shape) for k,v in batch.items() if hasattr(v,'shape')])

## Train

In [None]:
# Uncomment to start training
# trainer.train()

## Notes
- Modify `configs/default.yaml` to enable/disable datasets.
- Ensure split ratios sum to 1.0.
- Legacy keys (`train_dataset`, `val_dataset`) are ignored when `datasets` is present.

In [None]:
# (Optional) Advanced: Rebuild loaders after modifying config in-memory
# Example:
# config['data']['datasets'][0]['splits']['train'] = 0.7
# config['data']['datasets'][0]['splits']['val'] = 0.2
# config['data']['datasets'][0]['splits']['test'] = 0.1
# train_loader, val_loader, test_loader = build_dataloaders(config)
# trainer.train_loader = train_loader
# trainer.val_loader = val_loader
# trainer.test_loader = test_loader
# print('Rebuilt loaders: train', len(train_loader), 'val', len(val_loader or []), 'test', len(test_loader or []))