<font size=18> EDA of the morphological profiles for the JUMP data set </font>

---

This notebook serves to scrape the morphological profiles for the different conditions from the JUMP data set and analyze the ones corresponding to selected overexpression conditions.

---

# Environmental setup

First, we load all required software packages.

In [None]:
import boto3
import pandas as pd
import numpy as np
import os
import sys
from collections import Counter
from botocore import UNSIGNED
from botocore.client import Config
from tqdm.notebook import tqdm
import pickle
from umap import UMAP
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import (
    cross_validate,
    StratifiedGroupKFold,
    StratifiedKFold,
)
from imblearn.under_sampling import RandomUnderSampler
from sklearn.preprocessing import OneHotEncoder
import seaborn as sns
from venn import venn
from scipy.spatial.distance import squareform
from scipy.spatial.distance import pdist
from numpy.linalg import svd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score

import matplotlib as mpl

mpl.rcParams["figure.dpi"] = 600


seed = 1234


%load_ext nb_black

In [None]:
def iterative_null_space_projection(
    embeddings,
    batch_labels,
    bac_threshold=0.3,
    max_iter=200,
    max_iter_clf=1000,
    random_state=1234,
    n_jobs=1,
    balance_batches=False,
    n_samples_per_batch=None,
):
    batch_labels = LabelEncoder().fit_transform(batch_labels)
    b = batch_labels.max() + 1

    if balance_batches:
        ru = RandomUnderSampler(random_state=random_state)
        sample_idc, _ = ru.fit_resample(
            np.array(list(range(len(batch_labels)))).reshape(-1, 1), batch_labels
        )
        sample_idc = sample_idc.ravel()
        train_batch_labels = batch_labels[sample_idc]
        train_embeddings = embeddings[sample_idc]

        if n_samples_per_batch is not None:
            batch_sample_count = dict(zip(list(range(b)), [n_samples_per_batch] * b))
            ru = RandomUnderSampler(
                random_state=random_state, sampling_strategy=batch_sample_count
            )
            sample_idc, _ = ru.fit_resample(
                np.array(list(range(len(train_batch_labels)))).reshape(-1, 1),
                train_batch_labels,
            )
            sample_idc = sample_idc.ravel()
            train_batch_labels = train_batch_labels[sample_idc]
            train_embeddings = train_embeddings[sample_idc]
    else:
        train_batch_labels = batch_labels
        train_embeddings = embeddings

    print(train_embeddings.shape)

    bacs = []
    for i in tqdm(range(max_iter)):
        clf = LogisticRegression(
            class_weight="balanced",
            multi_class="multinomial",
            solver="lbfgs",
            n_jobs=n_jobs,
            penalty="l2",
            max_iter=max_iter_clf,
        )

        bac = cross_val_score(
            clf,
            train_embeddings,
            train_batch_labels,
            scoring="balanced_accuracy",
            cv=StratifiedKFold(n_splits=3),
        ).mean()
        bacs.append(bac)
        clf.fit(train_embeddings, train_batch_labels)
        weights = clf.coef_
        _, _, vt = svd(weights, full_matrices=True)
        v_tilde = vt[b:, :].transpose()
        embeddings = np.matmul(embeddings, v_tilde)
        train_embeddings = np.matmul(train_embeddings, v_tilde)

        if bac < bac_threshold:
            break
    return embeddings, bacs

In [None]:
def compare_clustering_to_rohban(
    jump_mean_embeddings,
    rohban_mean_embeddings,
    metric="euclidean",
    figsize=[12, 12],
    method="complete",
):
    shared_targets = list(
        set(jump_mean_embeddings.index).intersection(rohban_mean_embeddings.index)
    )
    print(shared_targets)
    jump_dist = pd.DataFrame(
        squareform(pdist(jump_mean_embeddings.loc[shared_targets], metric=metric)),
        index=shared_targets,
        columns=shared_targets,
    )
    rohban_dist = pd.DataFrame(
        squareform(pdist(rohban_mean_embeddings.loc[shared_targets], metric=metric)),
        index=shared_targets,
        columns=shared_targets,
    )
    fig = sns.clustermap(jump_dist, figsize=figsize, cmap="viridis", method=method)
    fig.fig.suptitle("JUMP embeddings")
    plt.show()
    plt.close()

    gig = sns.clustermap(rohban_dist, figsize=figsize, cmap="inferno", method=method)
    gig.fig.suptitle("Rohban embeddings")
    plt.show()
    plt.close()

In [None]:
def compute_batch_statistics(
    embeddings,
    batch_col="Metadata_Batch",
    label_col="Metadata_Symbol",
    ctrl_label="EMPTY",
):
    ctrl_means = (
        embeddings.loc[embeddings.loc[:, label_col] == ctrl_label]
        .groupby(batch_col)
        .mean(numeric_only=True)
    )
    ctrl_stds = (
        embeddings.loc[embeddings.loc[:, label_col] == ctrl_label]
        .groupby(batch_col)
        .std(numeric_only=True)
    )
    return ctrl_means, ctrl_stds

In [None]:
def do_within_batch_normalization(
    embeddings,
    batch_col="Metadata_Batch",
    label_col="Metadata_Symbol",
    ctrl_label="EMPTY",
):
    ctrl_means, ctrl_stds = compute_batch_statistics(
        embeddings, batch_col=batch_col, label_col=label_col, ctrl_label=ctrl_label
    )
    print("Statistics computed.")
    scaled_embeddings = embeddings.copy()
    for batch in tqdm(set(scaled_embeddings.loc[:, batch_col])):
        for col in ctrl_means.columns:
            scaled_embeddings.loc[scaled_embeddings.loc[:, batch_col] == batch, col] = (
                np.array(
                    scaled_embeddings.loc[
                        scaled_embeddings.loc[:, batch_col] == batch, col
                    ]
                )
                - ctrl_means.loc[batch, col]
            )
            scaled_embeddings.loc[scaled_embeddings.loc[:, batch_col] == batch, col] = (
                np.array(
                    scaled_embeddings.loc[
                        scaled_embeddings.loc[:, batch_col] == batch, col
                    ]
                )
                / ctrl_stds.loc[batch, col]
            )
    return scaled_embeddings

In [None]:
impactful_conditions = list(
    {
        "AKT1S1": 0,
        "ATF4": 1,
        "BAX": 2,
        "BCL2L11": 3,
        "BRAF": 4,
        "CASP8": 5,
        "CDC42": 6,
        "CDKN1A": 7,
        "CEBPA": 8,
        "CREB1": 9,
        "CXXC4": 10,
        "DIABLO": 11,
        "E2F1": 12,
        "ELK1": 13,
        "EMPTY": 14,
        "ERG": 15,
        "FGFR3": 16,
        "FOXO1": 17,
        "GLI1": 18,
        "HRAS": 19,
        "IRAK4": 20,
        "JUN": 21,
        "MAP2K3": 22,
        "MAP3K2": 23,
        "MAP3K5": 24,
        "MAP3K9": 25,
        "MAPK7": 26,
        "MOS": 27,
        "MYD88": 28,
        "PIK3R2": 29,
        "PRKACA": 30,
        "PRKCE": 31,
        "RAF1": 32,
        "RELB": 33,
        "RHOA": 34,
        "SMAD4": 35,
        "SMO": 36,
        "SRC": 37,
        "SREBF1": 38,
        "TRAF2": 39,
        "TSC2": 40,
        "WWTR1": 41,
    }.keys()
)

