# Train SimFormer on Mock Galaxy Data

This notebook trains a SimFormer model on mock galaxy photometric data with:
- Multiple surveys (Gaia, 2MASS, WISE, PS1, DECam)
- Error embeddings from measurement uncertainties
- Observed/unobserved mask embeddings
- Age-balanced subsampling for uniform logAge coverage

In [1]:
import os, time
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR

from prepare_data import load_and_filter, generate_synthetic_errors

from columns import (
    INTRINSIC_COLS, TRUE_MAG_COLS, OBS_COLS, OBS_ERR_COLS,
    ALL_VALUE_COLS, NUM_NODES, N_INTRINSIC, N_TRUE_MAG,
)
from train_mock_galaxy import (
    build_arrays, compute_age_bin_indices, create_model, 
    make_epoch_callback
) 
from transformer import Simformer
from simflower import FlowMatchingTrainer
from utils import make_condition_mask_generator

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cuda


## 1. Load & Filter Data

In [2]:
fname = '/n/home12/sratzenboeck/data_local/mock/train_data_prepared.parquet'

if "parquet" in fname:
    df = pd.read_parquet(fname)
else:
    df = load_and_filter(data_path=fname, max_rad=5, min_log_age=5)
    df = generate_synthetic_errors(df, unobs_frac=0.2, seed=0)
    print(f'After distance cut (rad < 5 kpc): {len(df):,} stars')
    fname_out = '/n/home12/sratzenboeck/data_local/mock/train_data_prepared.parquet'
    df.to_parquet(fname_out, index=False)

In [3]:
# Build data, errors, masks, means, and standard devs

In [None]:
data, data_errors, data_observed_mask, means, stds, log_err_mean, log_err_std = build_arrays(df)

In [None]:
# Save normalization stats
norm_path = os.path.join('/n/home12/sratzenboeck/data_local/mock/', 'norm_stats.npz')
np.savez(norm_path, means=means, stds=stds, columns=ALL_VALUE_COLS,
         log_err_mean=log_err_mean, log_err_std=log_err_std)
print(f'  Normalization stats saved to {norm_path}')

In [None]:
train_infos = dict(
    do_compile = True,
    batch_size = 2**12,
    lr = 1e-3,
    inner_loop_size = 500,
    patience = 20,
    val_split = 0.2,
    amp = True,
    wandb = True,
    wandb_project = 'test_mock_galaxy',
    n_bins = 30,
    lr_min = 1e-5,
    epochs = 500,
    tau_max = 0.7, 
    tau_warmup = 100,
    dense_ratio = 0.9,
    cap_per_bin = 1000,
)
do_compile = True

In [None]:
# ---- Age bin indices for curriculum weighting ----
# We need bin indices that correspond to the training split.
# The trainer does train_test_split internally, so we replicate the split
# to get the correct bin indices for the training subset.
log_age = df['logAge'].values
bin_idx_all, _ = compute_age_bin_indices(log_age, train_infos['n_bins'])

n_total = len(data)
all_indices = np.arange(n_total)
train_indices, val_indices = train_test_split(
    all_indices, test_size=train_infos['val_split'], random_state=42
)
bin_idx_train = bin_idx_all[train_indices]
bin_idx_val = bin_idx_all[val_indices]
bin_counts_train = np.bincount(bin_idx_train, minlength=train_infos['n_bins']).astype(np.float64)
bin_counts_val = np.bincount(bin_idx_val, minlength=train_infos['n_bins']).astype(np.float64)
print(f"  Curriculum: {train_infos['n_bins']} bins, train={len(train_indices):,}, val={len(val_indices):,}")

In [None]:
# --- Diagnostic: visualize cap-based sampling for various τ values ---
from train_mock_galaxy import build_epoch_indices

tau_values = [0.0, 0.2, 0.5, 0.7, 1.0]
cap = train_infos['cap_per_bin']
log_age_train = log_age[train_indices]

# Compute bin edges for consistent histogram bins
bin_edges_hist = np.linspace(log_age.min(), log_age.max(), 50)

fig, axes = plt.subplots(2, 3, figsize=(16, 9), sharex=True)
axes = axes.ravel()

# First panel: original (natural) distribution
axes[0].hist(log_age_train, bins=bin_edges_hist, alpha=0.8, color='grey', density=True)
axes[0].set_title(f'Original distribution\n({len(log_age_train):,} stars)', fontsize=11)
axes[0].set_ylabel('Density')

