# Use trained model to extract image features

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import holoviews as hv
import hvplot.pandas
import hvplot.xarray
import janitor
import numpy as np
import pandas as pd
import torch
import umap
import xarray as xr
from pytorch_hcs.datasets import BBBC021DataModule
from pytorch_hcs.models import ResNet18, ResNet18Embeddings
from pytorch_hcs.vis import set_hv_defaults
from tqdm.notebook import tqdm

# from pyprojroot import here


set_hv_defaults()

In [None]:
data_path = Path("data")  # here() / "data"
data_path

# Choose GPU or CPU processing

In [None]:
DEVICE = "cuda"
# DEVICE = 'cpu'

# Specify model to load

In [None]:
# run_name = "resnet18-cleandata-moreaug"
# model_version = "version_3d5kdlrp"

# model_cls = ResNet18

run_name = "ResNet18Embeddings"
# model_version = "version_3hxtgk45"  # version 1
model_version = "version_358i4dhs"  # version 2 with dropout before last layer

model_cls = ResNet18Embeddings

# Set up `LightningDataModule`

In [None]:
dm = BBBC021DataModule(
    num_workers=8,
    tv_batch_size=4,
    t_batch_size=32,
)

dm.setup()

# Load model

In [None]:
model_path = data_path / f"weights/{run_name}/{model_version}"

checkpoint_files = list(model_path.glob("epoch=*.ckpt"))

if len(checkpoint_files) > 1:
    raise Exception("Too many checkpoint files")
if len(checkpoint_files) == 0:
    raise FileNotFoundError("No checkpoint file exists.")

checkpoint_file = checkpoint_files[0]

print(checkpoint_file)

model_bbbc021 = (
    model_cls.load_from_checkpoint(str(checkpoint_file), num_channels=3).eval().to(DEVICE)
)
model_imagenet = (
    model_cls(num_classes=dm.num_classes, pretrained=True, num_channels=3).eval().to(DEVICE)
)

# Iterate through all images across train/val/test sets, plugging into model and extracting features 

We skip running the features through the final classification layer.

In [None]:
dataset = dm.all_dataset
# dataset = dm.train_dataset

dataloader = dm.all_dataloader()
# dataloader = dm.train_dataloader()

In [None]:
features_bbbc021 = []
features_imagenet = []

with torch.no_grad():
    for image_batch, _, _ in tqdm(dataloader):
        # features from our BBBC021-trained model

        features_batch = np.array(
            model_bbbc021.compute_features(image_batch.to(DEVICE)).cpu()
        )

        features_bbbc021.append(features_batch)

        # features from ImageNet-trained model

        features_batch = np.array(
            model_imagenet.compute_features(image_batch.to(DEVICE)).cpu()
        )

        features_imagenet.append(features_batch)

features_bbbc021 = np.concatenate(features_bbbc021, axis=0)
features_imagenet = np.concatenate(features_imagenet, axis=0)

In [None]:
dataset.class_to_label

In [None]:
image_df = dataset.image_df.transform_column(
    "moa", lambda class_name: dataset.class_to_label[class_name], "moa_label"
)
image_df

# Perform dimensionality reduction for visualization using UMAP

- We will r

Sources: 
- https://pair-code.github.io/understanding-umap/https://pair-code.github.io/understanding-umap/

# TODO: add replicates

In [None]:
from itertools import product

In [None]:
params = dict(
    datasets=[("BBBC021", features_bbbc021), ("ImageNet", features_imagenet)],
    metrics=["euclidean", "cosine"],
    n_neighbors=[35, 100, 500, 1000],
    densmap=[False, True],
    supervised=[False, True],
)

In [None]:
embedding_dfs = []

for (dataset_name, features), metric, n_neighbors, densmap, supervised in tqdm(
    product(
        params["datasets"],
        params["metrics"],
        params["n_neighbors"],
        params["densmap"],
        params["supervised"],
    ),
    total=np.prod([len(lst) for lst in params.values()]),
):
    reducer = umap.UMAP(
        metric=metric,
        n_neighbors=n_neighbors,
        min_dist=0.0,
        n_components=2,
        random_state=42,
        densmap=densmap,
    )

    vis_embedding = reducer.fit_transform(
        features.reshape(features.shape[0], -1),
        y=image_df["moa_label"] if supervised else None,
    )

    embedding_df = (
        pd.concat(
            [
                dataset.image_df.reset_index(drop=True),
                pd.DataFrame(vis_embedding, columns=["umap_x", "umap_y"]),
            ],
            axis=1,
        )
        .add_columns(
            dataset=dataset_name,
            metric=metric,
            n_neighbors=n_neighbors,
            densmap=densmap,
            supervised=supervised,
        )
        .reorder_columns(
            ["dataset", "metric", "n_neighbors", "densmap", "supervised"]
        )
    )

    embedding_dfs.append(embedding_df)

all_embedding_df = pd.concat(embedding_dfs, ignore_index=True).astype(
    dict(dataset="category", metric="category", concentration=float)
)
all_embedding_df

In [None]:
all_embedding_df.to_parquet(data_path / 'umap_results.parquet')

# Visualize the results

In [None]:
hover_cols = [
    "image_idx",
    "moa",
    "compound",
    "concentration",
    #     "tvt_assignment",
]

groups = ["weights", "metric", "n_neighbors", "densmap", "supervised"]

kwargs = dict(
    x="umap_x",
    y="umap_y",
    hover_cols=hover_cols,
    alpha=0.25,
    aspect="equal",
    cmap="glasbey",
    colorbar=False,
    width=900,
    height=550,
)