In [None]:
emb_output_dir = "../../../data/experiments/jump/images/embedding/embeddings"
os.makedirs(emb_output_dir, exist_ok=True)

---

# Get profile data

As a next step we will use the boto3 package to scrape all deposited profile data from the AWS S3 bucket that contains the JUMP data set.

## Download profiles

All profiles are deposited as .parquet files, we will thus screen for all data objects in the bucket that contain such a file ending.

In [None]:
output_dir = "../../../data/resources/images/jump/profiles"
os.makedirs(output_dir, exist_ok=True)

# s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
# objects = s3_client.list_objects_v2(
#     Bucket="cellpainting-gallery", Prefix="cpg0016-jump/source_3/workspace/profiles/"
# )
# for obj in tqdm(objects["Contents"]):
#     path = obj["Key"]
#     entries = path.split("/")
#     source = entries[1]
#     batch = entries[4]
#     plate_file = entries[-1]
#     s3_client.download_file(
#         "cellpainting-gallery",
#         obj["Key"],
#         os.path.join(output_dir, "_".join([source, batch, plate_file])),
#     )

---


## Read in profiles

Having downloaded all profile data, we will now read those and append the gene symbol information

In [None]:
all_morph_profiles = []
for file in tqdm(os.listdir(output_dir)):
    all_morph_profiles.append(pd.read_parquet(os.path.join(output_dir, file)))
morph_profiles = pd.concat(all_morph_profiles)

The respective metadata can be obtained from previously downloaded files.

In [None]:
well_data = pd.read_csv("../../../data/resources/images/jump/metadata/well.csv")
orf_data = pd.read_csv("../../../data/resources/images/jump/metadata/orf.csv")
plate_data = pd.read_csv("../../../data/resources/images/jump/metadata/plate.csv")

In [None]:
all_orf_data = well_data.merge(
    orf_data, right_on="Metadata_JCP2022", left_on="Metadata_JCP2022", how="inner"
)
all_orf_data = all_orf_data.merge(
    plate_data.loc[:, ["Metadata_Plate", "Metadata_Batch", "Metadata_PlateType"]],
    right_on="Metadata_Plate",
    left_on="Metadata_Plate",
)

We will now combine the two datasets to construct a joint profile data set that contains the derived morphological profiles for the different overexpression conditions.

In [None]:
morph_profiles["key"] = (
    morph_profiles["Metadata_Source"]
    + "_"
    + morph_profiles["Metadata_Plate"]
    + "_"
    + morph_profiles["Metadata_Well"]
)
all_orf_data["key"] = (
    all_orf_data["Metadata_Source"]
    + "_"
    + all_orf_data["Metadata_Plate"]
    + "_"
    + all_orf_data["Metadata_Well"]
)

selected_cols = ["key", "Metadata_Batch", "Metadata_Symbol"]
morph_profiles = morph_profiles.merge(
    all_orf_data.loc[:, selected_cols], on="key", how="inner"
)
morph_profiles = morph_profiles.drop(columns=["key"])
morph_profiles.describe()

In alignment with our analyses using directly the images, we define the condition using a "BFP" treatment as the control condition which we label as "EMPTY".

In [None]:
morph_profiles.loc[morph_profiles.Metadata_Symbol == "BFP", "Metadata_Symbol"] = "EMPTY"

---

## Read in other data

In addition to the morphological profiles from the JUMP data set, we now also read in the profiles we obtained from using our CNN encoder trained and evaluated on the images from the JUMP data set, as well as the image embeddings obtained for the dataset from Rohban et al.

In [None]:
jump_img_embs = pd.read_hdf(
    "../../../data/experiments/jump/images/embedding/specificity_target_emb_cv_strat/fold_0/nuclei_regions/test_latents.h5"
)
label_dict = {
    "AKT1S1": 0,
    "ATF4": 1,
    "BAX": 2,
    "BCL2L11": 3,
    "BRAF": 4,
    "CASP8": 5,
    "CDKN1A": 6,
    "CREB1": 7,
    "DIABLO": 8,
    "ELK1": 9,
    "EMPTY": 10,
    "ERG": 11,
    "FGFR3": 12,
    "HRAS": 13,
    "IRAK4": 14,
    "JUN": 15,
    "MAP3K2": 16,
    "MAP3K5": 17,
    "MAP3K9": 18,
    "MYD88": 19,
    "PIK3R2": 20,
    "PRKACA": 21,
    "PRKCE": 22,
    "RAF1": 23,
    "RELB": 24,
    "RHOA": 25,
    "SMAD4": 26,
    "SRC": 27,
    "SREBF1": 28,
    "TRAF2": 29,
    "TSC2": 30,
    "WWTR1": 31,
    "ALOX5": 32,
    "AMPH": 33,
    "AQP1": 34,
    "ARMCX2": 35,
    "AURKA": 36,
    "AURKB": 37,
    "AXL": 38,
    "BEX1": 39,
    "BIRC5": 40,
    "BMP4": 41,
    "BUB1": 42,
    "BUB1B": 43,
    "CCNA2": 44,
    "CCND2": 45,
    "CD40": 46,
    "CDC20": 47,
    "CDC42EP1": 48,
    "CDC45": 49,
    "CDCA3": 50,
    "CDCA8": 51,
    "CDK1": 52,
    "CDK14": 53,
    "CDK2": 54,
    "CDK6": 55,
    "CENPA": 56,
    "CKS2": 57,
    "CLSPN": 58,
    "CLU": 59,
    "CNN1": 60,
    "CRYAB": 61,
    "CYBA": 62,
    "DHRS2": 63,
    "DUSP6": 64,
    "EEF1A2": 65,
    "EFEMP1": 66,
    "EPB41L3": 67,
    "EXO1": 68,
    "FEN1": 69,
    "FGF1": 70,
    "FGFR2": 71,
    "FKBP4": 72,
    "FOS": 73,
    "FOXM1": 74,
    "FSTL1": 75,
    "GTSE1": 76,
    "HJURP": 77,
    "HK2": 78,
    "HSP90AB1": 79,
    "HSPA1B": 80,
    "IGF2": 81,
    "IGFBP5": 82,
    "INHBA": 83,
    "KIF2C": 84,
    "KLK6": 85,
    "KPNA2": 86,
    "KRT15": 87,
    "KRT18": 88,
    "KRT8": 89,
    "KRT80": 90,
    "KRT81": 91,
    "LBH": 92,
    "LDOC1": 93,
    "LGALS1": 94,
    "LOXL4": 95,
    "LRP1": 96,
    "LTBP2": 97,
    "LZTS2": 98,
    "MAGEA1": 99,
    "MAGEA4": 100,
    "MAGED1": 101,
    "MAPK8": 102,
    "MBP": 103,
    "MCM10": 104,
    "MCM3": 105,
    "MCM5": 106,
    "MCM7": 107,
    "MDFI": 108,
    "MDK": 109,
    "MSH2": 110,
    "MYH9": 111,
    "MYL2": 112,
    "NCF2": 113,
    "NDC80": 114,
    "NDN": 115,
    "NEFL": 116,
    "NEK2": 117,
    "NUF2": 118,
    "PAK2": 119,
    "PCLO": 120,
    "PCNA": 121,
    "PDGFRB": 122,
    "PLAT": 123,
    "PLK1": 124,
    "PRAME": 125,
    "PRC1": 126,
    "PRKCA": 127,
    "PRSS23": 128,
    "PTCH1": 129,
    "RACGAP1": 130,
    "RARA": 131,
    "RPS6KB1": 132,
    "RRM2": 133,
    "S100A14": 134,
    "SDC1": 135,
    "SDC2": 136,
    "SDC3": 137,
    "SERPINE1": 138,
    "SFN": 139,
    "SKP2": 140,
    "SPARC": 141,
    "SPC24": 142,
    "SPC25": 143,
    "STAC": 144,
    "STC2": 145,
    "SUV39H1": 146,
    "TCF4": 147,
    "TGFB1": 148,
    "TGM2": 149,
    "THRA": 150,
    "THY1": 151,
    "TIMP1": 152,
    "TINAGL1": 153,
    "TK1": 154,
    "TNC": 155,
    "TNNC1": 156,
    "TNNT1": 157,
    "TNNT2": 158,
    "TONSL": 159,
    "TPM1": 160,
    "TPM2": 161,
    "TPX2": 162,
    "TRIB3": 163,
    "TTK": 164,
    "TUBA1A": 165,
    "TUBB": 166,
    "TUBB2B": 167,
    "TUBB6": 168,
    "UBE2S": 169,
    "UCHL5": 170,
    "VCAN": 171,
    "VIM": 172,
    "YES1": 173,
    "YWHAQ": 174,
    "ZWINT": 175,
}

