# Analysis of the inferred image embeddings

In this notebook we assess the inferred image embeddings for the previosuly determined impactful gene perturbation settings. To this end, we will use the image embeddings computed during the training of the convolutional neural network in the 4-fold Group CV setup.

---

## 0. Environmental setup

First, we read in the required software packages and libraries.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from umap import UMAP
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import seaborn as sns
from sklearn.cluster import KMeans
from scipy.cluster import hierarchy as hc
from scipy.spatial.distance import pdist, euclidean, cosine
from tqdm import tqdm
from scipy.spatial.distance import squareform
import sys
from sklearn.metrics import (
    mutual_info_score,
    adjusted_mutual_info_score,
    adjusted_rand_score,
    rand_score,
    v_measure_score,
    normalized_mutual_info_score,
)
import matplotlib as mpl
from collections import Counter
from yellowbrick.cluster.elbow import kelbow_visualizer
from yellowbrick.cluster import KElbowVisualizer
from IPython.display import Image
from statannot import add_stat_annotation
import ot

sys.path.append("../../..")
from src.utils.notebooks.ppi.embedding import *
from src.utils.notebooks.images.embedding import *
from src.utils.notebooks.translation.analysis import *
from src.utils.basic.io import get_genesets_from_gmt_file

mpl.rcParams["figure.dpi"] = 600

seed = 1234

%reload_ext nb_black

In [None]:
def assess_cluster_topk(reg_nn_dict, struct_nn_dict, cluster_df):
    struct_topks = []
    reg_topks = []
    samples = []
    for sample in reg_nn_dict.keys():
        reg_nns = reg_nn_dict[sample]
        struct_nns = struct_nn_dict[sample]
        cluster = np.array(cluster_df.loc[sample])[0]
        n_cluster_samples = len(cluster_df.loc[cluster_df.cluster == cluster])
        if n_cluster_samples < 2:
            continue
        samples.append(sample)
        sample_struct_topks = [0]
        sample_reg_topks = [0]
        for i in range(1, len(reg_nns)):
            reg_nn_cluster = np.array(cluster_df.loc[reg_nns[i]])[0]
            struct_nn_cluster = np.array(cluster_df.loc[struct_nns[i]])[0]
            sample_struct_topks.append(
                sample_struct_topks[-1] + int(struct_nn_cluster == cluster)
            )
            sample_reg_topks.append(
                sample_reg_topks[-1] + int(reg_nn_cluster == cluster)
            )
        struct_topks.append(np.array(sample_struct_topks[1:]) / (n_cluster_samples - 1))
        reg_topks.append(np.array(sample_reg_topks[1:]) / (n_cluster_samples - 1))
    return samples, np.array(struct_topks), np.array(reg_topks)

In [None]:
def get_neighbor_dict(data, metric="euclidean"):
    samples = np.array(data.index)
    nn = NearestNeighbors(n_neighbors=len(data), metric=metric)
    sample_neighbor_dict = {}
    nn.fit(np.array(data))
    for sample in samples:
        if metric == "precomputed":
            query = np.zeros((1, len(data)))
            query[0, np.where(samples == sample)[0]] = 1
            pred_idx = nn.kneighbors(query, return_distance=False)[0]
        pred_idx = nn.kneighbors(
            np.array(data.loc[sample]).reshape(1, -1), return_distance=False
        )[0]
        sample_neighbor_dict[sample] = samples[pred_idx]
    return sample_neighbor_dict

In [None]:
def get_emd_for_embs(embs, label_col, metric="euclidean"):
    targets = np.unique(embs.loc[:, label_col])
    n_targets = len(targets)
    wd_mtx = np.infty * np.ones((n_targets, n_targets))
    for i in tqdm(range(n_targets), desc="Compute EMD"):
        source = targets[i]
        xs = np.array(embs.loc[embs.loc[:, label_col] == source]._get_numeric_data())
        ns = len(xs)
        ps = np.ones((ns,)) / ns
        for j in range(i, n_targets):
            target = targets[j]
            if source == target:
                wd_st = 0
            else:
                xt = np.array(
                    embs.loc[embs.loc[:, label_col] == target]._get_numeric_data()
                )
                nt = len(xt)
                pt = np.ones((nt,)) / nt
                m = ot.dist(xs, xt, metric=metric)
                m = m / m.max()
                wd_st = ot.emd2(ps, pt, m, numItermax=1e9)
            wd_mtx[i, j] = wd_st
            wd_mtx[j, i] = wd_st
    wd_df = pd.DataFrame(wd_mtx, columns=list(targets), index=list(targets))
    return wd_df

---

## 1. Read in data