In [None]:
(
    all_embedding_df.rename_column("dataset", "weights")
    .query('moa != "null"')
    .hvplot.scatter(
        c="moa",
        title="UMAP embedding | {dimensions}",
        groupby=groups,
        **kwargs
    )
    .opts(fontsize=dict(title=10))
    .layout(groups)
    .cols(2)
    .opts(
        shared_axes=False,
    )
)

# Select the one we want

In [None]:
metric = 'cosine'
n_neighbors = 500
densmap = True
supervised = False

dataset_name, features = "BBBC021", features_bbbc021

reducer = umap.UMAP(
    metric=metric,
    n_neighbors=n_neighbors,
    min_dist=0.0,
    n_components=2,
    random_state=42,
    densmap=densmap,
)

vis_embedding = reducer.fit_transform(
    features.reshape(features.shape[0], -1),
    y=image_df["moa_label"] if supervised else None,
)

embedding_df = (
    pd.concat(
        [
            dataset.image_df.reset_index(drop=True),
            pd.DataFrame(vis_embedding, columns=["umap_x", "umap_y"]),
        ],
        axis=1,
    )
    .add_columns(
        dataset=dataset_name,
        metric=metric,
        n_neighbors=n_neighbors,
        densmap=densmap,
        supervised=supervised,
    )
    .reorder_columns(
        ["dataset", "metric", "n_neighbors", "densmap", "supervised"]
    )
)

In [None]:
# embedding_df = (
#     all_embedding_df.query('dataset == "BBBC021" and metric == "cosine" and n_neighbors == 500 and densmap == True and supervised == False')
#     .copy()
#     .reset_index(drop=True)
# )
# embedding_df

In [None]:
(
    embedding_df.query('moa != "null"').hvplot.scatter(
        c="moa", title="UMAP embedding", **kwargs
    )
)

In [None]:
kwargs_ = kwargs.copy()
kwargs_.pop("cmap")
kwargs_.pop("colorbar")

(
    embedding_df.query('moa != "null"').hvplot.scatter(
        c="plate", title="UMAP embedding", cmap='glasbey',  **kwargs_
    )
)

# (
#     embedding_df.query('moa != "null"').hvplot.scatter(
#         c="site", title="UMAP embedding", cmap='glasbey', colorbar=True, **kwargs_
#     )
# )

# Make another plot but color the dots by whether they were predicted correctly in the test set or not

# Outlier analysis

# FINDING: COSINE DISTANCE ESSENTIAL TO FINDING USEFUL OUTLIERS

## KNN mean distance

In [None]:
def ecdf(data):
    data_sorted = np.sort(data)

    # calculate the proportional values of samples
    p = np.arange(len(data)) / (len(data) - 1)
    
    return data_sorted, p

In [None]:
from sklearn.neighbors import NearestNeighbors

nbrs = NearestNeighbors(n_neighbors=8)
nbrs.fit(embedding_df[["umap_x", "umap_y"]])

distances, indexes = nbrs.kneighbors(embedding_df[["umap_x", "umap_y"]])

distances = distances[:, 1:]

avg_distances = distances.mean(1)

cdf_x, cdf_y = ecdf(avg_distances)

(
    hv.Curve(
        avg_distances,
        kdims="BBBC021 image index",
        vdims="distance",
        label="Average kNN distance for BBBC021 image UMAP projections",
    ).opts(width=1000)
    + hv.Histogram(np.histogram(avg_distances, bins=200), kdims='distance').opts(width=1000)
    + hv.Curve((cdf_x, cdf_y), kdims='distance', vdims='ECDF').opts(width=1000)
).cols(1)

In [None]:
labeled_embedding_df = embedding_df.add_columns(
    outlier_score=avg_distances,
)

In [None]:
kwargs_

In [None]:
kwargs_ = kwargs.copy()
kwargs_.pop("cmap")
kwargs_.pop("colorbar")
kwargs_.pop("alpha")

(
    labeled_embedding_df.hvplot.scatter(
        c="outlier_score",
        title="Average kNN distance",
        cmap="jet",
        colorbar=True,
        logz=True,
        alpha=0.5,
        **kwargs_
    )
)

In [None]:
from pybbbc import BBBC021

bbbc021 = BBBC021()

In [None]:
outlier_df = labeled_embedding_df.sort_values("outlier_score", ascending=False)
outlier_order = outlier_df["image_idx"].values
outlier_scores = outlier_df["outlier_score"].values


def make_layout(image_idx):
    image, metadata = bbbc021[outlier_order[image_idx]]

    #     prefix = f"{metadata.compound.compound} @ {metadata.compound.concentration:.2e} μM, {metadata.compound.moa}"

    prefix = f"{metadata.compound.compound}, {metadata.compound.moa}, {outlier_scores[image_idx]}"

    plots = []

    cmaps = ["fire", "kg", "kb"]

    for channel_idx, im_channel in enumerate(image):
        plot = hv.Image(
            im_channel,
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label=f"{prefix} | {bbbc021.CHANNELS[channel_idx]}",
        ).opts(cmap=cmaps[channel_idx])
        plots.append(plot)

    plots.append(
        hv.RGB(
            image.transpose(1, 2, 0),
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label="Channel overlay",
        )
    )

    return hv.Layout(plots).cols(2)


hv.DynamicMap(make_layout, kdims="image").redim.range(
    image=(0, len(bbbc021) - 1)
)