label_dict = dict(zip(list(label_dict.values()), list(label_dict.keys())))
jump_img_embs.loc[:, "labels"] = np.array(jump_img_embs.labels.map(label_dict))

In [None]:
idc = np.array(jump_img_embs.index)
source = []
batch = []
plate = []
well = []
exp_well = []
for idx in tqdm(idc):
    s = idx.split("_")
    source.append("_".join(s[0:2]))
    batch.append("_".join(s[2:6]))
    plate.append("_".join(s[7]))
    well.append("_".join(s[8]))
    exp_well.append("_".join(s[:8]))
jump_img_embs["batch"] = np.array(batch)
jump_img_embs["plate"] = np.array(plate)
jump_img_embs["well"] = np.array(well)
jump_img_embs["exp_well"] = np.array(exp_well)

---

In [None]:
rohban_img_embs = pd.read_hdf(
    "../../../data/experiments/rohban/images/embeddings/four_fold_cv/fold_0/test_latents.h5"
)

label_dict = {
    "AKT1S1": 0,
    "ATF4": 1,
    "BAX": 2,
    "BCL2L11": 3,
    "BRAF": 4,
    "CASP8": 5,
    "CDC42": 6,
    "CDKN1A": 7,
    "CEBPA": 8,
    "CREB1": 9,
    "CXXC4": 10,
    "DIABLO": 11,
    "E2F1": 12,
    "ELK1": 13,
    "EMPTY": 14,
    "ERG": 15,
    "FGFR3": 16,
    "FOXO1": 17,
    "GLI1": 18,
    "HRAS": 19,
    "IRAK4": 20,
    "JUN": 21,
    "MAP2K3": 22,
    "MAP3K2": 23,
    "MAP3K5": 24,
    "MAP3K9": 25,
    "MAPK7": 26,
    "MOS": 27,
    "MYD88": 28,
    "PIK3R2": 29,
    "PRKACA": 30,
    "PRKCE": 31,
    "RAF1": 32,
    "RELB": 33,
    "RHOA": 34,
    "SMAD4": 35,
    "SMO": 36,
    "SRC": 37,
    "SREBF1": 38,
    "TRAF2": 39,
    "TSC2": 40,
    "WWTR1": 41,
}
label_dict = dict(zip(list(label_dict.values()), list(label_dict.keys())))
rohban_img_embs.loc[:, "labels"] = np.array(rohban_img_embs.labels.map(label_dict))
mean_rohban_img_embs = rohban_img_embs.groupby("labels").mean()
mean_rohban_img_embs.describe()

---

# Preprocess profile data

We will now preprocess the downloaded profile data. To this end, we will perform similar filtering steps as we had done for the dataset by Rohban et al.

## Filter out non-DNA related features

As a first step, we will subset the profiles to only contain features that can be computed from the DNA channel and are thus related to nuclei. This removes any features from the other 4 channels and other image-level descriptions such as e.g. the background intensity information for the DNA channel.

In [None]:
filter_keys = ["er", "mito", "phgolgi", "agp", "rna", "cytoplasm", "image", "cells"]
drop_columns = []
for col in morph_profiles.columns:
    if any(key in col.lower() for key in filter_keys):
        drop_columns.append(col)
len(drop_columns)

In [None]:
morph_profiles_dna = morph_profiles.drop(columns=drop_columns)

---

## Filter out positional features.

The previous filtering step removes 4,589/4,766 features. We next filter out any positional features, i.e. features with positional indicators such as X and Y.

In [None]:
pos_columns = []
for col in morph_profiles_dna.columns:
    if "_x" in col.lower() or "_y" in col.lower():
        pos_columns.append(col)
len(pos_columns)

In [None]:
morph_profiles_dna = morph_profiles_dna.drop(columns=pos_columns)

---

## Filter out constant features and incomplete entries

The previous filtering reduce the total number of features to removes 171. We will now filter out any features, for which we observe a constant value as well as any rows where any of the features was not measured.

In [None]:
print(np.array(morph_profiles_dna).shape, flush=True)
morph_profiles_dna = morph_profiles_dna.loc[
    :, (morph_profiles_dna != morph_profiles_dna.iloc[0]).any()
]
morph_profiles_dna = morph_profiles_dna.dropna()
print(np.array(morph_profiles_dna).shape, flush=True)

This filtering steps removes 2 additional features and 804 observations.

---

# Exploratory data analyses

To get an idea of the dataset characteristics and importantly the similarities of different overexpression conditions, we will perform a number of exploratory data analyses. To this end, we will subset the dataset to the 176 conditions that we will focus on in our consecutive analyses to validate our proposed computational platform called Image2Reg.

In [None]:
with open("../../../../data/other/target_lists/jump_targets.pkl", "rb") as handle:
    selected_conditions = sorted(list(pickle.load(handle)) + ["EMPTY"])
print(selected_conditions)

