In [10]:
import torch
import torch.nn as nn

from GINN.problem_sampler import ProblemSampler
from train.train_utils.latent_sampler import sample_new_z
from utils import get_stateless_net_with_partials, get_model
from neural_clbf.controllers.simple_neural_cbf_controller import SimpleNeuralCBFController
from neural_clbf.systems.simple3d import Simple3DRobot
from configs.get_config import get_config_from_yml
from models.model_utils import tensor_product_xz
from train.train_utils.loss_optims import LossBalancer, GradNormBalancer
from torch.utils.tensorboard import SummaryWriter

import subprocess
import time
from datetime import datetime
import os
from copy import deepcopy
from tqdm import trange
from collections import defaultdict


In [11]:
class AdapterMLP(nn.Module):
    """
    Adapter network to replace the final SIREN layer.
    """
    def __init__(self, in_dim, hidden_dim=16):
        super(AdapterMLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Output correction term
        )

    def forward(self, x):
        return self.mlp(x)

class ConditionalSIRENWithAdapter(nn.Module):
    """
    Conditional SIREN with an adapter replacing the final layer.
    """
    def __init__(self, siren_model, adapter_model):
        super(ConditionalSIRENWithAdapter, self).__init__()
        
        # Remove final linear layer from SIREN
        self.siren = nn.Sequential(*list(siren_model.network.children())[:-1])
        self.adapter = adapter_model  # Adapter MLP replaces final layer
        self.jacobian = None

        # Freeze all but adapter
        for param in self.siren.parameters():
            param.requires_grad = False

    def forward(self, x, z, calc_jacobian=False):
        xz = torch.cat([x, z], dim=-1)  # Ensure concatenation happens before passing to the model
        features = self.siren(xz)  # Pass through modified SIREN layers
        h_x = self.adapter(features)  # Apply adapter MLP

        if calc_jacobian:
            self.jacobian = torch.autograd.functional.jacobian(lambda x: self.forward(x, z, calc_jacobian=False), x)

        return h_x

class LossTimer:
    def __init__(self):
        self.times = defaultdict(list)  # Store loss computation times
        self.start_times = {}  # Store start times for ongoing loss calculations

    def start(self, loss_name):
        """Start timing for a specific loss."""
        self.start_times[loss_name] = time.time()

    def stop(self, loss_name):
        """Stop timing and log duration for a specific loss."""
        if loss_name in self.start_times:
            elapsed_time = time.time() - self.start_times.pop(loss_name)
            self.times[loss_name].append(elapsed_time)

    def print_summary(self):
        """Prints the average and all recorded times for each loss."""
        print("\n=== Loss Timing Summary ===")
        for loss_name, timings in self.times.items():
            avg_time = sum(timings) / len(timings)
            print(f"{loss_name}: Avg {avg_time:.6f}s | Timings: {timings}")


In [None]:
# Load config
with open("config_adapter.yml", "r") as file:
    config = yaml.safe_load(file)

# Extract values
DATASET_DIR = config["paths"]["dataset_dir"]
SIREN_CONFIG_PATH = config["paths"]["siren_config_path"]
MODEL_PATH = config["paths"]["model_path"]
MODEL_SAVE_PATH = os.path.join(config["paths"]["model_save_path"], datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
TENSORBOARD_PORT = config["paths"]["tensorboard_port"]
LOG_DIR = config["paths"]["tensorboard_log_dir"]

CBF_LAMBDA = config["training"]["cbf_lambda"]
CBF_RELAXATION_PENALTY = config["training"]["cbf_relaxation_penalty"]
MAX_EPOCHS = config["training"]["max_epochs"]
SAVE_N_EPOCHS = config["training"]["save_n_epochs"]
LOSS_THRESH = config["training"]["loss_thresh"]
MIN_LOSS_THRESH = config["training"]["min_loss_thresh"]
MIN_CONTROL_NORM = config["training"]["min_control_norm"]
LAMBDA_RECON = config["training"]["lambda_recon"]
LAMBDA_DESCENT = config["training"]["lambda_descent"]
LAMBDA_CONTROL = config["training"]["lambda_control"]
LOSS_BALANCER_MODEL = config["training"]["loss_balancer_model"]

CONTROLLER_PERIOD = config["simulation"]["controller_period"]
SIMULATION_DT = config["simulation"]["simulation_dt"]

# Ensure model save path exists
if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_PATH)