Second, we read in the data that describes the latent embeddings of the individual images part of the respective held-out sets in the CV setting.

In [None]:
root_dir = "../../../data/experiments/image_embeddings/specificity_target_emb_cv_strat/final_1024/"

all_latents = []

for i in range(4):
    latents = pd.read_hdf(root_dir + "fold_{}/".format(i) + "test_latents.h5")
    latents["fold"] = "fold_{}".format(i)
    all_latents.append(latents)
latents = pd.concat(all_latents)
print("Read in latent embeddings of shape: {}".format(np.array(latents).shape))

We will decode the numeric class labels to identify which regulator each embedding corresponds to.

In [None]:
label_dict = {
    "AKT1S1": 0,
    "ATF4": 1,
    "BAX": 2,
    "BCL2L11": 3,
    "BRAF": 4,
    "CASP8": 5,
    "CDC42": 6,
    "CDKN1A": 7,
    "CEBPA": 8,
    "CREB1": 9,
    "CXXC4": 10,
    "DIABLO": 11,
    "E2F1": 12,
    "ELK1": 13,
    "EMPTY": 14,
    "ERG": 15,
    "FGFR3": 16,
    "FOXO1": 17,
    "GLI1": 18,
    "HRAS": 19,
    "IRAK4": 20,
    "JUN": 21,
    "MAP2K3": 22,
    "MAP3K2": 23,
    "MAP3K5": 24,
    "MAP3K9": 25,
    "MAPK7": 26,
    "MOS": 27,
    "MYD88": 28,
    "PIK3R2": 29,
    "PRKACA": 30,
    "PRKCE": 31,
    "RAF1": 32,
    "RELB": 33,
    "RHOA": 34,
    "SMAD4": 35,
    "SMO": 36,
    "SRC": 37,
    "SREBF1": 38,
    "TRAF2": 39,
    "TSC2": 40,
    "WWTR1": 41,
}
label_dict = dict(zip(list(label_dict.values()), list(label_dict.keys())))
latents.loc[:, "labels"] = latents.loc[:, "labels"].map(label_dict)

oe_targets = set(list(latents.loc[:, "labels"]))

---

## 2. Visualization of the embeddings

Next, we will visualize the individual image embeddings. To this end, we will use UMAP to compute a 2D representation of the individual embeddings.

### 2.1. Overview of the joint image embeddings

As a first step we show that as expected the image embeddings differ between folds which is expected by design.

In [None]:
embs = plot_struct_embs_cv(latents, random_state=1234, normalize_all=True)

---

### 2.2. Visualization of individual perturbation settings.

We now once more will plot the image embeddings of a given gene perturbation against a background established from all other gene perturbation and the control condition.

To this end, we will randomly select the image embeddigns computed for the first fold of the 4-fold Group K-Fold setup.

In [None]:
embs_0 = plot_struct_embs_cv(latents, random_state=1234, folds=["fold_0"])

In [None]:
mpl.style.use("default")
mpl.rcParams["figure.dpi"] = 600

# for gene in np.unique(embs_0.label):
for gene in ["EMPTY", "JUN", "MAP3K9", "RAF1"]:
    geneset = [gene]

    fig, ax = plt.subplots(figsize=[8, 6])
    ax.scatter(
        np.array(embs_0.loc[~embs_0.label.isin(geneset), "umap_0"]),
        np.array(embs_0.loc[~embs_0.label.isin(geneset), "umap_1"]),
        c="silver",
        alpha=0.1,
        label="other",
        s=3,
    )
    ax.scatter(
        np.array(embs_0.loc[embs_0.label.isin(geneset), "umap_0"]),
        np.array(embs_0.loc[embs_0.label.isin(geneset), "umap_1"]),
        # label=geneset[0],
        s=3,
        alpha=1,
        color="r",
        label=gene,
    )
    #     ax.legend(loc="lower right")
    #     handles, labels = ax.get_legend_handles_labels()
    #     ax.legend(
    #         handles=list(handles)[::-1],
    #         labels=list(labels)[::-1],
    #         loc="lower right",
    #         prop=dict(size=18),
    #     )
    #     #     for lh in ax.get_legend().legendHandles:
    #     #         lh.set_alpha(1)
    #     #         lh._sizes = [140]
    # ax.get_legend().set_title("Condition", prop={"size": "20"})
    #     ax.get_legend().set_title("")
    # ax.set_xlabel("umap_0", size=18)
    # ax.set_ylabel("umap_1", size=18)
    ax.set_xlabel("")
    ax.set_ylabel("")
    plt.xticks(size=14)
    plt.yticks(size=14)
    plt.show()
    plt.close()