In [None]:
from vacation.model import VCNN

import optuna
from optuna.artifacts import FileSystemArtifactStore, download_artifact, upload_artifact
from optuna.storages import RetryFailedTrialCallback

import matplotlib.pyplot as plt

In [None]:
CHECKPOINT_DIR = "/scratch/tgross/vacation_models/artifacts"

artifact_store = FileSystemArtifactStore(base_path=CHECKPOINT_DIR)

storage = optuna.storages.RDBStorage(
    "sqlite:///../scripts/vacation.sqlite3",
    heartbeat_interval=1,
    failed_trial_callback=RetryFailedTrialCallback(),
)

In [None]:
storage.get_all_studies()

In [None]:
study = optuna.load_study(study_name="vacation", storage=storage)
best_artifact = f"./artifact-{study.best_trial.number}.pt"

In [None]:
try:
    download_artifact(
        artifact_store=artifact_store,
        file_path=best_artifact,
        artifact_id=study.best_trial.user_attrs["artifact_id"],
    )
except FileExistsError:
    pass

In [None]:
model = VCNN.load(path=best_artifact)

In [None]:
model.plot_metric(key="accuracy")

In [None]:
plt.plot(model._metrics["accuracy"].train_vals, label="Train")
plt.plot(model._metrics["accuracy"].valid_vals, label="Valid")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()

In [None]:
plt.plot(model._loss_metric.train_vals, label="Train")
plt.plot(model._loss_metric.valid_vals, label="Valid")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()