In [None]:
jump_morph_profiles = morph_profiles_dna.loc[
    morph_profiles_dna.Metadata_Symbol.isin(selected_conditions)
]
jump_morph_profiles.describe()

In total there are 2,305 observations that correspond to the 175 selected overexpression conditions and the one control condition.

---

## UMAP visualization of the data set.

We now visualize the individual overexpression conditions using UMAP plots. In each plot, we will highlight one of the overexpression conditions in color and plot the observations of all other conditions in gray.

In [None]:
features = jump_morph_profiles._get_numeric_data()
labels = jump_morph_profiles.Metadata_Symbol

mapper = UMAP(random_state=seed)
norm_features = StandardScaler().fit_transform(features)
embs = mapper.fit_transform(norm_features)
embs = pd.DataFrame(embs, index=features.index, columns=["umap_0", "umap_1"])
embs["label"] = labels.loc[embs.index]
embs["plate"] = jump_morph_profiles.loc[embs.index, "Metadata_Plate"]
embs["well"] = jump_morph_profiles.loc[embs.index, "Metadata_Well"]
embs["batch"] = jump_morph_profiles.loc[embs.index, "Metadata_Batch"]

In [None]:
fig, ax = plt.subplots(figsize=[18, 12])
ax = sns.scatterplot(
    data=embs,
    x="umap_0",
    y="umap_1",
    ax=ax,
    s=16,
)
plt.show()

---

## Analyses of batch effects

### Visualization of the batch effects

We notice that samples from the same batch cluster together in the UMAP plot, which indicates the presence of prominent batch effects.

In [None]:
fig, ax = plt.subplots(figsize=[12, 8])
ax = sns.scatterplot(
    data=embs,
    x="umap_0",
    y="umap_1",
    hue="batch",
    ax=ax,
    s=6,
    palette="tab20",
    hue_order=np.unique(embs.batch),
)
ax.set_xlabel("")
ax.set_ylabel("")
ax.legend(title="Batch")
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=[12, 8])
ax = sns.scatterplot(
    data=embs.loc[embs.label == "EMPTY"],
    x="umap_0",
    y="umap_1",
    hue="batch",
    ax=ax,
    s=6,
    palette="tab20",
    hue_order=np.unique(embs.batch),
)
ax.set_xlabel("")
ax.set_ylabel("")
ax.legend(title="Batch")
plt.show()

This shows the presence of dominant batch effects. Since the morphological profiles were computed based on the features we considered, these batch effects might confound our analyses in terms of the similarity of the chromatin states in the different overexpression conditions.

---

### Different OE conditions in the batches

We notice that not all of the overexpression conditions were assessed in each batch, but only a subset of our selected OE conditions was targeted in each batch.

In [None]:
per_batch = jump_morph_profiles.groupby("Metadata_Batch").Metadata_Symbol.nunique()
per_batch = pd.DataFrame(
    np.array(per_batch), index=per_batch.index, columns=["n_conditions"]
)
per_batch["batch"] = np.array(per_batch.index)
per_batch["frac_conditions"] = per_batch.n_conditions / len(
    set(jump_morph_profiles.Metadata_Symbol)
)

In [None]:
fig, ax = plt.subplots(figsize=[6, 4])
ax = sns.barplot(data=per_batch, x="batch", y="frac_conditions", palette=["b"])
ax.set_ylabel("Coverage of selected conditions")
ax.set_xlabel("Batch")
plt.xticks(rotation=90)
plt.show()

As we see the maximum number of the considered 176 condition that is covered by an individual batch is 15,9%, which correspond to 28 of the 176 conditions. Subsetting the data such that we only use data of a single batch is thus no feasible option.

---

### Different specific targets in the batches

Naively one could circumvent the problems of the different batch effects of simply executing our pipeline individually for each batch data. However, we see that only a very small fraction of the overall small number of specific targets is in each batch, which would make it difficult for the pipeline to generalize to new unseen conditions.

In [None]:
per_batch_spec = (
    jump_morph_profiles.loc[
        jump_morph_profiles.Metadata_Symbol.isin(impactful_conditions)
    ]
    .groupby("Metadata_Batch")
    .Metadata_Symbol.nunique()
)
per_batch_spec = pd.DataFrame(
    np.array(per_batch_spec), index=per_batch_spec.index, columns=["n_conditions"]
)
per_batch_spec["batch"] = np.array(per_batch_spec.index)
per_batch_spec["frac_conditions"] = per_batch_spec.n_conditions / len(
    set(jump_morph_profiles.Metadata_Symbol)
)

In [None]:
fig, ax = plt.subplots(figsize=[6, 4])
ax = sns.barplot(data=per_batch_spec, x="batch", y="n_conditions")
ax.set_ylabel("Coverage of selected, impactful conditions")
ax.set_xlabel("Batch")
plt.xticks(rotation=90)
plt.show()

In [None]:
per_batch_spec.describe()

----

### Summary of OE condition-batch distribution

To better understand which exact OE conditions are grouped in which batch and importantly also which ones are replicated as well as how do the test (non-specific OE) conditions colocalize on a batch level with the training conditions, i.e. the conditions where genes were targeted that we identified to impact the chromatin organization and nuclear morphology of U2OS cells in the data set Rohban. 

In [None]:
cond_batch_coocc_matrix = pd.DataFrame(
    index=sorted(list(selected_conditions)),
    columns=np.unique(jump_morph_profiles.Metadata_Batch),
)
for idx in cond_batch_coocc_matrix.index:
    for col in cond_batch_coocc_matrix.columns:
        cond_batch_coocc_matrix.loc[idx, col] = int(
            (
                len(
                    jump_morph_profiles.loc[
                        (jump_morph_profiles.Metadata_Symbol == idx)
                        & (jump_morph_profiles.Metadata_Batch == col)
                    ]
                )
                > 0
            )
        )
cond_batch_coocc_matrix = pd.DataFrame(
    np.array(cond_batch_coocc_matrix).astype(int),
    index=cond_batch_coocc_matrix.index,
    columns=cond_batch_coocc_matrix.columns,
)

In [None]:
fig, ax = plt.subplots(figsize=[12, 3])
ax = sns.heatmap(
    cond_batch_coocc_matrix.transpose().iloc[:, :60],
    # row_cluster=False,
    # dendrogram_ratio=0.001,
    # cbar_pos=None,
    # figsize=[12, 4],
    ax=ax,
    cbar=False,
    cmap="gray",
)
for tick_label in ax.get_xticklabels():
    tick_text = tick_label.get_text()
    if tick_text in impactful_conditions:
        tick_label.set_color("r")
        tick_label.set_fontweight("bold")
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=[12, 3])
ax = sns.heatmap(
    cond_batch_coocc_matrix.transpose().iloc[:, 60:120],
    # row_cluster=False,
    # dendrogram_ratio=0.001,
    # cbar_pos=None,
    # figsize=[12, 4],
    ax=ax,
    cbar=False,
    cmap="gray",
)
for tick_label in ax.get_xticklabels():
    tick_text = tick_label.get_text()
    if tick_text in impactful_conditions:
        tick_label.set_color("r")
        tick_label.set_fontweight("bold")
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=[12, 3])
ax = sns.heatmap(
    cond_batch_coocc_matrix.transpose().iloc[:, 120:],
    # row_cluster=False,
    # dendrogram_ratio=0.001,
    # cbar_pos=None,
    # figsize=[12, 4],
    ax=ax,
    cbar=False,
    cmap="gray",
)
for tick_label in ax.get_xticklabels():
    tick_text = tick_label.get_text()
    if tick_text in impactful_conditions:
        tick_label.set_color("r")
        tick_label.set_fontweight("bold")
