# FLORAH Tree: Inference Tutorial

This notebook demonstrates how to use a pre-trained `AutoregTreeGen` model to generate synthetic merger trees. We will cover the following steps:
1. **Configuration**: Loading the necessary hyperparameters.
2. **Load Model**: Loading a trained model from a checkpoint.
3. **Prepare Input Data**: Setting up the initial conditions (root halos) for tree generation.
4. **Generate Trees**: Running the inference process.
5. **Save & Analyze Results**: Storing the generated trees and performing some basic analysis.

## 1. Configuration

First, we need to load the configuration file that was used for training the model. This ensures that the model architecture and other parameters are consistent. We will also define some inference-specific parameters.

In [None]:
import os
import sys
import pickle
import glob
import re
import numpy as np
import torch
import ml_collections
from ml_collections import config_flags
from absl import flags

# Add the project root to the Python path
sys.path.append(os.path.abspath('..'))

import datasets
from florah_tree import infer_utils, training_utils, models_utils, analysis_utils
from florah_tree.atg import AutoregTreeGen

# --- Configuration ---
# Load the same configuration file used for training
config_path = '../configs/vsmdpl-nprog3-zmax10.py'

# To avoid parsing flags in a notebook, we can manually load the config
from ml_collections import config_dict
import importlib.util

spec = importlib.util.spec_from_file_location('config', config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
config = config_module.get_config()

# --- Inference-specific settings ---
# These settings are typically in `config.data_infer`
config.data_infer = ml_collections.ConfigDict()
config.data_infer.name = 'VMDPL'
config.data_infer.root = '/path/to/your/simulation/data' # IMPORTANT: Update this path
config.data_infer.box = 'VMDPL'
config.data_infer.zmax = 10.0
config.data_infer.step = 1
config.data_infer.num_files = 1
config.data_infer.num_max_trees = 100 # Number of trees to use as roots
config.data_infer.multiplicative_factor = 1 # Generate this many trees per root
config.data_infer.batch_size = 128
config.data_infer.outdir = './inference_output'

# --- Checkpoint settings ---
# 'best', 'last', or a specific checkpoint file (e.g., 'epoch=1-step=100.ckpt')
config.checkpoint_infer = 'best'

print("Configuration loaded and updated for inference.")

## 2. Load Model

Now, we'll load the trained `AutoregTreeGen` model from the specified checkpoint. The script will automatically find the best or last checkpoint in the work directory if specified.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if config.checkpoint_infer in ['best', 'last']:
    checkpoint_dir = os.path.join(config.workdir, config.name, "lightning_logs/checkpoints")
    all_checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, "*.ckpt")))
    if not all_checkpoints:
        raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")

    if config.checkpoint_infer == 'best':
        val_losses = []
        for cp in all_checkpoints:
            loss_match = re.search(r"val_loss=([-+]?*\.\d+|\d+)", cp)
            if loss_match:
                val_losses.append(float(loss_match.group(1)))
            else:
                val_losses.append(float('inf')) # In case the name doesn't have the loss
        checkpoint_path = all_checkpoints[np.argmin(val_losses)]
    else: # 'last'
        steps = []
        for cp in all_checkpoints:
            step_match = re.search(r"step=(\d+)", cp)
            if step_match:
                steps.append(int(step_match.group(1)))
            else:
                steps.append(0)
        checkpoint_path = all_checkpoints[np.argmax(steps)]
else:
    checkpoint_path = os.path.join(config.workdir, config.name, "lightning_logs/checkpoints", config.checkpoint_infer)

print(f'Loading model from checkpoint: {checkpoint_path}')
model = AutoregTreeGen.load_from_checkpoint(checkpoint_path, map_location=device)
model.eval()
print("Model loaded successfully.")

## 3. Prepare Input Data

We need to provide the model with a set of root halos at redshift z=0. The model will then generate the merger history for these halos back in time. We also need to define the cosmic time (or redshift) steps at which the model will predict progenitors.

