# Evaluate the performance of a model on the test set

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import holoviews as hv
import hvplot.pandas
import janitor
import numpy as np
import pandas as pd
import torch

from pytorch_hcs.datasets import BBBC021DataModule
from pytorch_hcs.models import ResNet18, ResNet101, ResNet18Embeddings

from sklearn.metrics import confusion_matrix
from tqdm.notebook import tqdm

hv.extension('bokeh')

In [None]:
data_path = Path("data")
data_path

# Choose GPU or CPU processing

In [None]:
DEVICE = "cuda"
# DEVICE = 'cpu'

# Initialize W&B run

This is only so we can load model checkpoints from W&B artifacts.
Future TODO would be to store the results of the evaluation run in W&B.

In [None]:
import wandb

wandb.login()
run = wandb.init(project='pytorch-hcs', name='evaluation')

# Specify model to load

Model checkpoint artifacts from the training will be accessible under `'model-{version}'`,
where `version` is by default set to the class name of the PyTorch-Lightning module.
You can find your model artifact names as a tab in the left pane under the W&B run.

In [None]:
user_project = 'zbarry/pytorch-hcs'

model_id, model_cls = "resnet18:latest", ResNet18
# model_id, model_cls = "resnet101:latest", ResNet101
# model_id, model_cls = "resnet18-embeddings:latest", ResNet18Embeddings

# Load model

Download model .ckpt file from W&B. 
Note that a model `.ckpt` file can be loaded directly through `.load_from_checkpoint` in the `data/weights` directory
rather than downloaded from W&B, if desired.

In [None]:
artifact = run.use_artifact(
    f"{user_project}/{model_id}", type="model"
)

artifact_dir = artifact.download()

ckpt_path = f"{artifact_dir}/model.ckpt"

model = model_cls.load_from_checkpoint(str(ckpt_path)).eval().to(DEVICE)

# Set up `LightningDataModule`

In [None]:
dm = BBBC021DataModule(
    num_workers=8 if DEVICE != "cpu" else 0,  # h5py pickling error otherwise...
    tv_batch_size=16,
    t_batch_size=8,
)

dm.setup()

# Iterate through dataset, extracting predicted class labels from model

Training and validation datasets also available for exploration.

In [None]:
# dataset = dm.train_dataset
# dataloader = dm.train_dataloader()

# dataset = dm.val_dataset
# dataloader = dm.val_dataloader()

dataset = dm.test_dataset
dataloader = dm.test_dataloader()

In [None]:
true_labels = []
predicted_labels = []
image_idcs = []

with torch.no_grad():
    for image_batch, true_labels_batch, metadata_batch in tqdm(dataloader):
        image_idcs.extend(metadata_batch.image_idx.numpy())
        outputs = model(image_batch.to(DEVICE))

        labels = torch.argmax(outputs, 1).cpu()

        predicted_labels.extend(labels)

        true_labels.extend(true_labels_batch)

true_labels = np.array(true_labels)
predicted_labels = np.array(predicted_labels)
image_idcs = np.array(image_idcs)

In [None]:
true_labels

In [None]:
predicted_labels

# Calculate Matthews correlation coefficient

See the [Wikipedia entry](https://en.wikipedia.org/wiki/Phi_coefficient).
MCC is useful for a multiclass classification problem with highly imbalanced classes.

In [None]:
from sklearn.metrics import matthews_corrcoef

In [None]:
mcc = matthews_corrcoef(true_labels, predicted_labels)
mcc

# Display images which were predicted incorrectly

In [None]:
error_image_idcs = image_idcs[np.flatnonzero(predicted_labels != true_labels)]
error_image_idcs

error_predictions = predicted_labels[np.flatnonzero(predicted_labels != true_labels)]

In [None]:
from pybbbc import BBBC021

bbbc021 = BBBC021()

In [None]:
def make_layout(image_idx):
    image, metadata = bbbc021[error_image_idcs[image_idx]]

    predicted_class = dataset.label_to_class[error_predictions[image_idx]]

    prefix = f"{metadata.compound.compound}, {metadata.compound.moa}, Pred: {predicted_class}"

    plots = []

    cmaps = ["fire", "kg", "kb"]

    for channel_idx, im_channel in enumerate(image):
        plot = hv.Image(
            im_channel,
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label=f"{prefix}",  # | {bbbc021.CHANNELS[channel_idx]}",
        ).opts(cmap=cmaps[channel_idx])
        plots.append(plot)

    plots.append(
        hv.RGB(
            image.transpose(1, 2, 0),
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label="Channel overlay",
        )
    )

    return hv.Layout(plots).cols(2)


hv.DynamicMap(make_layout, kdims="image").redim.range(
    image=(0, len(error_image_idcs) - 1)
).opts(
    hv.opts.Image(frame_width=450, aspect="equal", active_tools=["wheel_zoom"]),
    hv.opts.RGB(frame_width=450, aspect="equal", active_tools=["wheel_zoom"]),
)

# Build and visualize confusion matrix

## Metadata / MoA labels

In [None]:
image_df = dataset.image_df
image_df

In [None]:
moa_df = dataset.moa_df
moa_df

In [None]:
all_moas = np.array(dm.all_dataset.moa_df.query('moa != "null"')['moa'].unique())
all_moas

## Find subset of MoAs not in test set

In [None]:
moa_in_test = moa_df["moa"].unique()

missing_moas = set(all_moas).difference(moa_in_test)

print(f"MoAs not in test set: {missing_moas}")

## Construct and normalize confusion matrix

In [None]:
import xarray as xr
import hvplot.xarray

In [None]:
cmat = xr.DataArray(
    confusion_matrix(
        true_labels,
        predicted_labels,
        labels=np.arange(
            np.array(list(dataset.class_to_label.values())).max() + 1
        ),
    ),
    dims=["moa_true", "moa_predicted"],
    coords=dict(moa_true=moas, moa_predicted=moas),
    name="confusion",
)

cmat

In [None]:
cmat_normed = (cmat / cmat.sum("moa_predicted")).pipe(
    lambda da: da.where(~da.isnull(), other=0)
)

## Visualize confusion matrix

In [None]:
cmat_normed.hvplot.heatmap(
    "moa_predicted",
    "moa_true",
    "confusion",
    rot=45,
    frame_width=300,
    frame_height=300,
    cmap="bjy",
    ylabel="True MoA",
    xlabel="Predicted MoA",
    title="Model predictions vs. true MoAs",
    clim=(0, 1),
)

# Dump results to be used later

In [None]:
import pickle as pkl

with open("eval_results.pkl", "wb") as fhdl:
    pkl.dump(
        dict(
            model_name=model.__class__.__name__,
            mcc=mcc,
            cmat=cmat,
            cmat_normed=cmat_normed,
            image_df=image_df,
            moa_df=moa_df,
            true_labels=true_labels,
            predicted_labels=predicted_labels,
            image_idcs=image_idcs,
            error_image_idcs=error_image_idcs,
            error_predictions=error_predictions,
            error_classes=[
                dataset.label_to_class[error_predictions[image_idx]]
                for image_idx in range(len(error_predictions))
            ],
        ), fhdl
    )