plt.show()

The above map shows that most conditions are not replicated (111/175) or replicated in two batches (43/175) only. Only 22 conditions are replicated in more than 2 batches and solely the control condition is replicated in all batches.

In [None]:
Counter(cond_batch_coocc_matrix.sum(axis=1))

---

### Characterization of the batch effects

#### Prediction of OE condition purely from batch labels

In [None]:
tmp = jump_morph_profiles.loc[
    jump_morph_profiles.Metadata_Symbol.isin(impactful_conditions)
]
tmp_labels = np.array(tmp.Metadata_Symbol)
tmp_batch_labels = np.array(tmp.Metadata_Batch).reshape(-1, 1)
tmp_batch_labels = OneHotEncoder().fit_transform(tmp_batch_labels)

In [None]:
clf = LogisticRegression(
    class_weight="balanced",
    multi_class="multinomial",
    solver="lbfgs",
    max_iter=1000,
    penalty="none",
)

bac = cross_val_score(
    clf,
    tmp_batch_labels,
    tmp_labels,
    scoring="balanced_accuracy",
    cv=StratifiedKFold(n_splits=5),
)
print(bac.mean(), bac.std())

We see that simply by knowing which batch a specific samples comes from a Multinomial regression model can predict the correct overexpression condition with a balanced accuracy of 0.3218 (+/- 0.0212) which is more than half of the performance that we obtained by training the classifier directly on the images. This emphasizes how much information is simply contained in the batch information.

#### RandomForest classification of control samples

In [None]:
ctrl_profiles = jump_morph_profiles.loc[jump_morph_profiles.Metadata_Symbol == "EMPTY"]
ctrl_features = ctrl_profiles._get_numeric_data()
ctrl_features_norm = StandardScaler().fit_transform(ctrl_features)
ctrl_features_norm = pd.DataFrame(
    ctrl_features_norm, index=ctrl_features.index, columns=ctrl_features.columns
)
ctrl_batch_labels = ctrl_profiles.Metadata_Batch

In [None]:
rfc = RandomForestClassifier(random_state=1234)
bacs = cross_validate(
    rfc,
    X=ctrl_features_norm,
    y=ctrl_batch_labels,
    groups=ctrl_profiles.Metadata_Plate,
    cv=StratifiedGroupKFold(n_splits=10),
    scoring="balanced_accuracy",
)["test_score"]
print(
    "Balanced average batch classification accuracy: {} (+/-{})".format(
        np.round(bacs.mean(), 4), np.round(bacs.std(), 4)
    )
)

First, we find that RFC can distinguish between the different batches of the control cells with an average balanced accuracy of roughly 0.8075 (+/- 0.0393), which is significantly larger than the random baseline equivalent to 0.076. Since the evaluation was done using a stratified group 10-fold cross-validation, where each group was set to the used plate, we also can rule the performance to be impacted by plate effects.

In [None]:
rfc = rfc.fit(ctrl_features_norm, ctrl_batch_labels)
importances = rfc.feature_importances_
importances = pd.DataFrame(
    importances, index=ctrl_features_norm.columns, columns=["mdi"]
)
importances["feature"] = np.array(importances.index)
importances = importances.sort_values("mdi", ascending=False)

fig, ax = plt.subplots(figsize=[6, 4])
ax = sns.barplot(data=importances.iloc[:15], y="feature", x="mdi", palette=["gray"])
ax.set_title("Feature importances using MDI")
ax.set_xlabel("Mean decrease in impurity")
plt.show()

The above feature importance plot indicates that the batch effects are not due to changes of the imaging conditions reflected in differences of intensity-based features such as the texture measurements. Instead,we also observe morphological changes and variability of the spatial DNA organization between the different batches.

The boxplots below visualize the differences for 4 chosen morphological profile features.

In [None]:
fig, ax = plt.subplots(figsize=[12, 8], ncols=2, nrows=2, sharex=True)
ax = ax.flatten()

plot_feats = [
    "Nuclei_AreaShape_Area",
    "Nuclei_AreaShape_FormFactor",
    "Nuclei_Granularity_1_DNA",
    "Nuclei_RadialDistribution_FracAtD_DNA_4of4",
]
for i in range(4):
    ax[i] = sns.boxplot(
        data=ctrl_profiles, x="Metadata_Batch", y=plot_feats[i], ax=ax[i]
    )
    for tick in ax[i].get_xticklabels():
        tick.set_rotation(90)
plt.show()

Note that these batch effects also persist for any of the other negative controls (eGFP, HcRed, LacZ and Luciferase). There is no condition available where the cells were not treated at all for the JUMP dataset.

---

### Linearity of batch effects

We next analyze if the batch effects are captured by the principal components of the data set.

In [None]:
pc = PCA(n_components=2)
pc_embs = pc.fit_transform(norm_features)
pc_embs = pd.DataFrame(pc_embs, index=features.index, columns=["pc_0", "pc_1"])
pc_embs["label"] = labels.loc[pc_embs.index]
pc_embs["plate"] = jump_morph_profiles.loc[pc_embs.index, "Metadata_Plate"]
pc_embs["well"] = jump_morph_profiles.loc[pc_embs.index, "Metadata_Well"]
pc_embs["batch"] = jump_morph_profiles.loc[pc_embs.index, "Metadata_Batch"]

In [None]:
fig, ax = plt.subplots(figsize=[12, 8])
ax = sns.scatterplot(
    data=pc_embs,
    x="pc_0",
    y="pc_1",
    hue="batch",
    s=16,
    palette="tab20",
    hue_order=np.unique(embs.batch),
)
ax.set_xlabel("PC 0 ({})".format(np.round(pc.explained_variance_ratio_[0], 4)))
ax.set_ylabel("PC 1 ({})".format(np.round(pc.explained_variance_ratio_[1], 4)))

In [None]:
fig, ax = plt.subplots(figsize=[12, 8])
ax = sns.scatterplot(
    data=pc_embs.loc[pc_embs.label == "EMPTY"],
    x="pc_0",
    y="pc_1",
    hue="batch",
    s=16,
    palette="tab20",
    hue_order=np.unique(embs.batch),
)
ax.set_xlabel("PC 0 ({})".format(np.round(pc.explained_variance_ratio_[0], 4)))
ax.set_ylabel("PC 1 ({})".format(np.round(pc.explained_variance_ratio_[1], 4)))

