# Evaluate the performance of the model on the test set

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import holoviews as hv
import hvplot.pandas
import hvplot.xarray
import janitor
import numpy as np
import pandas as pd
import torch
import xarray as xr
from pytorch_hcs.datasets import BBBC021DataModule
from pytorch_hcs.models import ResNet18, ResNet101, ResNet18Embeddings
from pytorch_hcs.vis import set_hv_defaults
from sklearn.metrics import confusion_matrix
from tqdm.notebook import tqdm

# from pyprojroot import here


set_hv_defaults()

In [None]:
data_path = Path("data")  # here() / "data"
data_path

# Choose GPU or CPU processing

In [None]:
DEVICE = "cuda"
# DEVICE = 'cpu'

# Specify model to load

In [None]:
import wandb

wandb.login()
run = wandb.init()

# Load models

In [None]:
model_artifact_ids = {
    "resnet18": "model-1eyyjpad:v0",
    "resnet18-moreaug": "model-3d5kdlrp:v0",
    "resnet18-notpretrained": "model-6bsy7dth:v0",
    "resnet101": "model-3fizb084:v0",
}

In [None]:
models = {}

for run_name, artifact_id in model_artifact_ids.items():
    artifact = run.use_artifact(
        f"zbarry/pytorch-hcs/{model_artifact_ids[run_name]}", type="model"
    )

    artifact_dir = artifact.download()

    ckpt_path = f"{artifact_dir}/model.ckpt"

    if "resnet18" in run_name:
        model_cls = ResNet18
    elif "resnet101" in run_name:
        model_cls = ResNet101
    else:
        print("Could not parse model from run name:", run_name)

        continue

    models[run_name] = (
        model_cls.load_from_checkpoint(str(ckpt_path)).eval().to(DEVICE)
    )

In [None]:
models = {
#     "round1": ResNet18Embeddings.load_from_checkpoint(
#         "/home/zachary/projects/pytorch-hcs/notebooks/data/weights/ResNet18Embeddings/version_3hxtgk45/epoch=22-step=5703.ckpt"
#     )
#     .eval()
#     .to(DEVICE),
    "round2": ResNet18Embeddings.load_from_checkpoint(
        "/home/zachary/projects/pytorch-hcs/notebooks/data/weights/ResNet18Embeddings/version_358i4dhs/epoch=29-step=7439.ckpt"
    )
    .eval()
    .to(DEVICE),
}

In [None]:
# model_path = data_path / f"weights/{run_name}/{model_version}"

# checkpoint_files = list(model_path.glob("epoch=*.ckpt"))

# if len(checkpoint_files) > 1:
#     raise Exception("Too many checkpoint files")
# if len(checkpoint_files) == 0:
#     raise FileNotFoundError("No checkpoint file exists.")

# checkpoint_file = checkpoint_files[0]

# print(checkpoint_file)

# model = model_cls.load_from_checkpoint(str(checkpoint_file)).eval().to(DEVICE)

# Set up `LightningDataModule`

In [None]:
dm = BBBC021DataModule(
    num_workers=8,
    tv_batch_size=16,
    t_batch_size=16,
)

dm.setup()

# Iterate through test set, extracting predicted class labels from model

In [None]:
# dataset = dm.train_dataset
# dataloader = dm.train_dataloader()

# dataset = dm.val_dataset
# dataloader = dm.val_dataloader()

dataset = dm.test_dataset
dataloader = dm.test_dataloader()

In [None]:
label_dfs = []

for run_name, model in models.items():

    true_labels = []
    predicted_labels = []
    image_idcs = []

    with torch.no_grad():
        for image_batch, true_labels_batch, metadata_batch in tqdm(dataloader):
            image_idcs.extend(metadata_batch.image_idx.numpy())
            outputs = model(image_batch.to(DEVICE))

            labels = torch.argmax(outputs, 1).cpu()

            predicted_labels.extend(labels)

            true_labels.extend(true_labels_batch)

    true_labels = np.array(true_labels)
    predicted_labels = np.array(predicted_labels)
    image_idcs = np.array(image_idcs)
    
    df = pd.DataFrame(dict(run_name=run_name, true_label=true_labels, predicted_label=predicted_labels, image_idx=image_idcs))
    
    label_dfs.append(df)

In [None]:
true_labels_batch

In [None]:
true_labels

In [None]:
outputs

In [None]:
label_df = pd.concat(label_dfs, ignore_index=True).astype(dict(run_name='category'))
label_df

In [None]:
from sklearn.metrics import matthews_corrcoef

In [None]:
mcc_df = (
    label_df.groupby("run_name")
    .apply(
        lambda df: matthews_corrcoef(df["true_label"], df["predicted_label"])
    )
    .to_frame("mcc_score")
    .reset_index()
)

mcc_df

In [None]:
true_labels, predicted_labels, image_idcs =label_df.query('run_name == "round2"')[
    ["true_label", "predicted_label", "image_idx"]
].values.T

In [None]:
error_image_idcs = image_idcs[np.flatnonzero(predicted_labels != true_labels)]
error_image_idcs

In [None]:
from pybbbc import BBBC021

bbbc021 = BBBC021()

In [None]:
def make_layout(image_idx):
    image, metadata = bbbc021[error_image_idcs[image_idx]]

    #     prefix = f"{metadata.compound.compound} @ {metadata.compound.concentration:.2e} μM, {metadata.compound.moa}"

    prefix = f"{metadata.compound.compound}, {metadata.compound.moa}, {error_image_idcs[image_idx]}"

    plots = []

    cmaps = ["fire", "kg", "kb"]

    for channel_idx, im_channel in enumerate(image):
        plot = hv.Image(
            im_channel,
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label=f"{prefix} | {bbbc021.CHANNELS[channel_idx]}",
        ).opts(cmap=cmaps[channel_idx])
        plots.append(plot)

    plots.append(
        hv.RGB(
            image.transpose(1, 2, 0),
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label="Channel overlay",
        )
    )

    return hv.Layout(plots).cols(2)


hv.DynamicMap(make_layout, kdims="image").redim.range(
    image=(0, len(error_image_idcs) - 1)
)

# Load metadata

In [None]:
image_df = dataset.image_df
image_df

# Build list of MoAs for visualization

In [None]:
moa_df = dataset.moa_df
moa_df

In [None]:
moas = np.array(moa_df["moa"].unique())
moas = moas[moas != "null"]
moas

# Find subset of MoAs not in test set

In [None]:
moa_in_test = image_df["moa"].unique()

missing_moas = set(moas).difference(moa_in_test)

print(f"MoAs not in test set: {missing_moas}")

# Calculate and normalize confusion matrix

In [None]:
true_labels

In [None]:
cmat = xr.DataArray(
    confusion_matrix(
        true_labels,
        predicted_labels,
        labels=np.arange(
            np.array(list(dataset.class_to_label.values())).max() + 1
        ),
    ),
    dims=["moa_true", "moa_predicted"],
    coords=dict(moa_true=moas, moa_predicted=moas),
    name="confusion",
)

cmat

In [None]:
cmat_normed = (cmat / cmat.sum("moa_predicted")).pipe(
    lambda da: da.where(~da.isnull(), other=0)
)

# Visualize confusion matrix

In [None]:
cmat_normed.hvplot.heatmap(
    "moa_predicted",
    "moa_true",
    "confusion",
    rot=45,
    frame_width=300,
    frame_height=300,
    cmap="bjy",
    ylabel="True MoA",
    xlabel="Predicted MoA",
    title="Model predictions vs. true MoAs",
    clim=(0, 1),
)