In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import tqdm

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import matplotlib.pyplot as plt
import ml_collections
from ml_collections import config_dict

from models import noise_schedules
from models import transformer, scores, vdm
from models import diffusion_utils

%matplotlib inline

In [2]:
def get_config():

    config = config_dict.ConfigDict()

    config.workdir = 'trained_models/test2'

    config.data = data = config_dict.ConfigDict()
    data.conditioning_parameters = [
        "halo_mvir", "halo_mstar", "center_subhalo_mvir", "center_subhalo_mstar",
        "center_subhalo_vpeculiar", "center_subhalo_vmax_tilde", "log_num_subhalos",
        "inv_wdm_mass", "log_sn1", "log_sn2", "log_agn1"
    ]

    # VDM args
    config.vdm = vdm = config_dict.ConfigDict()
    vdm.d_in = 9
    vdm.d_cond = len(data.conditioning_parameters)
    vdm.d_context_embedding = 16
    vdm.timesteps = 0
    vdm.antithetic_time_sampling = True
    vdm.use_encdec = False
    vdm.embed_context = True
    
    vdm.score_model = score = config_dict.ConfigDict()
    score.name = 'transformer'
    score.d_t_embedding = 16
    score.d_model = 128
    score.d_mlp = 256
    score.d_cond = vdm.d_context_embedding if vdm.embed_context else vdm.d_cond
    score.n_layers = 6
    score.n_heads = 4

    vdm.noise_schedule = noise_schedule = config_dict.ConfigDict()
    noise_schedule.name = "learned_linear"
    noise_schedule.gamma_min = -16.0
    noise_schedule.gamma_max = 10.0

    # training and loss args
    config.training = training = config_dict.ConfigDict()
    training.batch_size = 128
    training.max_steps = 50_000
    training.noise_scale = 1e-3
    training.beta = 1.0
    training.rotation_augmentation = True
    training.n_pos_dim = 3
    training.n_vel_dim = 3
    training.add_mass_recon_loss = True
    training.i_mass_start = 6
    training.i_mass_stop = 8

    # optimizer and scheduler args
    config.optimizer = optimizer = config_dict.ConfigDict()
    optimizer.name = "AdamW"
    optimizer.lr = 5e-4
    optimizer.betas = [0.9, 0.999]
    optimizer.weight_decay = 0.01
    optimizer.grad_clip = 0.5
    config.scheduler = scheduler = config_dict.ConfigDict()
    scheduler.name = "WarmUpCosineDecayLR"
    scheduler.init_value = 0.0
    scheduler.peak_value = optimizer.lr
    scheduler.warmup_steps = 5_000
    scheduler.decay_steps = training.max_steps

    return config

config = get_config()

In [3]:
# read the dataset
datadir = '/mnt/ceph/users/tnguyen/dark_camels/point-cloud-diffusion-datasets/processed_datasets/final-WDM-datasets'

train_data = np.load(os.path.join(datadir, 'train_galprop.npz'))
train_cond_table = pd.read_csv(os.path.join(datadir, 'train_galprop_cond.csv'))
val_data = np.load(os.path.join(datadir, 'val_galprop.npz'))
val_cond_table = pd.read_csv(os.path.join(datadir, 'val_galprop_cond.csv'))

In [4]:
EPS = 1e-6

# preprocess the data
train_x = train_data['features']
train_mask = train_data['mask']
train_cond = train_cond_table[config.data.conditioning_parameters].values
val_x = val_data['features']
val_mask = val_data['mask']
val_cond = val_cond_table[config.data.conditioning_parameters].values

# normalize the data and convert to torch tensors
mask_bool = train_mask.astype(bool)
x_mean = np.mean(train_x, axis=(0, 1), where=mask_bool[..., None])
x_std = np.std(train_x, axis=(0, 1), where=mask_bool[..., None])
cond_mean = np.mean(train_cond, axis=0)
cond_std = np.std(train_cond, axis=0)
norm_dict = {
    'x_mean': x_mean,
    'x_std': x_std,
    'cond_mean': cond_mean,
    'cond_std': cond_std
}

train_x = (train_x - x_mean + EPS) / (x_std + EPS)
train_cond = (train_cond - cond_mean + EPS) / (cond_std + EPS)
val_x = (val_x - x_mean + EPS) / (x_std + EPS)
val_cond = (val_cond - cond_mean + EPS) / (cond_std + EPS)

train_x = torch.tensor(train_x, dtype=torch.float32)
train_cond = torch.tensor(train_cond, dtype=torch.float32)
val_x = torch.tensor(val_x, dtype=torch.float32)
val_cond = torch.tensor(val_cond, dtype=torch.float32)
# invert mask due to torch convention vs jax
train_mask = ~torch.tensor(train_mask, dtype=torch.bool)
val_mask = ~torch.tensor(val_mask, dtype=torch.bool)