We notice that roughly 78.5% of the total variance of the data is explained by the first two principal components and in particular the first principal component seems to capture the variance from the batches.

---

### Comparison to Deep learning embeddings

We will now compare that to the embeddings using our proposed image encoder model.

In [None]:
jump_image_mean_embs = jump_img_embs.groupby("exp_well").mean(numeric_only=True)
jump_image_labels = jump_img_embs.groupby("exp_well").labels.unique()
jump_image_labels = pd.Series(
    np.concatenate(jump_image_labels), index=jump_image_labels.index
)
jump_batch_labels = jump_img_embs.groupby("exp_well").batch.unique()
jump_batch_labels = pd.Series(
    np.concatenate(jump_batch_labels), index=jump_batch_labels.index
)
jump_plate_labels = jump_img_embs.groupby("exp_well").plate.unique()
jump_plate_labels = pd.Series(
    np.concatenate(jump_plate_labels), index=jump_plate_labels.index
)

To mimic the previous results, we compute an average embedding for each experimental well. We will now visualize the data set using a UMAP plot and highlight the data from the different batches.

In [None]:
mapper = UMAP(random_state=seed)
jump_umap_embs = mapper.fit_transform(jump_image_mean_embs)
jump_umap_embs = pd.DataFrame(
    jump_umap_embs, index=jump_image_mean_embs.index, columns=["umap_0", "umap_1"]
)
jump_umap_embs["label"] = jump_image_labels.loc[jump_umap_embs.index]
jump_umap_embs["plate"] = jump_plate_labels.loc[jump_umap_embs.index]
jump_umap_embs["batch"] = jump_batch_labels.loc[jump_umap_embs.index]

In [None]:
fig, ax = plt.subplots(figsize=[12, 8])
ax = sns.scatterplot(
    data=jump_umap_embs,
    x="umap_0",
    y="umap_1",
    hue="batch",
    ax=ax,
    s=6,
    palette="tab20",
    hue_order=np.unique(embs.batch),
)
ax.set_xlabel("")
ax.set_ylabel("")
ax.legend(title="Batch")
plt.show()

As expected the batch effects are also visible when looking at our deep image embeddings.

This clearly shows that we need overcome some of these problems with more sophisticated methods to account for the large batch effects.

---

## Comparison to Rohban et al.

We will now compare both the JUMP image embeddings and morphological profiles to what we have seen for the dataset from Rohban et al (2017). To visualize the (dis)similarity we will cluster all by-condition averaged feature representations to the ones obtained for Rohban et al using a co-cluster map.

In [None]:
mean_jump_morph_profiles = jump_morph_profiles.groupby("Metadata_Symbol").mean()
compare_clustering_to_rohban(
    mean_jump_morph_profiles,
    mean_rohban_img_embs,
    metric="euclidean",
    method="complete",
)

---

# Overcoming the batch effects

Based on the previous results, it is obvious that we need to apply methods to reduce the variability of the data due to the different batches. In particular, we propose the following strategies to obtain chromatin image embeddings that capture the differences of the chromatin states due to the different overexpression rather than the different batch effects:
1. Remove the 4 batches that show the largest inter-bach variability: Batch 3,4,10 and 13 and train our model solely based on the models solely based on data of the other batches.
2. Perform iterative null-space projection of the obtained embeddings to project the data on a space that does not contain batch information, but hopefully maintains batch-independent information about the chromatin states of the cells in the different overexpression conditions.
3. While training the classifier provide the batch information as an additional feature before the final layer to make the classifier use the image embeddings to derive features related to the OE conditions.
4. Normalize the embeddings by subtracting each embedding by the mean embedding of the control setting in the corresponding batch.

While we will analyze the effect of those steps in more detail in another notebook, we here will exemplarly show the results of the application of the steps 2 and 4. on the morphological profiles.

---

## Remove outlier batches

As a first step, we will remove the previously mentioned four batches that look very dissimilar from the others and assess which effect that has on the number of impactful and covered gene settings.

---

In [None]:
outlier_batches = [
    "2021_05_10_Batch3",
    "2021_05_17_Batch4",
    "2021_08_02_Batch10",
    "2021_08_30_Batch13",
]

outlier_orf_data = all_orf_data.loc[all_orf_data.Metadata_Batch.isin(outlier_batches)]
other_orf_data = all_orf_data.loc[~all_orf_data.Metadata_Batch.isin(outlier_batches)]

In [None]:
outlier_targets = set(outlier_orf_data.Metadata_Symbol)
other_targets = set(other_orf_data.Metadata_Symbol)

In [None]:
venn(
    {
        "Conditions in outlier batches": outlier_targets,
        "Conditions in other batches": other_targets,
        "Impactful conditions": set(impactful_conditions),
        "Conditions covered by the GGI": set(selected_conditions),
    }
)

We notice that by removing those four batches we eventually loose coverage of 27/175 including 5/33 impactful gene perturbation settings. The conditions we loose are the following:

In [None]:
print(
    "Lost impactful gene targets:",
    outlier_targets.intersection(impactful_conditions) - other_targets,
)
print(
    "Lost gene targets covered in GGI (prediction candidates):",
    outlier_targets.intersection(selected_conditions) - other_targets,
)

Unfortunately, three of the 5 impactful condition that are only contained in the identified outlier batches are CXXC4, BCL2L11 and CASP8 all three were observed to be highly toxic upon overexpression in the previous study by Rohban et al. Not having those genes is likely to reduce our coverage of the chromatin states of cells for overexpression settings related to the cellular apoptosis pathway.

However, only CASP8 would be additionally covered of those three if we would restrict ourselves to only removing the two batches that showed the strongest batch effects namely Batch 3 and 4. Thus, for now we will proceed with only using the other 9 batches, that are not Batch 3,4,10 or 13.

---

## Batch-specific normalization.

As alternative or addition to the batch subsetting, we also could normalize the morphological profiles by computing the mean and the standard deviation for the control setting for each batch or even for each plate individually and normalize the the morphological profiles in a given batch or plate by the respective statistics.

### Morphological profiles

We first assess the performance of normalizing the embeddings by centering them to the average of the mean conditions in each batch for the morphological profiles.

In [None]:
jump_morph_profiles_bc = do_within_batch_normalization(jump_morph_profiles)

In [None]:
jump_morph_profiles_bc["plate"] = jump_morph_profiles.loc[
    jump_morph_profiles_bc.index, "Metadata_Plate"
]
jump_morph_profiles_bc["well"] = jump_morph_profiles.loc[
    jump_morph_profiles_bc.index, "Metadata_Well"
]
jump_morph_profiles_bc["batch"] = jump_morph_profiles.loc[
    jump_morph_profiles_bc.index, "Metadata_Batch"
]
jump_morph_profiles_bc["label"] = jump_morph_profiles.loc[
    jump_morph_profiles_bc.index, "Metadata_Symbol"
]

Let us now check how the scaled data looks in UMAP space.

In [None]:
features = jump_morph_profiles_bc._get_numeric_data()
labels = jump_morph_profiles_bc.label

