In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from condition_prediction.run import ConditionPrediction
import wandb
from tqdm import tqdm
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import pathlib
from tqdm import tqdm

In [3]:
api = wandb.Api()
wandb_entity="ceb-sre"
wandb_project="orderly"

## Update models with top_n accuracy scores 

In [None]:
# Loop through all relevant runs on wandb to get run_ids, datasets and random seeds
# For each rerun the conditionprediction with skip_training=True and resume=True
DATASETS = ["with_trust_with_map","with_trust_no_map", "no_trust_no_map", "no_trust_with_map"]
BASE_PATH = pathlib.Path("/project/studios/orderly-preprocessing/ORDerly/")
DATASETS_PATH = BASE_PATH / "data/orderly/datasets/"
MODEL_PATH = pathlib.Path("ORDerly/models")
configs =[]
for random_seed in [12345, 54321,98765]:
    for dataset in DATASETS:
        filters = {
            "state": "finished",
            "config.output_folder_path": {"$in":
                [
                    f"models/{dataset}", 
                    str(MODEL_PATH / dataset),
                    f"/Users/Kobi/Documents/Research/phd_code/ORDerly/models/{dataset}"
                ],
            },
            "config.random_seed": random_seed,
            # "config.train_fraction": 1.0,
            "config.dataset_version": "v4",
            "config.train_mode": 0, # Teacher forcing
        }
        runs = api.runs(
            f"{wandb_entity}/{wandb_project}",
            filters=filters
        )
        if not len(runs) == 5: # For 5 training fractions
            raise ValueError(f"Not 5 runs for {dataset} (found {len(runs)})")
        
        for run in runs:
            config = dict(run.config)
            train_data_path = pathlib.Path(f"{DATASETS_PATH}/orderly_{dataset}_train.parquet")
            test_data_path = pathlib.Path(f"{DATASETS_PATH}/orderly_{dataset}_test.parquet")
            fp_directory = train_data_path.parent / "fingerprints"
            train_fp_path = fp_directory / (train_data_path.stem + ".npy")
            test_fp_path = fp_directory / (test_data_path.stem + ".npy")
            output_folder_path = MODEL_PATH / dataset
            output_folder_path.mkdir(parents=True, exist_ok=True)
            tags = dataset.split("_")
            tags = [f"{tags[0]}_{tags[1]}", f"{tags[2]}_{tags[3]}"]
            config.update({
                "train_data_path": train_data_path,
                "test_data_path": test_data_path,
                "train_fp_path": train_fp_path,
                "test_fp_path": test_fp_path,
                "output_folder_path": output_folder_path,
                "skip_training": True,
                "resume": True,
                "resume_from_best": True,
                "generate_fingerprints": False,
                "wandb_run_id": run.id,
                "wandb_tags": tags,
            })
            configs.append(config)
            del config["n_val"]
            del config["n_test"]
            del config["n_train"]
            del config["dataset_version"]
            instance = ConditionPrediction(**config)
            instance.run_model_arguments()
            wandb.finish()

## Generate Table 

