# FLORAH Tree Generator - Training Tutorial

This tutorial will guide you through training the FLORAH Tree Generator model step by step. The FLORAH Tree Generator is a machine learning model that generates merger trees for dark matter halos in cosmological simulations.

## Overview

The training process involves:
1. Setting up the environment and dependencies
2. Preparing your dataset
3. Configuring the model parameters
4. Running the training
5. Monitoring training progress

## Prerequisites

- Python 3.8 or higher
- GPU with CUDA support (recommended)
- Sufficient disk space for datasets and model checkpoints

## Step 1: Environment Setup

First, let's install the required dependencies and set up the environment.

In [None]:
# Install required packages
!pip install torch pytorch-lightning tensorboard ml-collections absl-py torch-geometric pyyaml numpy tqdm

In [None]:
# Import necessary libraries
import os
import sys
import yaml
import torch
import pytorch_lightning as pl
from ml_collections import config_dict
import numpy as np

# Add the project root to Python path
project_root = os.getcwd()
if project_root not in sys.path:
    sys.path.append(project_root)

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA devices: {torch.cuda.device_count()}")

## Step 2: Dataset Preparation

The model requires processed merger tree data. Your dataset should be organized as follows:

```
datasets/
├── processed/
│   └── your_dataset_name/
│       ├── file_0.pkl
│       ├── file_1.pkl
│       └── ...
```

Each pickle file should contain merger tree data in the expected format.

In [None]:
# Check if dataset directory exists and list available datasets
dataset_root = "./datasets/processed/"

if os.path.exists(dataset_root):
    datasets = [d for d in os.listdir(dataset_root) if os.path.isdir(os.path.join(dataset_root, d))]
    print(f"Available datasets in {dataset_root}:")
    for dataset in datasets:
        dataset_path = os.path.join(dataset_root, dataset)
        num_files = len([f for f in os.listdir(dataset_path) if f.endswith('.pkl')])
        print(f"  - {dataset}: {num_files} files")
else:
    print(f"Dataset directory not found: {dataset_root}")
    print("Please create the dataset directory and add your processed data.")

## Step 3: Configuration Setup

Now let's create a configuration for training. We'll start with a basic configuration that you can modify based on your needs.

In [None]:
def create_training_config(dataset_name, experiment_name):
    """Create a training configuration"""
    config = config_dict.ConfigDict()

    # Basic experiment settings
    config.workdir = './training_logs'  # Where to save training logs and checkpoints
    config.name = experiment_name
    config.overwrite = True  # Overwrite existing experiment directory
    config.enable_progress_bar = True
    config.checkpoint = None  # Start from scratch
    config.accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
    config.reset_optimizer = False

    # Random seeds for reproducibility
    config.seed = config_dict.ConfigDict()
    config.seed.data = 42
    config.seed.training = 1337
    config.seed.inference = 9999

    # Dataset configuration
    config.data = config_dict.ConfigDict()
    config.data.root = "./datasets/processed/"
    config.data.name = dataset_name
    config.data.num_files = 10  # Number of dataset files to use
    config.data.index_file_start = 0
    config.data.train_frac = 0.8  # 80% for training, 20% for validation
    config.data.reverse_time = False

    # Model architecture
    config.model = config_dict.ConfigDict()
    config.model.name = 'atg2'
    config.model.d_in = 2  # Input feature dimension
    config.model.num_classes = 3  # Number of output classes

    # Encoder configuration (GRU-based)
    config.model.encoder = config_dict.ConfigDict()
    config.model.encoder.name = 'gru'
    config.model.encoder.d_model = 128
    config.model.encoder.d_out = 128
    config.model.encoder.dim_feedforward = 128
    config.model.encoder.num_layers = 4
    config.model.encoder.concat = False

    # Decoder configuration (GRU-based)
    config.model.decoder = config_dict.ConfigDict()
    config.model.decoder.name = 'gru'
    config.model.decoder.d_model = 128
    config.model.decoder.d_out = 128
    config.model.decoder.dim_feedforward = 128
    config.model.decoder.num_layers = 4
    config.model.decoder.concat = False

    # Neural Posterior Estimation (NPE) configuration
    config.model.npe = config_dict.ConfigDict()
    config.model.npe.hidden_sizes = [128, 128]
    config.model.npe.num_transforms = 4
    config.model.npe.context_embedding_sizes = None
    config.model.npe.dropout = 0.2

    # Classifier configuration
    config.model.classifier = config_dict.ConfigDict()
    config.model.classifier.d_context = 1
    config.model.classifier.hidden_sizes = [128, 128]

    # Optimizer configuration
    config.optimizer = config_dict.ConfigDict()
    config.optimizer.name = 'AdamW'
    config.optimizer.lr = 5e-5  # Learning rate
    config.optimizer.betas = (0.9, 0.98)
    config.optimizer.weight_decay = 1e-4
    config.optimizer.eps = 1e-9

    # Learning rate scheduler
    config.scheduler = config_dict.ConfigDict()
    config.scheduler.name = 'WarmUpCosineAnnealingLR'
    config.scheduler.decay_steps = 100_000  # Total training steps
    config.scheduler.warmup_steps = 5_000   # Warmup steps
    config.scheduler.eta_min = 1e-6
    config.scheduler.interval = 'step'

    # Training configuration
    config.training = config_dict.ConfigDict()
    config.training.max_epochs = 100
    config.training.max_steps = 100_000
    config.training.train_batch_size = 32  # Adjust based on your GPU memory
    config.training.eval_batch_size = 32
    config.training.monitor = 'val_loss'
    config.training.patience = 10  # Early stopping patience
    config.training.save_top_k = 3
    config.training.save_last_k = 3
    config.training.gradient_clip_val = 0.5
    config.training.num_workers = 4

    # Training mode settings
    config.training.training_mode = 'all'
    config.training.use_sample_weight = False
    config.training.use_desc_mass_ratio = False
    config.training.num_branches_per_tree = 10

    # Freezing configuration (which parts of the model to freeze)
    config.training.freeze_args = config_dict.ConfigDict()
    config.training.freeze_args.encoder = False
    config.training.freeze_args.decoder = False
    config.training.freeze_args.npe = False
    config.training.freeze_args.classifier = False

    return config