In [None]:
# This function reads snapshot times from metadata files.
# You might need to adapt this or provide the times directly.
DEFAULT_METADATA_DIR = "../metadata" # Assumes metadata is in project root
def read_snapshot_times(box_name):
    if "GUREFT" in box_name:
        table_name = "snapshot_times_gureft.txt"
    else:
        table_name = f"snapshot_times_{box_name.lower()}.txt"
    snapshot_times = np.genfromtxt(
        os.path.join(DEFAULT_METADATA_DIR, table_name), delimiter=',', unpack=True)
    return snapshot_times

try:
    # Get the root features from the simulation data
    print(f"Loading root halos from {config.data_infer.root}...")
    sim_data = datasets.read_dataset(
        dataset_name=config.data_infer.name,
        dataset_root=config.data_infer.root,
        index_start=0, # Start from the first file
        max_num_files=config.data_infer.num_files,
    )
    sim_data = sim_data[:config.data_infer.num_max_trees]
    print(f"  -> {len(sim_data)} root halos loaded.")

    # Get the time steps for inference
    snap_table, aexp_table, z_table = read_snapshot_times(config.data_infer.box)
    select = z_table <= config.data_infer.zmax
    snap_times_out = snap_table[select][::config.data_infer.step]
    aexp_times_out = aexp_table[select][::config.data_infer.step]

    # Create the initial input tensor (x0) and the time history tensor (Zhist)
    x0 = torch.stack([sim_data[i].x[0, :-1] for i in range(len(sim_data))], dim=0)
    Zhist = torch.tensor(aexp_times_out, dtype=torch.float32).unsqueeze(1)
    snapshot_list = torch.tensor(snap_times_out, dtype=torch.long)

    # Repeat the input if you want to generate multiple trees from the same root
    x0 = x0.repeat(config.data_infer.multiplicative_factor, 1)

    print(f"Input tensor shape (x0): {x0.shape}")
    print(f"Time history tensor shape (Zhist): {Zhist.shape}")

except FileNotFoundError:
    print(f"ERROR: Dataset not found at {os.path.join(config.data_infer.root)}")
    print("Please update the 'config.data_infer.root' path in the configuration cell.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

## 4. Generate Trees

With the model loaded and the inputs prepared, we can now run the autoregressive generation process. The `generate_forest` utility function handles the batching and the step-by-step generation.

In [None]:
import pytorch_lightning as pl

# Set a seed for reproducibility
pl.seed_everything(42)

print("Starting tree generation...")
tree_list = infer_utils.generate_forest(
    model,
    x0,
    Zhist,
    norm_dict=model.norm_dict,
    device=device,
    batch_size=config.data_infer.batch_size,
    sort=True,
    snapshot_list=snapshot_list,
    verbose=True,
)

print(f"Finished generation. {len(tree_list)} trees were created.")

## 5. Save & Analyze Results

Finally, we'll save the list of generated trees to a file. The output is a list of `torch_geometric.data.Data` objects, where each object represents a single merger tree. We can use `pickle` to save this list.

In [None]:
# Create the output directory if it doesn't exist
os.makedirs(config.data_infer.outdir, exist_ok=True)
outfile = os.path.join(config.data_infer.outdir, 'generated_trees.pkl')

print(f"Saving generated trees to {outfile}")
with open(outfile, 'wb') as f:
    pickle.dump(tree_list, f)
print("Save complete.")

### Basic Analysis

Let's load the saved trees and inspect one of them to see what the output looks like.

In [None]:
# Load the trees back from the file
with open(outfile, 'rb') as f:
    loaded_trees = pickle.load(f)

# Inspect the first tree
if loaded_trees:
    first_tree = loaded_trees[0]
    print("--- First Generated Tree ---")
    print(first_tree)
    print(f"Number of nodes (halos): {first_tree.num_nodes}")
    print(f"Number of edges (progenitor links): {first_tree.num_edges}")
    print(f"Node features shape: {first_tree.x.shape}")
    print(f"Node features (first 5 nodes): {first_tree.x[:5]}")
else:
    print("No trees were generated or loaded.")