In [52]:
DATASETS = ["with_trust_with_map","with_trust_no_map", "no_trust_no_map", "no_trust_with_map"]
lines = ["Solvents", "Agents", "Everything"]
for dataset in DATASETS:
    filters = {
        "state": "finished",
        "config.output_folder_path": f"models/{dataset}",
        "config.random_seed": 12345,
        "config.train_fraction": 1.0,
        "config.train_mode": 0, # Teacher forcing
    }
    runs = api.runs(
        f"{wandb_entity}/{wandb_project}",
        filters=filters
    )
    assert len(runs) == 1
    run = runs[0]

    # Get model solvent, agent and overall accuracy
    test_best = run.summary["test_best"]
    solvent_accuracy = test_best["solvent_accuracy"]
    agent_accuracy = test_best["three_agents_accuracy"]
    overall_accuracy = test_best["overall_accuracy"]

    # Get frequency informed solvent, agent and overall accuracy
    fi_solvent_accuracy = run.summary["frequency_informed_solvent_accuracy"]
    fi_agent_accuracy = run.summary["frequency_informed_agent_accuracy"]
    fi_overall_accuracy = run.summary["frequency_informed_overall_accuracy"]

    # Improvement
    solvent_improvement = (solvent_accuracy-fi_solvent_accuracy)/fi_solvent_accuracy
    solvent_improvement_color = "lessgreen" if solvent_improvement>0 else "red"
    agent_improvement = (agent_accuracy-fi_agent_accuracy)/fi_agent_accuracy
    agent_improvement_color = "lessgreen" if agent_improvement>0 else "red"
    overall_improvement = (overall_accuracy-fi_overall_accuracy)/fi_overall_accuracy
    overall_improvement_color = "lessgreen" if overall_improvement>0 else "red"

    # Create table lines
    lines[0] += f" & {fi_solvent_accuracy*100:.0f} // {solvent_accuracy*100:.0f} // \\textcolor{{{solvent_improvement_color}}}{{{solvent_improvement*100:.0f}\%}} "
    lines[1] += f" & {fi_agent_accuracy*100:.0f} // {agent_accuracy*100:.0f} // \\textcolor{{{agent_improvement_color}}}{{{agent_improvement*100:.0f}\%}} "
    lines[2] += f" & {fi_overall_accuracy*100:.0f} // {overall_accuracy*100:.0f} // \\textcolor{{{overall_improvement_color}}}{{{overall_improvement*100:.0f}\%}} "
print("\\\\ \n".join(lines) + "\\\\")

KeyError: 'test_best'

In [None]:
DATASETS = ["with_trust_with_map","with_trust_no_map", "no_trust_no_map", "no_trust_with_map"]
LABELS = {
    "with_trust_with_map": r"Labelling, rare $\rightarrow$ other",
    "with_trust_no_map": r"Labelling, rare $\rightarrow$ delete rxn",
    "no_trust_no_map": r"Reaction string, rare $\rightarrow$ other",
    "no_trust_with_map": r"Reaction string, rare $\rightarrow$ delete rxn",
}
TRAIN_FRACS =  [0.2, 0.4, 0.6, 0.8, 1.0]
fig, ax = plt.subplots(1)
markers = ["o", "d", "s", "^"]
for i, dataset in enumerate(DATASETS):
    overall_accuracies = []
    for train_fraction in TRAIN_FRACS:
        filters = {
            "state": "finished",
            "config.output_folder_path": f"models/{dataset}",
            "config.random_seed": 12345,
            "config.train_fraction": train_fraction,
            "config.train_mode": 0, # Teacher forcing
        }
        runs = api.runs(
            f"{wandb_entity}/{wandb_project}",
            filters=filters
        )
        assert len(runs) == 1
        run = runs[0]

        # Get overall accuracy
        acc = run.summary["test_best"]["overall_accuracy"]
        overall_accuracies.append(acc)
    
    # Add line to plot
    label = LABELS[dataset]
    ax.plot(
        TRAIN_FRACS, 
        overall_accuracies, 
        label=label, 
        linewidth=3.5, 
        marker=markers[i], 
        markersize=10,
    )

# Formatting
axis_fontsize = 16
heading_fontsize = 18
ax.legend(loc="upper left", fontsize=axis_fontsize)
ax.set_xlabel("Training set fraction", fontsize=heading_fontsize)
ax.set_ylabel("Overall Accuracy",  fontsize=heading_fontsize)
ax.set_xticks(TRAIN_FRACS)
ax.set_xticklabels(TRAIN_FRACS, fontsize=axis_fontsize)
ylabels = np.arange(0.1, 0.35, 0.05)
ax.set_yticks(ylabels)
ax.set_yticklabels([f"{ylabel:0.2f}" for ylabel in ylabels], fontsize=axis_fontsize)
fig.tight_layout()
fig.savefig("scaling_behavior.png", dpi=300)