In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
import copy
import numpy as np
from pathlib import Path
from omegaconf import OmegaConf
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

import time
from datetime import datetime

# Add the project root to the path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.utils import extract_trailing_numbers, set_seed
import flowmol
import dgl

In [3]:
from true_reward.dxtb_simulation import compute_energy, compute_energy_grad

  return self.fget.__get__(instance, owner)()


In [4]:
def sampling(config: OmegaConf, model: flowmol.FlowMol, device: torch.device):
    new_molecules = model.sample_random_sizes(
        n_molecules = config.num_samples, 
        n_timesteps = config.num_integration_steps + 1, 
        device = device,
    )
    return new_molecules

In [5]:
def setup_gen_model(flow_model: str, device: torch.device):
    # Load - Flow Model
    gen_model = flowmol.load_pretrained(flow_model)
    gen_model.to(device)
    return gen_model

In [None]:
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Run ALM with optional parameter overrides")
    # Settings
    parser.add_argument("--config", type=str, default="../configs/adjoint_matching.yaml",
                        help="Path to config file")
    parser.add_argument("--save_model", action='store_true',
                        help="Save the model, default: false")
    parser.add_argument("--save_samples", action='store_true',
                        help="Create animation of the samples and save the samples, default: false")
    parser.add_argument("--save_plots", action='store_true',
                        help="Save plots of rewards and constraints, default: false")
    parser.add_argument("--plotting_freq", type=int,
                        help="Plotting frequency")
    # FlowMol arguments
    flowmol_choices = ['geom_ctmc', 'geom_gaussian']
    parser.add_argument('--flow_model', type=str, choices=flowmol_choices,
                        help='pretrained model to be used')
    # Adjoint Matching Parameters
    parser.add_argument("--reward_lambda", type=float,
                        help="Override reward_lambda in config")
    parser.add_argument("--lr", type=float,
                        help="Override adjoint_matching.lr in config")
    parser.add_argument("--clip_grad_norm",  type=float,
                        help="Override adjoint_matching.clip_grad_norm in config")
    parser.add_argument("--batch_size", type=int,
                        help="Override adjoint_matching.batch_size in config")
    parser.add_argument("--samples_per_update", type=int,
                        help="Override adjoint_matching.num_samples in config")
    parser.add_argument("--num_integration_steps", type=int,
                        help="Override adjoint_matching.num_integration_steps in config")
    parser.add_argument("--finetune_steps", type=int,
                        help="Override adjoint_matching.finetune_steps in config")
    parser.add_argument("--num_iterations", type=int,
                        help="Override number of iterations")
    return parser.parse_args()

sys.argv = [""]

In [7]:
def update_config_with_args(config, args):
    # FlowMol arguments
    if args.flow_model is not None:
        config.flowmol.model = args.flow_model
    # Adjoint Matching Parameters
    if args.reward_lambda is not None:
        config.reward_lambda = args.reward_lambda
    if args.lr is not None:
        config.adjoint_matching.lr = args.lr
    if args.clip_grad_norm is not None:
        config.adjoint_matching.clip_grad_norm = args.clip_grad_norm
    if args.batch_size is not None:
        config.adjoint_matching.batch_size = args.batch_size
    if args.samples_per_update is not None:
        config.adjoint_matching.sampling.num_samples = args.samples_per_update
    if args.num_integration_steps is not None:
        config.adjoint_matching.sampling.num_integration_steps = args.num_integration_steps
    if args.finetune_steps is not None:
        config.adjoint_matching.finetune_steps = args.finetune_steps
    if args.num_iterations is not None:
        config.adjoint_matching.num_iterations = args.num_iterations
    return config

In [None]:
# Parse command line arguments
args = parse_args()

# Load config from file
config_path = Path(args.config)
config = OmegaConf.load(config_path)

# Update config with command line arguments
config = update_config_with_args(config, args)

