In [None]:
import numpy as np
from vacation.data import GalaxyDataset, CLASS_NAMES
from vacation.evaluation.visualizations import plot_example_matrix
from vacation.evaluation.optimization_results import (
    get_best_trial,
    get_model_from_trial,
)

from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
dataset = GalaxyDataset(
    "/scratch/tgross/vacation_data/reduced_size/Galaxy10_DECals_valid.h5",
    device="cuda:1",
    index_collection=np.random.randint(0, 4204, 1200),
)

In [None]:
best_trial, study = get_best_trial(
    storage_path="../scripts/vacation.sqlite3",
    study_name="vacation_v2",
    return_study=True,
)
model = get_model_from_trial(
    trial=best_trial,
    checkpoint_dir="/scratch/tgross/vacation_models/artifacts/",
    download_path="./build/vacation_v2.pt",
    overwrite=True,
)

In [None]:
y_pred, y_true = model.predict_dataset(dataset=dataset, return_true=True)

In [None]:
np.sum((y_pred == y_true).cpu().numpy()) / len(y_true)

In [None]:
example_matrix(
    dataset=dataset,
    y_pred=y_pred.cpu().numpy(),
    layout=(3, 3),
    figsize=(7, 7),
    seed=1337,
    save_path="./build/examples.png",
)

In [None]:
CLASS_NAMES

In [None]:
from optuna.importance import get_param_importances
import matplotlib.pyplot as plt

In [None]:
param_importances = get_param_importances(study=study)

In [None]:
fig, ax = plt.subplots(figsize=(3, 6))
ax.barh(
    y=list(param_importances.keys())[::-1],
    width=list(param_importances.values())[::-1],
    color="#e64553",
    log=True,
)
ax.set_xlabel("fANOVA Importance Score")
ax.set_ylabel("Hyperparameter")