# Create dataloaders
train_dset = TensorDataset(train_x, train_cond, train_mask)
val_dset = TensorDataset(val_x, val_cond, val_mask)
train_loader = DataLoader(
    train_dset, batch_size=config.training.batch_size, shuffle=True)
val_loader = DataLoader(
    val_dset, batch_size=config.training.batch_size, shuffle=False)

In [5]:
model = vdm.VariationalDiffusionModel(
    d_in=config.vdm.d_in,
    d_cond=config.vdm.d_cond,
    d_context_embedding=config.vdm.d_context_embedding,
    embed_context=config.vdm.embed_context,
    score_model_args=config.vdm.score_model,
    noise_schedule_args=config.vdm.noise_schedule,
    timesteps=config.vdm.timesteps,
    antithetic_time_sampling=config.vdm.antithetic_time_sampling,
    use_encdec=config.vdm.use_encdec,
    training_args=config.training,
    optimizer_args=config.optimizer,
    scheduler_args=config.scheduler,
    norm_dict=norm_dict
)

callbacks = [
    pl.callbacks.EarlyStopping(
        monitor='val_loss', patience=1000, mode='min', verbose=True),
    pl.callbacks.ModelCheckpoint(
        monitor='val_loss', save_top_k=10, mode='min', save_weights_only=False),
    pl.callbacks.LearningRateMonitor("step"),
]
train_logger = pl_loggers.TensorBoardLogger(config.workdir)

trainer = pl.Trainer(
    default_root_dir=config.workdir,
    max_steps=config.training.max_steps,
    accelerator='gpu',
    callbacks=callbacks,
    logger=train_logger,
    enable_progress_bar=True,
    inference_mode=False,
    gradient_clip_val=config.optimizer.grad_clip,
)

# train the model
trainer.fit(model, train_loader, val_loader)

/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/home/tnguyen/miniconda3/envs/geometric/lib/pyth ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/home/tnguyen/miniconda3/envs/geometric/lib/pyth ...
Missing logger folder: trained_models/test2/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | 

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


                                                                           

/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 7/7 [00:00<00:00, 18.62it/s, v_num=0]

Metric val_loss improved. New best score: 225.326


Epoch 1: 100%|██████████| 7/7 [00:00<00:00, 17.12it/s, v_num=0]

Metric val_loss improved by 17.676 >= min_delta = 0.0. New best score: 207.650


Epoch 3: 100%|██████████| 7/7 [00:00<00:00, 18.32it/s, v_num=0]

Metric val_loss improved by 1.514 >= min_delta = 0.0. New best score: 206.136


Epoch 4: 100%|██████████| 7/7 [00:00<00:00, 17.34it/s, v_num=0]

Metric val_loss improved by 2.299 >= min_delta = 0.0. New best score: 203.837


Epoch 5: 100%|██████████| 7/7 [00:00<00:00, 16.54it/s, v_num=0]

Metric val_loss improved by 6.450 >= min_delta = 0.0. New best score: 197.387


Epoch 6: 100%|██████████| 7/7 [00:00<00:00, 16.66it/s, v_num=0]

Metric val_loss improved by 2.892 >= min_delta = 0.0. New best score: 194.495


Epoch 8: 100%|██████████| 7/7 [00:00<00:00, 16.43it/s, v_num=0]

Metric val_loss improved by 1.255 >= min_delta = 0.0. New best score: 193.240


Epoch 9: 100%|██████████| 7/7 [00:00<00:00, 17.48it/s, v_num=0]

Metric val_loss improved by 4.927 >= min_delta = 0.0. New best score: 188.313


Epoch 14: 100%|██████████| 7/7 [00:00<00:00, 17.55it/s, v_num=0]

Metric val_loss improved by 2.576 >= min_delta = 0.0. New best score: 185.738


Epoch 17: 100%|██████████| 7/7 [00:00<00:00, 17.63it/s, v_num=0]

Metric val_loss improved by 3.102 >= min_delta = 0.0. New best score: 182.635


Epoch 18: 100%|██████████| 7/7 [00:00<00:00, 17.25it/s, v_num=0]

Metric val_loss improved by 4.612 >= min_delta = 0.0. New best score: 178.024


Epoch 31: 100%|██████████| 7/7 [00:00<00:00, 18.76it/s, v_num=0]

Metric val_loss improved by 6.008 >= min_delta = 0.0. New best score: 172.016


Epoch 32: 100%|██████████| 7/7 [00:00<00:00, 19.34it/s, v_num=0]

Metric val_loss improved by 0.638 >= min_delta = 0.0. New best score: 171.378


Epoch 38: 100%|██████████| 7/7 [00:00<00:00, 18.37it/s, v_num=0]

Metric val_loss improved by 1.157 >= min_delta = 0.0. New best score: 170.221


Epoch 39: 100%|██████████| 7/7 [00:00<00:00, 18.61it/s, v_num=0]

Metric val_loss improved by 7.339 >= min_delta = 0.0. New best score: 162.882


Epoch 53: 100%|██████████| 7/7 [00:00<00:00, 19.88it/s, v_num=0]