# Setup - Seed and device and root directory
set_seed(config.seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
root_dir = Path(config.root) / Path(config.experiment)

# Settings
if args.plotting_freq is None:
    args.plotting_freq = max(config.adjoint_matching.num_iterations // 20, 1)

# General Parameters
flowmol_model = config.flow_model

# Adjoint Matching Parameters
reward_lambda = config.reward_lambda
learning_rate = config.adjoint_matching.lr
clip_grad_norm = config.adjoint_matching.clip_grad_norm
traj_samples_per_stage = config.adjoint_matching.sampling.num_samples
traj_len = config.adjoint_matching.sampling.num_integration_steps
finetune_steps = config.adjoint_matching.finetune_steps
num_iterations = config.adjoint_matching.num_iterations

config.adjoint_matching.sampling.sampler_type = "memoryless"
config.reward_sampling.sampler_type = "euler"

In [16]:
device

device(type='cpu')

In [9]:
print(OmegaConf.to_yaml(config))

seed: 42
verbose: false
root: /Users/svlg/MasterThesis/v02/
experiment: first_test
flow_model: qm9_ctmc
reward_lambda: 100
adjoint_matching:
  num_iterations: 50
  batch_size: 6
  clip_grad_norm: 0.5
  clip_loss: 100000.0
  lr: 5.0e-07
  finetune_steps: 2
  sampling:
    sampler_type: memoryless
    num_samples: 24
    num_integration_steps: 40
reward_sampling:
  sampler_type: euler
  num_samples: 20
  num_integration_steps: 100



In [10]:
# Setup - Gen Model
gen_model = setup_gen_model(config.flow_model, device=device)

In [11]:
def reward_fn(x_new):
    return compute_energy(x_new, reward_lambda, device=device)

def grad_reward_fn(x_new):
    with torch.enable_grad():
        gradients = compute_energy_grad(x_new, reward_lambda, device=device)
        return gradients

In [12]:
from finetuning.flow_adjoint import AdjointMatchingFinetuningTrainerFlowMol

In [13]:
# Set up - Adjoint Matching       
trainer = AdjointMatchingFinetuningTrainerFlowMol(
    config = config.adjoint_matching,
    model = copy.deepcopy(gen_model),
    base_model = copy.deepcopy(gen_model),
    grad_reward_fn = grad_reward_fn,
    device = device,
    verbose = False,
)

In [14]:
# Initialize lists to store loss and rewards
losses = []
rewards = []
if args.save_samples:
    new_samples = []

In [15]:
args.save_samples

False

In [None]:
# Generate Samples
new_molecules = sampling(
    config.reward_sampling,
    copy.deepcopy(gen_model),
    device=device
)
if args.save_samples:
    new_samples.extend(dgl.unbatch(new_molecules.cpu()))

In [None]:
# Compute appropriate reward for evaluation
reward = reward_fn(x_new).item()
rewards.append(reward/(reward_lambda+1e-8))
current_best_reward = rewards[-1]
best_epoch = 0
rewards[-1]

In [None]:
# Run finetuning loop
with tqdm(range(1, num_iterations + 1), desc="Finetuning Progress", dynamic_ncols=True) as pbar:
    for i in pbar:

        # Solves lean adjoint ODE to create dataset
        dataset = trainer.sample_dataset()
        
        # Fine-tune the model with adjoint matching loss
        loss = trainer.finetune(dataset, steps=config.adjoint_matching.finetune_steps, verbose=False)
        losses.append(loss/(reward_lambda+1e-8)/(traj_len//2))
        
        # Generate Samples
        x_new = sampling(
            config.reward_sampling,
            copy.deepcopy(gen_model),
            device=device
        )
        if args.save_samples:
            new_samples.append(x_new.detach().cpu().numpy())

        # Compute appropriate reward for evaluation
        tmp_reward = reward_fn(x_new).item()
        rewards.append(tmp_reward/(reward_lambda+1e-8))

        if rewards[-1] > current_best_reward:
            current_best_reward = rewards[-1]
            best_epoch = i

        elif i % (args.plotting_freq*4) == 0:
            print(f"Iteration {i}: Loss: {losses[-1]:.4f}, Reward: {rewards[-1]:.4f}", flush=True)
            print(f"Best reward: {current_best_reward:.4f} at epoch {best_epoch}", flush=True)

In [None]:
# Plot rewards and losses
tmp_data = [rewards, losses]
tmp_titles = ["Rewards", "Losses"]
plot_graphs(tmp_data, tmp_titles)

In [None]:
from datetime import datetime
time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
print(time)