# Create a sample configuration
# Replace 'your_dataset_name' with your actual dataset name
dataset_name = "sample_dataset"  # Change this to your dataset name
experiment_name = "my_first_training"

config = create_training_config(dataset_name, experiment_name)
print(f"Created configuration for experiment: {experiment_name}")
print(f"Dataset: {dataset_name}")
print(f"Training device: {config.accelerator}")

## Step 4: Configuration Customization

You can modify the configuration based on your specific needs:

In [None]:
# Customize configuration based on your needs

# If you have a specific dataset, update the name and number of files
# config.data.name = "your_actual_dataset_name"
# config.data.num_files = 20  # Adjust based on your dataset size

# If you have limited GPU memory, reduce batch size
# config.training.train_batch_size = 16
# config.training.eval_batch_size = 16

# If you want to train for longer
# config.training.max_epochs = 200
# config.training.max_steps = 200_000

# If you want to adjust the model size (smaller for faster training, larger for better performance)
# config.model.encoder.d_model = 64  # Smaller model
# config.model.decoder.d_model = 64

print("Configuration ready for training!")
print(f"Batch size: {config.training.train_batch_size}")
print(f"Max epochs: {config.training.max_epochs}")
print(f"Model size: {config.model.encoder.d_model}")

## Step 5: Training Function

Now let's set up the training function. This is adapted from the `train_atg.py` script.

In [None]:
# Import training modules
import datasets
from florah_tree.atg import AutoregTreeGen
import pytorch_lightning.loggers as pl_loggers
import shutil

