In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mlflow
from hydra import initialize, compose
from rdkit import Chem
from collections import defaultdict
from ergochemics.mapping import rc_to_nest
from ergochemics.draw import draw_reaction, draw_molecule
from IPython.display import SVG
from cgr.ml import bin_label_to_sep_aidx
import yaml
from ast import literal_eval
from tqdm import tqdm
from ergochemics.mapping import (
    get_reaction_center
)

from sklearn.metrics import (
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    RocCurveDisplay,
    PrecisionRecallDisplay,
    precision_score,
    recall_score,
    accuracy_score,
    f1_score
)

from sklearn.calibration import CalibrationDisplay

with initialize(version_base=None, config_path="../configs/filepaths"):
    cfg = compose(config_name="filepaths")

mlflow.set_tracking_uri(f"file://{cfg.mlruns}")

In [None]:


def get_min_dist_to_rc(am_rxn: str, rc: list[list[list[int]], list[list[int]]]) -> list[list[list[int]], list[list[int]]]:
    mols = [[Chem.MolFromSmiles(elt) for elt in side.split('.')] for side in am_rxn.split('>>')]
    min_dists = [[], []]
    for i, side_rc in enumerate(rc):
        for mol, rc in zip(mols[i], side_rc):
            for atom in mol.GetAtoms():
                aidx = atom.GetIdx()
                min_dist = min(len(Chem.GetShortestPath(mol, aidx, rcidx)) - 1 if aidx != rcidx else 0 for rcidx in rc)

                min_dists[i].append(min_dist)
    return [np.array(elt).reshape(-1, 1) for elt in min_dists]


In [None]:
experiment_name = "production"
experiment = mlflow.get_experiment_by_name(experiment_name)

if experiment:
    df = mlflow.search_runs(experiment_ids=[experiment.experiment_id])
else:
    print(f"Experiment '{experiment_name}' not found.")

In [None]:
df.head()

In [None]:
# Copy configs to conf dir to train production model
for i, row in df.iterrows():
    fn = f"outer_split_{row['params.data/outer_split_idx']}.yaml"
    out_path = Path(cfg.configs) / "production" / fn
    artifact_dir = Path(row['artifact_uri'].replace("file:///home/stef/cgr", "/home/stef/cgr"))
    model_ckpt_abs = next(artifact_dir.rglob("*.ckpt"))
    model_ckpt_rel = model_ckpt_abs.relative_to(cfg.mlruns)
    print("checkpoint: ", model_ckpt_rel)

    config = defaultdict(dict)
    for k, v in row.items():
        if k.startswith("params.") and "/" in k:
            k_out, k_in, *_ = k.removeprefix("params.").split("/")

            if k_out == "full":
                continue

            try:
                config[k_out][k_in] = literal_eval(v)
            except ValueError:
                config[k_out][k_in] = v
    
    config["model"]["ckpt"] = str(model_ckpt_rel)

    with open(out_path, "w") as f:
        yaml.dump(dict(config), f)
    
    print(f"Saving config to: {out_path}")

In [None]:
# preds = []
# for fn in (Path(cfg.processed_data) / "mech_probas").glob("*.parquet"):
#     print("Loading: ", fn)
#     preds.append(pd.read_parquet(fn))

# dts = [0.9562190771102905, 0.05624692514538765, 0.01461805310100317, 0.0060590095818042755, 0.0028916343580931425]
# pred_df = pd.concat(preds)
# avg = pred_df.groupby(["rxn_id", "aidx"])["probas"].mean().reset_index()
# pop_vote = []
# avg_model = []
# for dt in tqdm(dts, total=len(dts)):
#     pred_df['label'] = (pred_df['probas'] > dt).astype(int)
#     grouped = pred_df.groupby(['rxn_id', 'aidx'])['label'].agg(lambda x: x.mode()[0]).reset_index()
#     pop_vote.append(grouped)
#     avg_model.append((avg['probas'] > dt).astype(int))

In [None]:
# ground_truth = pred_df.drop_duplicates(subset=["rxn_id", "aidx"], keep="first").reset_index(drop=True)
# for dt, pop_ans, avg_ans in zip(dts, pop_vote, avg_model):
#     y = ground_truth['label']
#     pop_y_pred = pop_ans['label']
#     avg_y_pred = avg_ans

#     f1_pop = f1_score(y, pop_y_pred)
#     f1_avg = f1_score(y, avg_y_pred)
#     acc_pop = accuracy_score(y, pop_y_pred)
#     acc_avg = accuracy_score(y, avg_y_pred)

#     print(f"Decision threshold: {dt:.5f}")
#     print(f"  Popularity vote - F1: {f1_pop:.4f}, Acc: {acc_pop:.4f}")
#     print(f"  Average model   - F1: {f1_avg:.4f}, Acc: {acc_avg:.4f}")