# Remaining panels: cap-based sampling at each τ
rng_diag = np.random.default_rng(0)
for i, tau in enumerate(tau_values):
    ax = axes[i + 1]
    idx = build_epoch_indices(bin_idx_train, bin_counts_train, tau, cap, rng_diag)
    selected_ages = log_age_train[idx]

    ax.hist(selected_ages, bins=bin_edges_hist, alpha=0.8, density=True,
            color='steelblue' if tau < 0.99 else 'coral')
    ax.set_title(f'τ = {tau:.1f}  ({len(idx):,} stars)', fontsize=11)

    if i + 1 >= 3:
        ax.set_xlabel('logAge')
    if (i + 1) % 3 == 0:
        ax.set_ylabel('Density')

plt.suptitle(f'Cap-based epoch sampling (cap_per_bin={cap}, n_bins={train_infos["n_bins"]})',
             fontsize=13, y=1.01)
plt.tight_layout()
plt.show()

# Print per-bin stats at τ=0
print(f"\n--- Per-bin counts at τ=0 (cap={cap}) ---")
print(f"{'Bin':>4}  {'Total':>8}  {'Selected':>8}  {'logAge range'}")
print("-" * 50)
bin_edges = np.linspace(log_age.min(), log_age.max() + 1e-6, train_infos['n_bins'] + 1)
rng_diag2 = np.random.default_rng(0)
idx_tau0 = build_epoch_indices(bin_idx_train, bin_counts_train, 0.0, cap, rng_diag2)
sel_bin_idx = bin_idx_train[idx_tau0]
sel_counts = np.bincount(sel_bin_idx, minlength=train_infos['n_bins'])
for b in range(train_infos['n_bins']):
    lo, hi = bin_edges[b], bin_edges[b + 1]
    n_total = int(bin_counts_train[b])
    n_sel = int(sel_counts[b])
    flag = " (all)" if n_sel == n_total else ""
    print(f"{b:4d}  {n_total:8,}  {n_sel:8,}  [{lo:.2f}, {hi:.2f}){flag}")

In [8]:
# ---- Model ----
print('\n--- Model ---')
model = create_model()
if do_compile:
    print('  Compiling model with torch.compile ...')
    model = torch.compile(model)


--- Model ---
  Model created: 816,225 parameters
  Compiling model with torch.compile ...


In [9]:
# ---- Condition mask generator ----
obs_indices = list(range(N_INTRINSIC + N_TRUE_MAG, NUM_NODES))
cond_gen = make_condition_mask_generator(
    batch_size=train_infos['batch_size'],
    num_features=NUM_NODES,
    percent=(0.1, 0.5),
    allowed_idx=obs_indices,
    device=device,
)

# ---- Trainer ----
print('\n--- Trainer ---')
trainer = FlowMatchingTrainer(
    model=model,
    data=data,
    data_errors=data_errors,
    data_observed_mask=data_observed_mask,
    condition_mask_generator=cond_gen,
    batch_size=train_infos['batch_size'],
    lr=train_infos['lr'],
    dense_ratio=train_infos['dense_ratio'],
    inner_train_loop_size=train_infos['inner_loop_size'],
    early_stopping_patience=train_infos['patience'],
    val_split=train_infos['val_split'],
    use_amp=train_infos['amp'],
    use_wandb=train_infos['wandb'],
    wandb_project=train_infos['wandb_project'],
    wandb_config=train_infos,
    device=device,
)
print('  Trainer initialized.')