# TensorBoard setup
log_dir = os.path.join(LOG_DIR, f"tensorboard_log_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [12]:
##### Using config
config = get_config_from_yml(SIREN_CONFIG_PATH)
config['device'] = DEVICE

siren_model = get_model(config).to(DEVICE)
siren_model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
netp = get_stateless_net_with_partials(siren_model, use_x_and_z_arg=True)

p_sampler = ProblemSampler(config)
z = sample_new_z(config, is_init=True).to(DEVICE)

adapter_model = AdapterMLP(in_dim=256, hidden_dim=16).to(DEVICE)
model = ConditionalSIRENWithAdapter(siren_model, adapter_model).to(DEVICE)
opt = torch.optim.Adam(model.adapter.parameters(), lr=1e-3)
##### Using config

##### Tensor Board
tensorboard_port = 6006
writer = SummaryWriter(log_dir=log_dir)

tensorboard_log_file = f"tensorboard_log_{datetime_str}"
log_dir = os.path.join("all_runs/adapter", tensorboard_log_file)

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

writer = SummaryWriter(log_dir=log_dir)
tensorboard_process = subprocess.Popen(["tensorboard", "--logdir", log_dir, "--port", str(tensorboard_port)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

time.sleep(3)

tensorboard_url = f"http://localhost:{tensorboard_port}/"
print(f"TensorBoard is running at: {tensorboard_url}")
##### Tensor Board



########

tensorboard_log_file = f"tensorboard_log_{datetime_str}"
log_dir = os.path.join("all_runs/adapter", tensorboard_log_file)

# Ensure the directory exists
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

# Initialize TensorBoard SummaryWriter with custom log file
writer = SummaryWriter(log_dir=log_dir)

########

controller_period = 0.05
simulation_dt = 0.01
nominal_params = {}
scenarios = [nominal_params]

dynamics_model = Simple3DRobot(
    nominal_params,
    dt=simulation_dt,
    controller_dt=controller_period,
    scenarios=scenarios,
)

cbf_controller = SimpleNeuralCBFController(
    dynamics_model,
    scenarios,
    model,
    cbf_lambda=CBF_LAMBDA,
    cbf_relaxation_penalty=CBF_RELAXATION_PENALTY,
    z = z,
    device = DEVICE
)

loss_timer = LossTimer()

prev_lost = 1
min_loss = 1
best_epoch = 1
best_model = None

lambda_recon = 1.0
lambda_descent = 1.0
lambda_control = 1.0

loss_balancer_model = 'gradnorm' # fixed
if loss_balancer_model == 'gradnorm':
    loss_balancer = GradNormBalancer(num_losses=3).to(DEVICE)  # 6 loss terms
else:
    loss_balancer = None

for epoch in trange(MAX_EPOCHS, leave=True, position=0, colour="yellow"):
    # print("\n=============", str(epoch), "=============")
    opt.zero_grad()
    cbf_controller.set_V_nn(model)

    # Reconstruction Loss
    loss_timer.start("Reconstruction Loss")
    recon_inps = torch.vstack([
        p_sampler.sample_from_interface()[0],
        p_sampler.sample_from_domain(),
        p_sampler.sample_from_outer()
    ])
    siren_ys = siren_model(*tensor_product_xz(recon_inps, z)).squeeze(1)
    my_ys = model(*tensor_product_xz(recon_inps, z)).squeeze(1)
    recon_loss = (siren_ys - my_ys).square().mean()
    loss_timer.stop("Reconstruction Loss")

    # Descent Loss
    loss_timer.start("Descent Loss")
    loss_descent = torch.tensor(0.0, device=DEVICE)
    xs_start, u_refs = p_sampler.sample_for_descent()
    losses_list, u_opt = cbf_controller.descent_loss(xs_start, u_ref=u_refs, get_us=True)
    loss_values = torch.stack([torch.clamp(l, min=0) for _, l in losses_list if not l.isnan()], dim=0)
    if loss_values.numel() > 0:
        loss_descent = loss_values.mean()
    loss_timer.stop("Descent Loss")

    # Small Control Loss
    loss_timer.start("Small Control Loss")
    loss_small_control = torch.tensor(0.0, device=DEVICE)
    u_norm = torch.norm(u_opt, p=2, dim=1)
    loss_small_controls = torch.clamp(MIN_CONTROL_NORM - u_norm, min=0)
    loss_small_control = loss_small_controls.mean()
    loss_timer.stop("Small Control Loss")

    # Loss Balancer
    loss_timer.start("Loss Balancer Computation")
    losses = torch.stack([recon_loss, loss_descent, loss_small_control])
    if loss_balancer_model == 'gradnorm':
        loss = loss_balancer(losses, model.adapter.parameters())
        lambdas = [l.item() for l in loss_balancer.loss_weights]
        for i, lam in enumerate(lambdas):
            writer.add_scalar(f"Lambda/lambda_{i}", loss.item(), epoch)    
    else:
        lambdas = torch.tensor([lambda_recon, lambda_descent, lambda_control])
        loss = (losses * lambdas).sum()
    loss_timer.stop("Loss Balancer Computation")

    loss.backward()
    opt.step()

    writer.add_scalar("Loss/Total", loss.item(), epoch)
    writer.add_scalar("Loss/Reconstruction", recon_loss.item(), epoch)
    writer.add_scalar("Loss/Descent", loss_descent.item(), epoch)
    writer.add_scalar("Loss/Small_Control", loss_small_control.item(), epoch)

    # print("Current loss:", loss.item(), "Epoch:", epoch, "Delta (%):", (abs(loss - prev_lost) / loss).item() * 100)
    if prev_lost < loss:
        best_epoch = epoch
        best_model = deepcopy(model.state_dict())
    
    prev_lost = loss
    if epoch % SAVE_N_EPOCHS == 0 and epoch > 1:
        print(epoch)
        loss_timer.print_summary()
        savename = os.path.join(MODEL_SAVE_PATH, f"model_{epoch}.pth")
        torch.save(model.state_dict(), savename)
        print("Saving...", savename)
        print("Loss", loss)

savename = os.path.join(MODEL_SAVE_PATH, f"model_best_{best_epoch}.pth")
torch.save(best_model, savename)
print("Best epoch:", best_epoch)

# if os.path.exists(MODEL_SAVE_PATH) and not os.listdir(MODEL_SAVE_PATH):  # Check if folder exists and is empty
#     os.rmdir(folder_path)
#     print(f"Deleted empty folder: {folder_path}")

  siren_model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


TensorBoard is running at: http://localhost:6006/
{}


 60%|[33m██████    [0m| 6/10 [00:30<00:20,  5.03s/it]

5

=== Loss Timing Summary ===
Reconstruction Loss: Avg 0.002368s | Timings: [0.0025861263275146484, 0.0022258758544921875, 0.002318143844604492, 0.0023987293243408203, 0.0023462772369384766, 0.0023310184478759766]
Descent Loss: Avg 2.469436s | Timings: [2.238929033279419, 3.270869493484497, 2.3278675079345703, 2.5791780948638916, 2.196256399154663, 2.2035152912139893]
Small Control Loss: Avg 0.000223s | Timings: [0.00022935867309570312, 0.0002219676971435547, 0.00021505355834960938, 0.00022292137145996094, 0.00022602081298828125, 0.00022530555725097656]
Loss Balancer Computation: Avg 1.267314s | Timings: [0.8722145557403564, 1.4422128200531006, 0.9493632316589355, 2.6168084144592285, 0.7698776721954346, 0.9534075260162354]
Saving... /scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/models/adapter/2025-02-24_16-39-43/model_5.pth
Loss tensor(0.6772, device='cuda:0', grad_fn=<SumBackward0>)


100%|[33m██████████[0m| 10/10 [00:46<00:00,  4.62s/it]

Best epoch: 9