def train_model(config):
    """Train the FLORAH Tree Generator model"""

    # Set up work directory
    workdir = os.path.join(config.workdir, config.name)

    # Handle existing directory
    if os.path.exists(workdir):
        if config.overwrite:
            shutil.rmtree(workdir)
            print(f"Removed existing directory: {workdir}")
        else:
            raise ValueError(f"Directory {workdir} already exists. Set overwrite=True to overwrite.")

    os.makedirs(workdir, exist_ok=True)
    print(f"Created training directory: {workdir}")

    # Save configuration
    config_path = os.path.join(workdir, 'config.yaml')
    with open(config_path, 'w') as f:
        yaml.dump(config.to_dict(), f, default_flow_style=False)
    print(f"Saved configuration to: {config_path}")

    # Load dataset and prepare dataloader
    print("Loading dataset...")
    try:
        train_loader, val_loader, norm_dict = datasets.prepare_dataloader(
            datasets.read_dataset(
                dataset_root=config.data.root,
                dataset_name=config.data.name,
                max_num_files=config.data.get("num_files", 1),
            ),
            train_frac=config.data.train_frac,
            train_batch_size=config.training.train_batch_size,
            eval_batch_size=config.training.eval_batch_size,
            use_sampler=config.training.get("use_sampler", False),
            sampler_args=config.training.get("sampler_args"),
            num_workers=config.training.get("num_workers", 0),
            seed=config.seed.data,
            norm_dict=None,
            reverse_time=config.data.get("reverse_time", False),
        )
        print(f"Dataset loaded successfully!")
        print(f"Training batches: {len(train_loader)}")
        print(f"Validation batches: {len(val_loader)}")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

    # Create the model
    print("Creating model...")
    model = AutoregTreeGen(
        d_in=config.model.d_in,
        num_classes=config.model.num_classes,
        encoder_args=config.model.encoder,
        decoder_args=config.model.decoder,
        npe_args=config.model.npe,
        classifier_args=config.model.classifier,
        optimizer_args=config.optimizer,
        scheduler_args=config.scheduler,
        training_args=config.training,
        norm_dict=norm_dict,
    )

    # Count model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model created with {total_params:,} total parameters ({trainable_params:,} trainable)")

    # Create callbacks
    callbacks = [
        pl.callbacks.EarlyStopping(
            monitor=config.training.get('monitor', 'val_loss'),
            patience=config.training.get('patience', 10),
            mode='min',
            verbose=True
        ),
        pl.callbacks.ModelCheckpoint(
            filename="best-{epoch}-{step}-{val_loss:.4f}",
            monitor=config.training.get('monitor', 'val_loss'),
            save_top_k=config.training.get('save_top_k', 1),
            mode='min',
            save_weights_only=False,
            save_last=False
        ),
        pl.callbacks.ModelCheckpoint(
            filename="last-{epoch}-{step}-{val_loss:.4f}",
            save_top_k=config.training.get('save_last_k', 1),
            monitor='epoch',
            mode='max',
            save_weights_only=False,
            save_last=False
        ),
        pl.callbacks.LearningRateMonitor("step"),
    ]

    # Create trainer
    train_logger = pl_loggers.TensorBoardLogger(workdir, version='')
    trainer = pl.Trainer(
        default_root_dir=workdir,
        max_epochs=config.training.max_epochs,
        max_steps=config.training.max_steps,
        accelerator=config.accelerator,
        callbacks=callbacks,
        logger=train_logger,
        gradient_clip_val=config.training.get('gradient_clip_val', 0),
        enable_progress_bar=config.get("enable_progress_bar", True),
        num_sanity_val_steps=0,
    )

    return model, trainer, train_loader, val_loader

print("Training function ready!")

## Step 6: Start Training

Now let's start the training process. Make sure you have your dataset ready before running this cell.

In [None]:
# Before running training, let's do a quick check
dataset_path = os.path.join(config.data.root, config.data.name)
if not os.path.exists(dataset_path):
    print(f"❌ Dataset not found at: {dataset_path}")
    print("Please make sure your dataset is properly located.")
    print("\nDataset should be organized as:")
    print(f"  {config.data.root}")
    print(f"  └── {config.data.name}/")
    print("      ├── file_0.pkl")
    print("      ├── file_1.pkl")
    print("      └── ...")
else:
    print(f"✅ Dataset found at: {dataset_path}")

    # Set up training
    try:
        model, trainer, train_loader, val_loader = train_model(config)

        if model is not None:
            print("\n🚀 Starting training...")
            print(f"Training will run for up to {config.training.max_epochs} epochs or {config.training.max_steps} steps")
            print(f"Early stopping patience: {config.training.patience} epochs")
            print(f"Checkpoints will be saved to: {os.path.join(config.workdir, config.name)}")

            # Set random seed for reproducibility
            pl.seed_everything(config.seed.training)

            # Start training
            trainer.fit(
                model,
                train_dataloaders=train_loader,
                val_dataloaders=val_loader,
            )

            print("\n🎉 Training completed!")
            print(f"Best model saved to: {trainer.checkpoint_callback.best_model_path}")

    except Exception as e:
        print(f"❌ Training failed with error: {e}")
        import traceback
        traceback.print_exc()

## Step 7: Monitoring Training Progress

You can monitor the training progress using TensorBoard.

In [None]:
# Start TensorBoard to monitor training
tensorboard_log_dir = os.path.join(config.workdir, config.name, "lightning_logs")

