In [None]:
import numpy as np
from vacation.data import GalaxyDataset, CLASS_NAMES
from vacation.evaluation.visualizations import (
    plot_example_matrix,
    plot_confusion_matrix,
    plot_hyperparameter_importance,
)
from vacation.evaluation.optimization_results import (
    get_best_trial,
    get_model_from_trial,
)

import matplotlib.pyplot as plt

from sklearn.metrics import classification_report, accuracy_score

In [None]:
dataset = GalaxyDataset(
    "/scratch/tgross/vacation_data/reduced_size/Galaxy10_DECals_proc_test.h5",
    device="cuda:1",
    cache_loaded=True,
)

In [None]:
dataset.plot_distribution()

In [None]:
best_trial, study = get_best_trial(
    storage_path="../scripts/vacation.sqlite3",
    study_name="vacation_v2",
    return_study=True,
)
model = get_model_from_trial(
    trial=best_trial,
    checkpoint_dir="/scratch/tgross/vacation_models/artifacts/",
    download_path="./build/vacation_v2.pt",
    overwrite=True,
)

In [None]:
y_pred, y_true = model.predict_dataset(dataset=dataset, return_true=True)

In [None]:
print(classification_report(y_true=y_true.cpu().numpy(), y_pred=y_pred.cpu().numpy()))

In [None]:
accuracy_score(y_true=y_true.cpu().numpy(), y_pred=y_pred.cpu().numpy())

In [None]:
plot_example_matrix(
    dataset=dataset,
    y_pred=y_pred,
    layout=(3, 3),
    figsize=(7, 7),
    seed=42,
    save_path="./build/examples.png",
)

In [None]:
plot_confusion_matrix(y_true=y_true, y_pred=y_pred, normalize=True)

In [None]:
CLASS_NAMES

In [None]:
plot_hyperparameter_importance(
    study=study,
    log=False,
)

In [None]:
model.summarize()