mapper = UMAP(random_state=seed)
norm_features = StandardScaler().fit_transform(features)
embs = mapper.fit_transform(norm_features)
embs = pd.DataFrame(embs, index=features.index, columns=["umap_0", "umap_1"])
embs["label"] = labels.loc[embs.index]
embs["plate"] = jump_morph_profiles_bc.loc[embs.index, "plate"]
embs["well"] = jump_morph_profiles_bc.loc[embs.index, "well"]
embs["batch"] = jump_morph_profiles_bc.loc[embs.index, "batch"]

fig, ax = plt.subplots(figsize=[12, 8])
ax = sns.scatterplot(
    data=embs,
    x="umap_0",
    y="umap_1",
    hue="batch",
    ax=ax,
    s=16,
    palette="tab20",
    hue_order=np.unique(embs.batch),
)
plt.show()


The UMAP plot suggests that the scale of the batch effects could be effectively reduced using the normalization approach. However, to check the strength of the batch effects, we will also look at the classification performance of multinomial regression model using the centered data to predict the corresponding batch labels.

In [None]:
ctrl_profiles_bc = jump_morph_profiles_bc.loc[
    jump_morph_profiles_bc.Metadata_Symbol == "EMPTY"
]
ctrl_features_bc = ctrl_profiles_bc._get_numeric_data()
ctrl_features_bc_norm = StandardScaler().fit_transform(ctrl_features_bc)
ctrl_features_bc_norm = pd.DataFrame(
    ctrl_features_bc_norm,
    index=ctrl_features_bc.index,
    columns=ctrl_features_bc.columns,
)
ctrl_batch_labels = ctrl_profiles_bc.Metadata_Batch

In [None]:
rfc = RandomForestClassifier(random_state=1234)
bacs = cross_validate(
    rfc,
    X=ctrl_features_bc_norm,
    y=ctrl_batch_labels,
    groups=ctrl_profiles.Metadata_Plate,
    cv=StratifiedGroupKFold(n_splits=10),
    scoring="balanced_accuracy",
)["test_score"]
print(
    "Balanced average batch classification accuracy: {} (+/-{})".format(
        np.round(bacs.mean(), 4), np.round(bacs.std(), 4)
    )
)

In agreement with the previous UMAP plot, we observe the classification accuracy of a RandomForest classifier to drop to from 81% to 31% after the centering approach. However, the performance is still roughly 4 times of what we would expect by random chance (i.e. 7.6%), which shows that the centering did not fully resolve the batch problems.

We are similarly interested in how similar the overall structure of those embeddings look to what we have observed using our CNN-based image embeddings in the data set from Rohban et al.

To assess that we will plot the clustered pair-wise euclidean distance matrix for the two embeddings.

In [None]:
mean_jump_morph_profiles_bc = jump_morph_profiles_bc.groupby("label").mean()
compare_clustering_to_rohban(
    mean_jump_morph_profiles_bc,
    mean_rohban_img_embs,
    metric="euclidean",
    method="complete",
)

Looking at the above clustermaps we see that in both embeddings the conditions of JUN and ERG as well as BCL2L11 and CASP8 prominently cluster together. While the other clustering structure looks slighltly different, we also observe genes of the MAPK pathway to cluster together in both the embeddings from Rohban et al as well as the batch-centered morphological profiles from the JUMP data set.

---

### Deep image embeddings

We next validate if this also holds true using our deep-learning based image embeddings. At first, we will do that on an image level as well.

In [None]:
jump_img_embs_bc = do_within_batch_normalization(
    jump_img_embs, batch_col="batch", label_col="labels"
)

In [None]:
jump_image_mean_embs_bc = jump_img_embs_bc.groupby("exp_well").mean(numeric_only=True)
jump_image_labels = jump_img_embs_bc.groupby("exp_well").labels.unique()
jump_image_labels = pd.Series(
    np.concatenate(jump_image_labels), index=jump_image_labels.index
)
jump_batch_labels = jump_img_embs_bc.groupby("exp_well").batch.unique()
jump_batch_labels = pd.Series(
    np.concatenate(jump_batch_labels), index=jump_batch_labels.index
)
jump_plate_labels = jump_img_embs_bc.groupby("exp_well").plate.unique()
jump_plate_labels = pd.Series(
    np.concatenate(jump_plate_labels), index=jump_plate_labels.index
)

In [None]:
mapper = UMAP(random_state=seed)
jump_umap_embs_bc = mapper.fit_transform(jump_image_mean_embs_bc)
jump_umap_embs_bc = pd.DataFrame(
    jump_umap_embs_bc, index=jump_image_mean_embs_bc.index, columns=["umap_0", "umap_1"]
)
jump_umap_embs_bc["label"] = jump_image_labels.loc[jump_umap_embs_bc.index]
jump_umap_embs_bc["plate"] = jump_plate_labels.loc[jump_umap_embs_bc.index]
jump_umap_embs_bc["batch"] = jump_batch_labels.loc[jump_umap_embs_bc.index]

In [None]:
fig, ax = plt.subplots(figsize=[18, 12])
ax = sns.scatterplot(
    data=jump_umap_embs_bc,
    x="umap_0",
    y="umap_1",
    hue="batch",
    ax=ax,
    s=16,
    palette="tab20",
    hue_order=np.unique(embs.batch),
)
plt.show()

Similarly to the observations for the morphological profiles we observe significantly less batch effects after the normalization. However the samples from batch 3 and 4 still group somewhat differently in addition to a few outliers from batch 12.

As before, we next analyze the similarity of the clustering of the batch-centered JUMP image embeddings and the image embeddings from Rohban et al.

In [None]:
mean_jump_img_embs_bc = jump_img_embs_bc.groupby("labels").mean()
compare_clustering_to_rohban(
    mean_jump_img_embs_bc, mean_rohban_img_embs, metric="euclidean", method="complete"
)

The clustermap above shows that the batch-centering does not as well preserve the cluster structure that we had previously observed for the dataset from Rohban et al. This can be explained by the fact that since the classifier is trained to distinguish between different conditions and some are present in only a few batches during training the batch effects are enforced and non-linearly reflected in the obtained embeddings as it helps the classifier to identify the corresponding conditions.

----

## Iterative null-space projection

Thus, in particular for the embeddings obtained from our CNN-based methods, more sophisticated methods to remove batch effects are needed. As described before our third approach uses the the iterature null-space projection algorithm.

The iterative null-space projection algorithm aims to remove batch effects by repeating the following procedure:

1. Given the embedding Z and the corresponding batch labels y, solve a linear model for $ZW^T = y$, where $Z \in \mathbb{R}^{n \times d}, W \in \mathbb{R}^{b\times d}$ and $y \in \mathbb{R}^{n\times b}$
2. Decompose $W=U\Sigma V^T$
3. Obtain batch-independent components of $V^T$ as $\tilde{V}^T$. that correspond to the dimensions multiplied with the 0 singular values.
3. Obtain projection of embeddings less dependent on the batch label as $\tilde{Z} = Z\tilde{V}$