if os.path.exists(tensorboard_log_dir):
    print(f"TensorBoard logs available at: {tensorboard_log_dir}")
    print("\nTo view training progress, run in terminal:")
    print(f"tensorboard --logdir {tensorboard_log_dir}")
    print("\nThen open http://localhost:6006 in your browser")

    # You can also start TensorBoard from here (uncomment the next line)
    # %load_ext tensorboard
    # %tensorboard --logdir {tensorboard_log_dir}
else:
    print("No TensorBoard logs found yet. Start training first.")

## Step 8: Training Results and Next Steps

After training completes, you'll find the following in your training directory:

In [None]:
# Check training results
training_dir = os.path.join(config.workdir, config.name)

if os.path.exists(training_dir):
    print(f"Training directory: {training_dir}")
    print("\nContents:")
    for item in os.listdir(training_dir):
        item_path = os.path.join(training_dir, item)
        if os.path.isdir(item_path):
            print(f"  📁 {item}/")
        else:
            print(f"  📄 {item}")

    # Check for checkpoints
    checkpoint_dir = os.path.join(training_dir, "lightning_logs", "checkpoints")
    if os.path.exists(checkpoint_dir):
        checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
        print(f"\nCheckpoints ({len(checkpoints)} found):")
        for ckpt in sorted(checkpoints):
            print(f"  🔖 {ckpt}")
    else:
        print("\nNo checkpoints found.")
else:
    print("Training directory not found. Run training first.")

## Troubleshooting

### Common Issues and Solutions:

1. **Dataset not found**: Make sure your dataset is in the correct location and format.

2. **Out of memory**: Reduce batch size in the configuration:
   ```python
   config.training.train_batch_size = 16  # or smaller
   config.training.eval_batch_size = 16
   ```

3. **Training too slow**: 
   - Use GPU if available
   - Reduce model size:
     ```python
     config.model.encoder.d_model = 64
     config.model.decoder.d_model = 64
     ```

4. **Training not converging**: 
   - Adjust learning rate: `config.optimizer.lr = 1e-4`
   - Increase patience: `config.training.patience = 20`
   - Check your data quality

## Next Steps

After successful training:

1. **Evaluate your model**: Use validation metrics from TensorBoard
2. **Run inference**: Use the inference tutorial with your trained model
3. **Fine-tune**: Adjust hyperparameters and retrain if needed
4. **Save your model**: The best checkpoint is automatically saved

## Configuration Templates

Here are some preset configurations for different scenarios:

In [None]:
def get_quick_test_config():
    """Quick test configuration for debugging"""
    config = create_training_config("your_dataset", "quick_test")
    config.training.max_epochs = 5
    config.training.max_steps = 100
    config.training.train_batch_size = 8
    config.training.eval_batch_size = 8
    config.model.encoder.d_model = 32
    config.model.decoder.d_model = 32
    return config

def get_high_performance_config():
    """High performance configuration for production"""
    config = create_training_config("your_dataset", "high_performance")
    config.training.max_epochs = 500
    config.training.max_steps = 500_000
    config.training.train_batch_size = 64
    config.training.eval_batch_size = 64
    config.model.encoder.d_model = 256
    config.model.decoder.d_model = 256
    config.model.encoder.num_layers = 6
    config.model.decoder.num_layers = 6
    config.optimizer.lr = 1e-5
    return config

def get_memory_efficient_config():
    """Memory efficient configuration for limited resources"""
    config = create_training_config("your_dataset", "memory_efficient")
    config.training.train_batch_size = 8
    config.training.eval_batch_size = 8
    config.model.encoder.d_model = 64
    config.model.decoder.d_model = 64
    config.model.encoder.num_layers = 2
    config.model.decoder.num_layers = 2
    config.training.num_workers = 2
    return config

print("Configuration templates created!")
print("Use get_quick_test_config(), get_high_performance_config(), or get_memory_efficient_config()")

## Conclusion

You've successfully learned how to train the FLORAH Tree Generator! 

**Summary of what you accomplished:**
- Set up the training environment
- Configured the model parameters
- Prepared the dataset
- Ran the training process
- Learned how to monitor progress

**Your trained model is now ready for inference!** 

Check out the inference tutorial (`tutorial_inference.ipynb`) to learn how to generate merger trees using your trained model.

Good luck with your cosmological simulations! 🌌