# FLORAH Tree Generator: Training Tutorial

Welcome to the FLORAH training tutorial! This notebook provides a step-by-step guide to training your own generative model for dark matter halo merger trees. We will walk through setting up the environment, configuring the model, loading data, and running the training process.

The goal is to provide a clear, modular walkthrough that separates each key stage of the training pipeline.

## 1. Setup and Environment

First, we need to install the necessary Python packages and import the required libraries.

In [1]:
# Install required packages. Uncomment the line below if you haven't installed them yet.
# !pip install torch pytorch-lightning tensorboard ml-collections absl-py torch-geometric pyyaml numpy tqdm

print("Dependencies are assumed to be installed.")

Dependencies are assumed to be installed.


In [2]:
# Import necessary libraries
import os
import sys
import yaml
import shutil
import torch
import pytorch_lightning as pl
from ml_collections import config_dict
import numpy as np

# Import FLORAH-specific modules
import datasets
from florah_tree.atg import AutoregTreeGen
import pytorch_lightning.loggers as pl_loggers

# Add the project root to the Python path to ensure modules are found
# This assumes the notebook is run from the root of the florah-tree repository
project_root = os.getcwd()
if project_root not in sys.path:
    sys.path.append(project_root)

print(f"Project Root: {project_root}")
print(f"PyTorch Lightning version: {pl.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

Project Root: /mnt/home/tnguyen/projects/florah/florah-tree
PyTorch Lightning version: 2.1.3
CUDA available: True


## 2. Configuration

All aspects of the training run are controlled by a single configuration object. This makes experiments reproducible and easy to modify. Here, we define a sample configuration. You should adapt the paths and parameters for your specific setup and dataset.

In [3]:
# Create a configuration using ml_collections.config_dict
config = config_dict.ConfigDict()

# -- Logging and Environment --
config.workdir = './training_logs'  # Directory to save logs and checkpoints
config.name = 'florah_tutorial_run' # A name for this specific experiment
config.overwrite = True  # If True, deletes previous run with the same name
config.accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'

# -- Seeds for reproducibility --
config.seed = config_dict.ConfigDict()
config.seed.data = 42
config.seed.training = 1337

# -- Dataset Configuration --
config.data = config_dict.ConfigDict()
config.data.root = "./datasets/processed/" # Path to your processed datasets
config.data.name = "vsmdpl-nprog3-dt2_6-z10" # The name of the dataset to use
config.data.num_files = 5  # Number of data files to use for this run (for quick tests)
config.data.train_frac = 0.8  # 80% for training, 20% for validation

# -- Model Architecture --
config.model = config_dict.ConfigDict()
config.model.d_in = 2  # Input feature dimension
config.model.num_classes = 3  # Number of output classes for the progenitor classifier

# Encoder (processes halo history)
config.model.encoder = config_dict.ConfigDict({'name': 'gru', 'd_model': 128, 'd_out': 128, 'num_layers': 4})

# Decoder (generates progenitor properties)
config.model.decoder = config_dict.ConfigDict({'name': 'gru', 'd_model': 128, 'd_out': 128, 'num_layers': 4})

# Neural Posterior Estimation (for sampling continuous properties)
config.model.npe = config_dict.ConfigDict({'hidden_sizes': [128, 128], 'num_transforms': 4})

# Classifier (predicts number of progenitors)
config.model.classifier = config_dict.ConfigDict({'d_context': 1, 'hidden_sizes': [128, 128]})

# -- Optimizer and Scheduler --
config.optimizer = config_dict.ConfigDict({'name': 'AdamW', 'lr': 5e-5, 'weight_decay': 1e-4})
config.scheduler = config_dict.ConfigDict({'name': 'WarmUpCosineAnnealingLR', 'warmup_steps': 5000, 'decay_steps': 100000})

# -- Training Loop --
config.training = config_dict.ConfigDict()
config.training.max_epochs = 50
config.training.max_steps = 100_000
config.training.train_batch_size = 32  # Adjust based on your GPU memory
config.training.eval_batch_size = 64
config.training.num_workers = 4  # Number of CPU cores for data loading
config.training.gradient_clip_val = 0.5
config.training.monitor = 'val_loss' # Metric to monitor for checkpointing and early stopping
config.training.patience = 10 # Early stopping patience in epochs
config.training.save_top_k = 3 # Save the top 3 models

print(f"Configuration created for experiment: '{config.name}'")
print(f"Training will run on: {config.accelerator}")
print(f"Dataset to be used: {config.data.name}")

Configuration created for experiment: 'florah_tutorial_run'
Training will run on: gpu
Dataset to be used: vsmdpl-nprog3-dt2_6-z10


## 3. Dataset Preparation

Next, we load the dataset using the parameters from our configuration object. The `prepare_dataloader` function handles reading the raw data, splitting it into training and validation sets, and creating `DataLoader` objects for batching.

In [5]:
# Set up the working directory for the experiment
workdir = os.path.join(config.workdir, config.name)
if os.path.exists(workdir) and config.overwrite:
    print(f"Overwriting existing directory: {workdir}")
    shutil.rmtree(workdir)
os.makedirs(workdir, exist_ok=True)

# Save the configuration to the experiment directory for reproducibility
with open(os.path.join(workdir, 'config.yaml'), 'w') as f:
    f.write(config.to_yaml())

print("Loading dataset...")
try:
    # This function handles loading, splitting, and creating DataLoaders
    train_loader, val_loader, norm_dict = datasets.prepare_dataloader(
        dataset_name=config.data.name,
        dataset_root=config.data.root,
        train_batch_size=config.training.train_batch_size,
        eval_batch_size=config.training.eval_batch_size,
        seed=config.seed.data,
        num_data_files=config.data.num_files,
        train_frac=config.data.train_frac,
        num_workers=config.training.num_workers,
    )
    print("Dataset loaded successfully!")
    print(f"  -> Number of training batches: {len(train_loader)}")
    print(f"  -> Number of validation batches: {len(val_loader)}")
    print(f"  -> Normalization dictionary loaded: {list(norm_dict.keys())}")

except FileNotFoundError:
    print(f"ERROR: Dataset not found at {os.path.join(config.data.root, config.data.name)}")
    print("Please update the 'config.data.root' and 'config.data.name' paths in the configuration cell.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

Loading dataset...
An unexpected error occurred: prepare_dataloader() got an unexpected keyword argument 'dataset_name'


## 4. Model Definition

Now we instantiate the `AutoregTreeGen` model. This class, defined in `florah_tree/atg.py`, is a `pl.LightningModule` that encapsulates the entire model architecture, including the encoder, decoder, classifier, and the logic for training and validation steps.

In [None]:
print("Initializing the FLORAH model...")

# The model takes the configuration and the normalization dictionary (from the dataset) as input
model = AutoregTreeGen(
    d_in=config.model.d_in,
    num_classes=config.model.num_classes,
    encoder_args=config.model.encoder,
    decoder_args=config.model.decoder,
    npe_args=config.model.npe,
    classifier_args=config.model.classifier,
    optimizer_args=config.optimizer,
    scheduler_args=config.scheduler,
    training_args=config.training,
    norm_dict=norm_dict,
)

# Print the number of parameters to get a sense of the model's size
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized successfully!")
print(f" -> Total parameters: {total_params:,}")
print(f" -> Trainable parameters: {trainable_params:,}")

## 5. Training Setup

With the data and model ready, we now configure the `Trainer`. This PyTorch Lightning object handles the training loop, gradient updates, evaluation, and logging. We also set up **callbacks** to add functionality like saving the best models (`ModelCheckpoint`) and stopping training if performance stagnates (`EarlyStopping`).

In [None]:
# Define callbacks for the trainer
callbacks = [
    # Stop training if the validation loss doesn't improve for a number of epochs
    pl.callbacks.EarlyStopping(
        monitor=config.training.monitor,
        patience=config.training.patience,
        mode='min',
        verbose=True
    ),
    # Save the best K models based on validation loss
    pl.callbacks.ModelCheckpoint(
        filename="best-model-{epoch}-{val_loss:.4f}",
        monitor=config.training.monitor,
        save_top_k=config.training.save_top_k,
        mode='min',
    ),
    # Save the last model checkpoint
    pl.callbacks.ModelCheckpoint(filename="last-model-{epoch}"),
    # Monitor the learning rate
    pl.callbacks.LearningRateMonitor("step"),
]

# Set up the TensorBoard logger
train_logger = pl_loggers.TensorBoardLogger(workdir, version='')

# Initialize the PyTorch Lightning Trainer
trainer = pl.Trainer(
    default_root_dir=workdir,
    max_epochs=config.training.max_epochs,
    max_steps=config.training.max_steps,
    accelerator=config.accelerator,
    callbacks=callbacks,
    logger=train_logger,
    gradient_clip_val=config.training.gradient_clip_val,
    enable_progress_bar=True,
)

print("Trainer configured. Ready to start training.")

## 6. Execute Training

This is the final step. Calling `trainer.fit()` will start the training process. The trainer will use the model, training data, and validation data we prepared earlier. Progress will be displayed below, and you can monitor more detailed metrics using TensorBoard.

In [None]:
print(f"🚀 Starting training for '{config.name}'...")

# Set the seed for the training loop for reproducibility
pl.seed_everything(config.seed.training)

# Start the training!
trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

print("
🎉 Training finished!")

## 7. Review Results

After training is complete, you can find the results in the directory specified in `config.workdir`. This includes the saved model checkpoints and TensorBoard logs.

In [None]:
# The path to the best model checkpoint is stored in the callback
best_model_path = trainer.checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")

# You can list the contents of the experiment directory to see all saved files
print(f"
Contents of the experiment directory ({workdir}):")
for item in os.listdir(workdir):
    print(f"- {item}")

# To view the training metrics, run TensorBoard in your terminal:
print("
To monitor training with TensorBoard, run this command in your terminal:")
print(f"tensorboard --logdir={config.workdir}")

## Next Steps

Congratulations! You have successfully trained a FLORAH model.

You can now use this trained model to generate new merger trees. Proceed to the **inference tutorial** (`tutorial_inference.ipynb`) to learn how to load your checkpoint and generate trees.