In [1]:
%load_ext autoreload
%autoreload 2

from krxns.config import filepaths
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact
import torch
from pathlib import Path

In [2]:
def latest_ckpt(ckpt_dir: Path):
    ckpts = list(ckpt_dir.glob("*.ckpt"))
    ckpt_rank = [tuple([elt.split("=")[-1] for elt in ckpt.stem.split("-")]) for ckpt in ckpts]
    srt_ckpts = sorted(zip(ckpt_rank, ckpts), key=lambda x : x[0], reverse=True)
    return srt_ckpts[0][1]

In [3]:
# Set this
experiment = "default_chemprop_data_titration"
copts = [elt.stem for elt in  (filepaths["spl_cv"] / experiment).glob('*')]


In [None]:
# Loss curves
@interact
def plot_loss_curves(experiment=experiment, condition=copts):
    lc = []
    cond_dir = filepaths["spl_cv"] / experiment / condition

    for fp in cond_dir.glob("split_*/metrics.csv"):
        ckpt = latest_ckpt(fp.parent / "checkpoints")
        state_dict = torch.load(ckpt, map_location=torch.device('cpu'))
        scl = float(state_dict['hyper_parameters']['predictor']['output_transform'].scale)
        split = pd.read_csv(fp, sep=',')
        split['split'] = int(fp.parent.stem.split('_')[-1])
        lc.append(split)

    stats = ('mean', 'std')
    lc = pd.concat(lc, axis=0).drop(columns=['step', 'train_loss_epoch'])
    lc = lc.groupby(['epoch', 'split']).agg('mean').reset_index() # Average over epochs w/in splits
    lc['val/rmse_rescaled'] = lc['val/rmse'] * scl # Add re-scaled rmse
    lc = lc.groupby('epoch').agg(stats).drop(columns=['split']).reset_index() # Average over epochs across splits
    
    fig, ax = plt.subplots(nrows=2, figsize=(7, 8), sharex=True)
    metrics, _ = zip(*lc.columns)
    metrics = set(filter(lambda x : x not in ['epoch', 'val/rmse_rescaled'], metrics))
    for m in metrics:
        ax[1].errorbar(x=lc['epoch'], y=lc[(m, stats[0])], yerr=lc[(m, stats[1])], label=m)


    ax[0].errorbar(x=lc['epoch'], y=lc[('val/rmse_rescaled', stats[0])], yerr=lc[('val/rmse_rescaled', stats[1])])

    ax[1].set_xlabel('Epochs')
    ax[1].legend()
    ax[0].set_ylabel("Rescaled RMSE [synthetic distance]")
    plt.show()

interactive(children=(Text(value='default_chemprop_data_titration', description='experiment'), Dropdown(descri…