# Model Training Example

This notebook demonstrates how to train a neural network on using PyTorch Lightning for learning purposes. For small models or datasets this might be fine but for large scale projects you should be using train.py.

## 1. Setup and Imports

In [None]:
import sys
sys.path.insert(0, '..')  # Add parent directory to path for imports

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from datamodule import LazyDataModule
from model import MLPLightningModule, MLP

# Set random seed for reproducibility
pl.seed_everything(42)

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pl.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration

In [None]:
# Configuration
args = {
    # Data
    "train_files": "/path/to/your/data",  # TODO: Update this path
    "batch_size": 32,
    "num_workers": 4,
    
    # Model
    "in_channels": 10,  # TODO: Adjust based on your input features
    "out_channels": 1,  # TODO: Adjust based on your target
    "hidden_dims": [64, 128, 64],
    
    # Training
    "learning_rate": 1e-3,
    "max_epochs": 50,
    "seed": 42,
    
    # Output
    "output_dir": "../outputs",
}

## 3. Create Dataset

In [None]:
output_dir = Path(args["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)

# Initialize data module
datamodule = LazyDataModule(
    train_files=args["train_files"],
    batch_size=args["batch_size"],
    num_workers=args["num_workers"],
    seed=args["seed"],
)

## 4. Create DataLoaders

In [None]:
# Initialize model
model = MLP(
    in_channels=args["in_channels"],
    out_channels=args["out_channels"],
)

# Wrap in Lightning module
lightning_module = MLPLightningModule(
    model=model,
    learning_rate=args["learning_rate"],
)

# Set up callbacks
callbacks = [
    ModelCheckpoint(
        dirpath=output_dir / "checkpoints",
        filename="mlp-{epoch:02d}-{val_loss:.4f}",
        monitor="val_loss",
        mode="min",
        save_top_k=3,
    ),
    EarlyStopping(
        monitor="val_loss",
        patience=10,
        mode="min",
    ),
    LearningRateMonitor(logging_interval="epoch"),
]

# Set up logger
logger = TensorBoardLogger(
    save_dir=output_dir,
    name="logs",
)  # we could also use wandb: https://wandb.ai/home

# Initialize trainer
trainer = pl.Trainer(
    max_epochs=args["max_epochs"],
    callbacks=callbacks,
    logger=logger,
    accelerator="auto",
    devices="auto",
)

## 7. Train Model

In [None]:
trainer.fit(lightning_module, datamodule=datamodule)

## 8. Save and Load Model

In [None]:
# Save model manually
torch.save(lightning_module.state_dict(), "model_weights.pt")

# Load model
# loaded_model = LazyLightningModule.load_from_checkpoint("checkpoints/best_model.ckpt", model=model)