--- Trainer ---


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msebastian-ratzenboeck[0m ([33msebastian-ratzenboeck-harvard-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  Trainer initialized.


In [10]:
print('ready to train...')

ready to train...


In [None]:
# ---- LR scheduler ----
lr_scheduler = CosineAnnealingLR(
    trainer.optimizer, T_max=train_infos['epochs'], eta_min=train_infos['lr_min']
)

# ---- Epoch callback (curriculum) ----
epoch_cb = make_epoch_callback(
    bin_idx_train=bin_idx_train,
    bin_counts_train=bin_counts_train,
    bin_idx_val=bin_idx_val,
    bin_counts_val=bin_counts_val,
    tau_max=train_infos['tau_max'],
    tau_warmup=train_infos['tau_warmup'],
    cap_per_bin=train_infos['cap_per_bin'],
    use_wandb=train_infos['wandb'],
)

# ---- Train ----
print(f"\n--- Training ({train_infos['epochs']} epochs) ---")
t0 = time.time()
best_model = trainer.fit(
    epochs=train_infos['epochs'],
    verbose=True,
    epoch_callback=epoch_cb,
    lr_scheduler=lr_scheduler,
)
elapsed = time.time() - t0
print(f'\nTraining completed in {elapsed / 60:.1f} minutes.')

In [None]:
# ---- Save best model ----
output_dir = '/n/home12/sratzenboeck/data_local/mock/'
ckpt_path = os.path.join(output_dir, 'best_model.pt')
torch.save(best_model.state_dict(), ckpt_path)
print(f'  Best model saved to {ckpt_path}')

## 10. Quick Validation: Sample Posterior

Condition on observed Gaia + 2MASS photometry and sample intrinsic parameters.

In [None]:
from sampling import sample_batched_flow, build_inference_edge_mask

# Pick a test star (from validation set)
test_idx = 0
x_test = torch.tensor(data[test_idx], dtype=torch.float32)
x_test_errors = torch.tensor(data_errors[test_idx], dtype=torch.float32)
x_test_observed = torch.tensor(data_observed_mask[test_idx], dtype=torch.float32)

# Condition on observed photometry columns
cond_col_names = [
    'GAIA_GAIA3.G_mag_obs', 'GAIA_GAIA3.Gbp_mag_obs', 'GAIA_GAIA3.Grp_mag_obs',
    '2MASS_H_mag_obs', '2MASS_J_mag_obs', '2MASS_Ks_mag_obs',
    'parallax_obs',
]
cond_indices = [ALL_VALUE_COLS.index(c) for c in cond_col_names if c in ALL_VALUE_COLS]

# Safety check: only condition on columns that are actually observed for this star
obs_flags = x_test_observed[cond_indices]
unobs_cond = [ALL_VALUE_COLS[cond_indices[i]] for i, o in enumerate(obs_flags) if o == 0]
if unobs_cond:
    print(f"WARNING: dropping unobserved conditioned columns: {unobs_cond}")
    cond_indices = [idx for idx, o in zip(cond_indices, obs_flags) if o == 1]

print(f"Conditioning on {len(cond_indices)} columns: {[ALL_VALUE_COLS[i] for i in cond_indices]}")

n_samples = 512
condition_mask = torch.zeros(n_samples, NUM_NODES, dtype=torch.float32)
condition_mask[:, cond_indices] = 1.0

condition_values = x_test.unsqueeze(0).repeat(n_samples, 1)
node_ids = torch.arange(NUM_NODES).unsqueeze(0).repeat(n_samples, 1)

# Errors: full vector from test star (NaN for unobserved bands → zero embedding
# via ErrorEmbed's nan_to_num; real errors for observed bands → informative embedding).
# This matches training where the model always sees the star's full error vector.
sample_errors = x_test_errors.unsqueeze(0).repeat(n_samples, 1)

# Observed mask: the star's actual observation pattern (which bands were measured).
# The tokenizer uses this to embed physically meaningful context.
sample_observed = x_test_observed.unsqueeze(0).repeat(n_samples, 1)

# Build fully-connected edge mask (no observed filtering at inference)
edge_mask = build_inference_edge_mask(n_samples, NUM_NODES, device=device)

best_model.eval()
samples = sample_batched_flow(
    model_fn=best_model,
    shape=(n_samples,),
    condition_mask=condition_mask,
    condition_values=condition_values,
    node_ids=node_ids,
    edge_masks=edge_mask,
    errors=sample_errors,
    observed_mask=sample_observed,
    steps=50,
    device=device,
)

samples = samples.squeeze(-1).cpu().numpy()  # (n_samples, NUM_NODES)
# Denormalize
samples_denorm = samples * stds + means
print(f'Samples shape: {samples_denorm.shape}')

In [None]:
# Corner plot of intrinsic parameters
plot_cols = ['logAge', 'feh', 'logT', 'logg', 'logL', 'Av']
plot_indices = [ALL_VALUE_COLS.index(c) for c in plot_cols]

# True values for this star
true_vals = (x_test.numpy() * stds + means)[plot_indices]

fig, axes = plt.subplots(len(plot_cols), len(plot_cols), figsize=(14, 14))
for i, col_i in enumerate(plot_cols):
    for j, col_j in enumerate(plot_cols):
        ax = axes[i, j]
        if j > i:
            ax.axis('off')
            continue
        if i == j:
            ax.hist(samples_denorm[:, plot_indices[i]], bins=30, alpha=0.7, density=True)
            ax.axvline(true_vals[i], color='red', lw=2, label='True')
            if i == 0:
                ax.legend(fontsize=8)
        else:
            ax.scatter(samples_denorm[:, plot_indices[j]], samples_denorm[:, plot_indices[i]],
                      s=1, alpha=0.3)
            ax.scatter(true_vals[j], true_vals[i], color='red', s=50, zorder=10, marker='x')
        if j == 0:
            ax.set_ylabel(col_i, fontsize=10)
        if i == len(plot_cols) - 1:
            ax.set_xlabel(col_j, fontsize=10)

plt.suptitle('Posterior Samples (conditioned on Gaia + 2MASS + parallax)', fontsize=14)
plt.tight_layout()
plt.show()