This process is repeated until the linear model in step one can no longer distinguish well between the batches. Note that at each step the dimension is reduced by the number of batches, i.e. $b$ .

We implement that now briefly.

### Morphological profiles

We will now experiment how well the method works for the morphological profiles. Note that we will assess the method with and without a preceding batch-specific normalization as described before.

For now, we will also not filter out any batches.

Since, we have 13 batches in total, we will set the threshold on the 3-fold cross-validated balanced accuracy to 30% which is marginally above random chance (= 7.6%).

In [None]:
batch_labels = jump_morph_profiles.Metadata_Batch
features = jump_morph_profiles._get_numeric_data()
norm_features = StandardScaler().fit_transform(features)

In [None]:
norm_features.shape

In [None]:
jump_morph_profiles_insp, batch_bacs = iterative_null_space_projection(
    norm_features, np.array(batch_labels), balance_batches=True
)
jump_morph_profiles_insp = pd.DataFrame(
    jump_morph_profiles_insp, index=jump_morph_profiles.index
)
jump_morph_profiles_insp["batch"] = np.array(jump_morph_profiles.Metadata_Batch)
jump_morph_profiles_insp["label"] = np.array(jump_morph_profiles.Metadata_Symbol)

In [None]:
fig, ax = plt.subplots(figsize=[6, 4])
ax.plot(list(range(len(batch_bacs))), batch_bacs)
ax.scatter(list(range(len(batch_bacs))), batch_bacs)
ax.set_xlabel("Iteration")
ax.set_ylabel("Avg. 3-fold CV BAC")
ax.set_title("Evolution of the performance of the batch classifier")
plt.show()

We find that the threshold is hit after 8 iterations of the algorithm. Since at each iteration 13 features are removed the resulting reduced profiles contain 58 features.

A UMAP plot validates that the obtained embeddings do not group samples by batches.

In [None]:
mapper = UMAP(random_state=seed)
embs = mapper.fit_transform(
    StandardScaler().fit_transform(jump_morph_profiles_insp._get_numeric_data())
)
embs = pd.DataFrame(
    embs, index=jump_morph_profiles_insp.index, columns=["umap_0", "umap_1"]
)
embs["label"] = jump_morph_profiles.loc[embs.index, "Metadata_Symbol"]
embs["plate"] = jump_morph_profiles.loc[embs.index, "Metadata_Plate"]
embs["well"] = jump_morph_profiles.loc[embs.index, "Metadata_Well"]
embs["batch"] = jump_morph_profiles.loc[embs.index, "Metadata_Batch"]
fig, ax = plt.subplots(figsize=[12, 8])
ax = sns.scatterplot(
    data=embs, x="umap_0", y="umap_1", hue="batch", ax=ax, s=16, palette="tab20"
)
plt.show()

As seen above by the plot, the method indeed removes any kind of batch information from the data.

We next check if it however maintains the biological information regarding the similarities of certain OE conditions.

In [None]:
mean_jump_morph_profiles_insp = jump_morph_profiles_insp.groupby("label").mean()
compare_clustering_to_rohban(
    mean_jump_morph_profiles_insp, mean_rohban_img_embs, metric="euclidean"
)

In contrast to what we had seen for the batch-centered profiles, the embeddings we obtain as a result of the INSP directly used on the learned features seem to not reflect the cluster structure we had found in the data set by Rohban et al.

---

### Deep image embeddings

We now check the effect of the INSP on the embeddings using our obtained image embeddings for the CNN.

Since, we have 13 batches in total, we will set the threshold on the 3-fold cross-validated balanced accuracy to 10% which is marginally above random chance (= 7.6%).

In [None]:
batch_labels = jump_img_embs.batch
features = jump_img_embs._get_numeric_data()
norm_features = StandardScaler().fit_transform(features)

In [None]:
jump_img_embs_insp, batch_bacs = iterative_null_space_projection(
    norm_features,
    np.array(batch_labels),
    n_jobs=10,
    max_iter_clf=10,
    balance_batches=True,
    n_samples_per_batch=10000,
)
jump_img_embs_insp = pd.DataFrame(jump_img_embs_insp, index=jump_img_embs.index)
jump_img_embs_insp["batch"] = np.array(jump_img_embs.batch)
jump_img_embs_insp["label"] = np.array(jump_img_embs.labels)
jump_img_embs_insp["exp_well"] = np.array(jump_img_embs.exp_well)

In [None]:
fig, ax = plt.subplots(figsize=[6, 4])
ax.plot(list(range(len(batch_bacs))), batch_bacs)
ax.scatter(list(range(len(batch_bacs))), batch_bacs)
ax.set_xlabel("Iteration")
ax.set_ylabel("Avg. 3-fold CV BAC")
ax.set_title("Evolution of the performance of the batch classifier")
plt.show()

We find that the threshold is hit after 5 iterations of the algorithm. Since at each iteration 13 features are removed the resulting reduced profiles contain 1024-65 features.

A UMAP plot validates that the obtained embeddings do not group samples by batches.

In [None]:
jump_image_mean_embs_insp = jump_img_embs_insp.groupby("exp_well").mean(
    numeric_only=True
)
jump_image_labels = jump_img_embs_insp.groupby("exp_well").label.unique()
jump_image_labels = pd.Series(
    np.concatenate(jump_image_labels), index=jump_image_labels.index
)
jump_batch_labels = jump_img_embs_insp.groupby("exp_well").batch.unique()
jump_batch_labels = pd.Series(
    np.concatenate(jump_batch_labels), index=jump_batch_labels.index
)

In [None]:
mapper = UMAP(random_state=seed)
jump_umap_embs_insp = mapper.fit_transform(jump_image_mean_embs_insp)
jump_umap_embs_insp = pd.DataFrame(
    jump_umap_embs_insp,
    index=jump_image_mean_embs_insp.index,
    columns=["umap_0", "umap_1"],
)
jump_umap_embs_insp["label"] = jump_image_labels.loc[jump_umap_embs_bc.index]
jump_umap_embs_insp["batch"] = jump_batch_labels.loc[jump_umap_embs_bc.index]

In [None]:
fig, ax = plt.subplots(figsize=[18, 12])
ax = sns.scatterplot(
    data=jump_umap_embs_insp,
    x="umap_0",
    y="umap_1",
    hue="batch",
    ax=ax,
    s=16,
    palette="tab20",
    hue_order=np.unique(embs.batch),
)
plt.show()

As seen above by the plot, the method indeed removes any kind of batch information from the data.

We next check if it however maintains the biological information regarding the similarities of certain OE conditions.

In [None]:
mean_jump_img_embs_insp = jump_img_embs_insp.groupby("label").mean()
compare_clustering_to_rohban(
    mean_jump_img_embs_insp, mean_rohban_img_embs, metric="euclidean"
)

While the profiles at least highlight the most obvious cluster that is the one of CASP8 and BCL2L11 it does not as clearly show the clustering of JUN and ERG or the genes of the MAPK signaling pathways as e.g. the batch-centered profiles.

# Data export

We will now export the data for it to be used in later analyses.