# Data split for the 4-fold CV

---

This notebook can be used to generate data splits for the screening data to allow for parallel computing of the screens.

---

## 0. Environmental setup

In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
import os
from imblearn.under_sampling import RandomUnderSampler
from tqdm import tqdm

%load_ext nb_black

<IPython.core.display.Javascript object>

---
## 1. Read in data

To start we will read in the whole metadata set.

In [2]:
md = pd.read_csv(
    "../../../data/resources/images/rohban/profiles/nuclei_profiles_hoechst_cleaned.csv",
    index_col=0,
)
md.head()

Unnamed: 0,ImageNumber,Nuclei_AreaShape_Area,Nuclei_AreaShape_Center_X,Nuclei_AreaShape_Center_Y,Nuclei_AreaShape_Compactness,Nuclei_AreaShape_Eccentricity,Nuclei_AreaShape_EulerNumber,Nuclei_AreaShape_Extent,Nuclei_AreaShape_FormFactor,Nuclei_AreaShape_MajorAxisLength,...,Nuclei_Texture_SumEntropy_Hoechst_10_0,Nuclei_Texture_SumEntropy_Hoechst_3_0,Nuclei_Texture_SumEntropy_Hoechst_5_0,Nuclei_Texture_SumVariance_Hoechst_10_0,Nuclei_Texture_SumVariance_Hoechst_3_0,Nuclei_Texture_SumVariance_Hoechst_5_0,Nuclei_Texture_Variance_Hoechst_10_0,Nuclei_Texture_Variance_Hoechst_3_0,Nuclei_Texture_Variance_Hoechst_5_0,labels
0,1959,1055,505,897,1.05076,0.582823,1,0.775735,0.829029,41.2344,...,2.28077,2.46149,2.39829,6.95743,10.1031,8.70054,3.17906,2.92942,2.91259,EMPTY
1,2621,692,1048,833,1.02947,0.487037,1,0.697581,0.865848,32.1004,...,2.22891,2.50667,2.39297,6.23979,11.4035,8.74462,4.92492,3.89204,4.06769,PRKACA
2,1959,1130,254,912,1.12291,0.778182,1,0.713384,0.823856,48.128,...,2.09847,2.20948,2.15659,4.90396,7.23318,5.97824,2.78244,2.50232,2.56194,EMPTY
3,2621,945,1025,768,1.0303,0.562608,1,0.751192,0.888179,38.3836,...,2.09782,2.30247,2.21613,4.87518,7.33326,6.19254,2.48965,2.36182,2.38637,PRKACA
4,1959,1131,316,926,1.03167,0.581268,1,0.764189,0.880571,42.2784,...,2.3439,2.49636,2.44547,7.31812,10.2496,9.03631,2.87705,2.95287,2.8269,EMPTY


<IPython.core.display.Javascript object>

---

## 2. Grouped K-Fold

We will now go over each individual gene and add a column to the dataframe that indicates to which fold of the respective gene screen the respective nuclei corresponds to. Thereby, we will ensure that all nuclei from the same slide-image are in the same fold to ensure that the model cannot cheat by focussing on imaging artifacts specific to individual slide images.

In [3]:
def get_data_splits_for_label(
    data,
    label_col,
    target_list,
    n_folds,
    group_col,
    random_state=1234,
    val_size=0.2,
):

    # Subsample the data
    label_data = data.loc[data.loc[:, label_col].isin(target_list), :]
    if "EMPTY" in target_list:
        idc = np.array(list(range(len(label_data)))).reshape(-1, 1)
        labels = label_data.loc[:, label_col]
        idc, _ = RandomUnderSampler(
            sampling_strategy="majority", random_state=random_state
        ).fit_resample(idc, labels)
        label_data = label_data.iloc[idc.flatten(), :]

    # Split in folds
    features = np.array(list(range(len(label_data)))).reshape(-1, 1)
    labels = np.array(label_data.loc[:, label_col])
    groups = np.array(label_data.loc[:, group_col])

    fold_data = {"train": [], "val": [], "test": []}
    group_kfold = GroupKFold(n_splits=n_folds)
    for train_index, test_index in group_kfold.split(features, labels, groups=groups):

        train_val_fold_data = label_data.iloc[train_index]
        train_val_fold_labels = labels[train_index]
        train_val_fold_groups = groups[train_index]

        train_index, val_index = next(
            GroupShuffleSplit(
                test_size=val_size, n_splits=2, random_state=random_state
            ).split(
                train_val_fold_data, train_val_fold_labels, groups=train_val_fold_groups
            )
        )
        train_fold_data = train_val_fold_data.iloc[train_index]
        val_fold_data = train_val_fold_data.iloc[val_index]

        test_fold_data = label_data.iloc[test_index]

        fold_data["train"].append(train_fold_data)
        fold_data["val"].append(val_fold_data)
        fold_data["test"].append(test_fold_data)

    return fold_data

<IPython.core.display.Javascript object>

In [4]:
output_dir = "../../../data/experiments/rohban/images/screen/mroph_profiles/screen_splits/"
os.makedirs(output_dir)

<IPython.core.display.Javascript object>

In [5]:
label_col = "labels"
group_col = "ImageNumber"
random_state = 1234
n_folds = 4

<IPython.core.display.Javascript object>

In [9]:
labels = set(md.loc[:, label_col]) - set(["EMPTY"])
for label in tqdm(labels):
    fold_data = get_data_splits_for_label(
        data=md,
        label_col=label_col,
        target_list=[label, "EMPTY"],
        n_folds=n_folds,
        group_col=group_col,
        random_state=random_state,
    )
    label_output_dir = os.path.join(output_dir, label)
    os.makedirs(label_output_dir, exist_ok=True)
    for k, v in fold_data.items():
        for i in range(len(v)):
            fold_label_data = fold_data[k][i].to_csv(
                os.path.join(
                    label_output_dir, "morph_md_{}_fold_{}.csv.gz".format(k, i)
                )
            )

100%|██████████| 193/193 [25:55<00:00,  8.06s/it]


<IPython.core.display.Javascript object>

In [11]:
casp8_fold_data = get_data_splits_for_label(
    data=md,
    label_col="labels",
    target_list=["MAPK9", "EMPTY"],
    n_folds=4,
    group_col="ImageNumber",
    random_state=1234,
)

<IPython.core.display.Javascript object>

In [13]:
print(
    len(set(casp8_fold_data["train"][0].ImageNumber)),
    len(set(casp8_fold_data["val"][0].ImageNumber)),
    len(set(casp8_fold_data["test"][0].ImageNumber)),
)

564 142 234


<IPython.core.display.Javascript object>

In [14]:
from collections import Counter

print(
    Counter(casp8_fold_data["train"][0].labels),
    Counter(casp8_fold_data["val"][0].labels),
    Counter(casp8_fold_data["test"][0].labels),
)

Counter({'EMPTY': 2607, 'MAPK9': 2359}) Counter({'MAPK9': 925, 'EMPTY': 700}) Counter({'MAPK9': 1110, 'EMPTY': 1087})


<IPython.core.display.Javascript object>