In [None]:
import sys

# Append it to sys.path
sys.path.append("../")

In [None]:
import glob
import os

import hydra
import torch
import yaml
from hydra import compose, initialize
from lightning import LightningModule
from omegaconf import OmegaConf

# Specify the directory to search
activation_list = ["linear", "hardtanh", "tanh"]
model_list = ["diagonalrnn", "rnn", "softplusrnn"]
rho_name_list = ["exp", "pol"]

index_0, index_1, index_2 = 0, 2, 1

activation = activation_list[index_0]
model = model_list[index_1]
rho_name = rho_name_list[index_2]

directory = f"../logs/LF_{activation}_{model}_{rho_name}"  # Stable

# Use glob to get all the .ckpt files
ckpt_files = glob.glob(f"{directory}/**/last.ckpt", recursive=True)

# Now sort the files based on their last modified time
ckpt_files.sort(key=os.path.getmtime)

print(ckpt_files)
pathlist = ckpt_files

In [None]:
cfg_path = os.path.join(os.path.dirname(os.path.dirname(pathlist[0])), ".hydra/")
# print(cfg_path)
with initialize(version_base="1.3", config_path=cfg_path):
    cfg = compose(config_name="config", overrides=[])
    # print(OmegaConf.to_yaml(cfg))

In [None]:
stability_margin_path = os.path.join(directory, "stability_margin.txt")

with open(stability_margin_path, "w") as file:
    for path in pathlist:
        # path = pathlist[0]
        cfg_path = os.path.join(os.path.dirname(os.path.dirname(path)), ".hydra/")

        with initialize(version_base="1.3", config_path=cfg_path):
            # print("Path is ", path)
            cfg = compose(config_name="config", overrides=[])

            model: LightningModule = hydra.utils.instantiate(cfg["model"])
            model_ckpt = torch.load(path)["state_dict"]
            model.load_state_dict(model_ckpt)
            stability_margin = model.net.stability_margin()
            print("Current stability margin: ", stability_margin.detach().numpy())
            file.write(str(stability_margin.detach().numpy()) + "\n")

In [None]:
for path in pathlist:
    # path = pathlist[0]
    cfg_path = os.path.join(os.path.dirname(os.path.dirname(path)), ".hydra/")
    ckpt_best_path = torch.load(path)["callbacks"][
        "ModelCheckpoint{'monitor': 'val/loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"
    ]["best_model_path"]
    with initialize(version_base="1.3", config_path=cfg_path):
        # print("Path is ", path)
        cfg = compose(config_name="config", overrides=[])

        model: LightningModule = hydra.utils.instantiate(cfg["model"])
        model_ckpt = torch.load(ckpt_best_path)["state_dict"]
        model.load_state_dict(model_ckpt)
        stability_margin = model.net.stability_margin()

    loss = torch.load(path)["callbacks"][
        "ModelCheckpoint{'monitor': 'val/loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"
    ]["best_model_score"]
    print("Best stability margin:", float(stability_margin))

In [None]:
for path in pathlist:
    # path = pathlist[0]
    cfg_path = os.path.join(os.path.dirname(os.path.dirname(path)), ".hydra/")
    ckpt_best_path = torch.load(path)["callbacks"][
        "ModelCheckpoint{'monitor': 'val/loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"
    ]["best_model_path"]
    with initialize(version_base="1.3", config_path=cfg_path):
        # print("Path is ", path)
        cfg = compose(config_name="config", overrides=[])

        model: LightningModule = hydra.utils.instantiate(cfg["model"])
        model_ckpt = torch.load(ckpt_best_path)["state_dict"]
        model.load_state_dict(model_ckpt)
        stability_margin = model.net.stability_margin()

    loss = torch.load(path)["callbacks"][
        "ModelCheckpoint{'monitor': 'val/loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"
    ]["best_model_score"]
    print("Loss:            ", float(loss))
    print("Stability margin:", float(stability_margin))
    print("Regularized loss:", float(loss - stability_margin * 0.01), "\n")