# Use UMAP / densMAP on CNN embeddings to find dataset outliers

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import holoviews as hv
import hvplot.pandas
import hvplot.xarray
import janitor
import numpy as np
import pandas as pd
import torch
import umap
import xarray as xr
from pytorch_hcs.datasets import BBBC021DataModule
from pytorch_hcs.models import ResNet18, ResNet101, ResNet18Embeddings
from tqdm.notebook import tqdm
from pathlib import Path

hv.extension('bokeh')

In [None]:
data_path = Path("data")
data_path

# Choose GPU or CPU processing

In [None]:
DEVICE = "cuda"
# DEVICE = 'cpu'

# Initialize W&B run

This is only so we can load model checkpoints from W&B artifacts.
Future TODO would be to store the results of the evaluation run in W&B.

In [None]:
import wandb

wandb.login()
run = wandb.init(project='pytorch-hcs', name='outliers')

# Specify model to load

In [None]:
user_project = 'zbarry/pytorch-hcs'

model_id, model_cls = "resnet18:latest", ResNet18
# model_id, model_cls = "resnet101:latest", ResNet101
# model_id, model_cls = "resnet18-embeddings:latest", ResNet18Embeddings

# Load model

Download model .ckpt file from W&B. 
Note that a model `.ckpt` file can be loaded directly through `.load_from_checkpoint` in the `data/weights` directory
rather than downloaded from W&B, if desired.

In [None]:
artifact = run.use_artifact(
    f"{user_project}/{model_id}", type="model"
)

artifact_dir = artifact.download()

ckpt_path = f"{artifact_dir}/model.ckpt"

model = model_cls.load_from_checkpoint(str(ckpt_path)).eval().to(DEVICE)

# Set up `LightningDataModule`

In [None]:
dm = BBBC021DataModule(
    num_workers=8 if DEVICE != "cpu" else 0,  # pickle error with h5py otherwise
    tv_batch_size=4,
    t_batch_size=16,
)

dm.setup()

# Extract image embeddings with our BBBC021-trained network

- We skip running the features through the final classification layer.
- All images are included (even those from compounds with unknown MoA)

In [None]:
dataset = dm.all_dataset
# dataset = dm.train_dataset
# dataset = dm.test_dataset

dataloader = dm.all_dataloader()
# dataloader = dm.train_dataloader()
# dataloader = dm.test_dataloader()

In [None]:
features_bbbc021 = []

with torch.no_grad():
    for image_batch, _, _ in tqdm(dataloader):
        # features from our BBBC021-trained model

        features_batch = np.array(
            model.compute_features(image_batch.to(DEVICE)).cpu()
        )

        features_bbbc021.append(features_batch)

features_bbbc021 = np.concatenate(features_bbbc021, axis=0)

In [None]:
image_df = dataset.image_df.transform_column(
    "moa", lambda class_name: dataset.class_to_label[class_name], "moa_label"
)
image_df

# Perform dimensionality reduction for visualization using UMAP

- [UMAP article](https://pair-code.github.io/understanding-umap/)
- We are using the [densMAP](https://umap-learn.readthedocs.io/en/latest/densmap_demo.html) implementation to push outlier points further away from inliers.
- Try both `'cosine'` and `'euclidean'` as distance metrics.
to see the effect on which images are the greatest outliers.
- `n_neighboors` has been tuned ahead of time (see `notebooks/poc/umap_param_sweep.ipynb`).

In [None]:
supervised = False

reducer = umap.UMAP(
    metric='cosine',
    n_neighbors=500,
    min_dist=0.0,
    n_components=2,
    random_state=42,
    densmap=True,
)

vis_embedding = reducer.fit_transform(
    features_bbbc021.reshape(features_bbbc021.shape[0], -1),
    y=image_df["moa_label"] if supervised else None,
)

embedding_df = (
    pd.concat(
        [
            dataset.image_df.reset_index(drop=True),
            pd.DataFrame(vis_embedding, columns=["umap_x", "umap_y"]),
        ],
        axis=1,
    )
)

# Plot UMAP'd embeddings for compounds of known MoA

In [None]:
hover_cols = [
    "image_idx",
    "moa",
    "compound",
    "concentration",
]

kwargs = dict(
    x="umap_x",
    y="umap_y",
    hover_cols=hover_cols,
    alpha=0.25,
    aspect="equal",
    cmap="glasbey",
    colorbar=False,
    width=900,
    height=550,
)

(
    embedding_df.query('moa != "null"').hvplot.scatter(
        c="moa", title="UMAP embedding", **kwargs
    )
)

# Calculate k-nearest neighbors distance for each embedded point

In [None]:
from sklearn.neighbors import NearestNeighbors

n_neighbors = 8

nbrs = NearestNeighbors(n_neighbors=n_neighbors)
nbrs.fit(embedding_df[["umap_x", "umap_y"]])

distances, indexes = nbrs.kneighbors(embedding_df[["umap_x", "umap_y"]])

distances = distances[:, 1:]

avg_distances = distances.mean(1)

labeled_embedding_df = embedding_df.add_columns(
    outlier_score=avg_distances,
)

In [None]:
def ecdf(data):
    data_sorted = np.sort(data)

    # calculate the proportional values of samples
    p = np.arange(len(data)) / (len(data) - 1)
    
    return data_sorted, p

In [None]:
cdf_x, cdf_y = ecdf(avg_distances)

(
    hv.Curve(
        avg_distances,
        kdims="BBBC021 image index",
        vdims="distance",
        label="Average kNN distance for BBBC021 image UMAP projections",
    ).opts(width=1000)
    + hv.Histogram(
        np.histogram(avg_distances, bins=200), kdims="distance"
    ).opts(width=1000)
    + hv.Curve((cdf_x, cdf_y), kdims="distance", vdims="ECDF").opts(width=1000)
).cols(1)

# View images in order of descending kNN distance

In [None]:
from pybbbc import BBBC021
bbbc021 = BBBC021()

In [None]:
outlier_df = labeled_embedding_df.sort_values("outlier_score", ascending=False)

outlier_order = outlier_df["image_idx"].values
outlier_scores = outlier_df["outlier_score"].values


def make_layout(image_idx):
    image, metadata = bbbc021[outlier_order[image_idx]]

    prefix = f"{metadata.compound.compound}, {metadata.compound.moa}, {outlier_scores[image_idx]}"

    plots = []

    cmaps = ["fire", "kg", "kb"]

    for channel_idx, im_channel in enumerate(image):
        plot = hv.Image(
            im_channel,
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label=f"{prefix} | {bbbc021.CHANNELS[channel_idx]}",
        ).opts(cmap=cmaps[channel_idx])
        plots.append(plot)

    plots.append(
        hv.RGB(
            image.transpose(1, 2, 0),
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label="Channel overlay",
        )
    )

    return hv.Layout(plots).cols(2)


hv.DynamicMap(make_layout, kdims="image").redim.range(
    image=(0, len(bbbc021) - 1)
).opts(hv.opts.Image(frame_width=450, aspect='equal'), hv.opts.RGB(frame_width=450, aspect='equal'))