# Data split for the 4-fold CV

---

This notebook can be used to generate data splits for the screening data to allow for parallel computing of the screens.

---

## 0. Environmental setup

In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
import os
from imblearn.under_sampling import RandomUnderSampler
from tqdm import tqdm

---
## 1. Read in data

To start we will read in the whole metadata set.

In [2]:
md = pd.read_csv(
    "../../../data/experiments/rohban/images/preprocessing/full_pipeline/padded_nuclei_metadata.csv.gz",
    index_col=0,
)
md.head()

Unnamed: 0,plate,well,image_file,gene_id,gene_symbol,is_landmark,allele,expr_vec,toxicity,ie_blast,...,bb_height,minor_axis_length,major_axis_length,aspect_ratio,aspect_ratio_cluster,nuclei_count_image,slide_image_name,aspect_ratio_cluster_ratio,centroid_0,centroid_1
0,41744,k21,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,1977.0,EIF4E,0.0,WT.2,pLX304,,0.91,...,21,15.474579,35.755476,0.432789,1,53,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,0.528302,17.675294,19.835294
1,41744,k21,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,1977.0,EIF4E,0.0,WT.2,pLX304,,0.91,...,17,16.074259,48.34201,0.332511,1,53,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,0.528302,28.882645,273.150413
2,41744,k21,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,1977.0,EIF4E,0.0,WT.2,pLX304,,0.91,...,29,24.256958,35.030138,0.69246,0,53,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,0.528302,149.476762,1002.848576
3,41744,k21,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,1977.0,EIF4E,0.0,WT.2,pLX304,,0.91,...,36,27.689881,51.502812,0.537638,1,53,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,0.528302,250.939748,108.973921
4,41744,k21,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,1977.0,EIF4E,0.0,WT.2,pLX304,,0.91,...,27,26.587002,32.592086,0.81575,0,53,taoe005-u2os-72h-cp-a-au00044859_k21_s7_w10efe...,0.528302,252.834328,913.795522


---

## 2. Grouped K-Fold

We will now go over each individual gene and add a column to the dataframe that indicates to which fold of the respective gene screen the respective nuclei corresponds to. Thereby, we will ensure that all nuclei from the same slide-image are in the same fold to ensure that the model cannot cheat by focussing on imaging artifacts specific to individual slide images.

In [3]:
def get_data_splits_for_label(
    data,
    label_col,
    target_list,
    n_folds,
    group_col,
    random_state=1234,
    val_size=0.2,
):

    # Subsample the data
    label_data = data.loc[data.loc[:, label_col].isin(target_list), :]
    if "EMPTY" in target_list:
        idc = np.array(list(range(len(label_data)))).reshape(-1, 1)
        labels = label_data.loc[:, label_col]
        idc, _ = RandomUnderSampler(
            sampling_strategy="majority", random_state=random_state
        ).fit_resample(idc, labels)
        label_data = label_data.iloc[idc.flatten(), :]

    # Split in folds
    features = np.array(list(range(len(label_data)))).reshape(-1, 1)
    labels = np.array(label_data.loc[:, label_col])
    groups = np.array(label_data.loc[:, group_col])

    fold_data = {"train": [], "val": [], "test": []}
    group_kfold = GroupKFold(n_splits=n_folds)
    for train_index, test_index in group_kfold.split(features, labels, groups=groups):

        train_val_fold_data = label_data.iloc[train_index]
        train_val_fold_labels = labels[train_index]
        train_val_fold_groups = groups[train_index]

        train_index, val_index = next(
            GroupShuffleSplit(
                test_size=val_size, n_splits=2, random_state=random_state
            ).split(
                train_val_fold_data, train_val_fold_labels, groups=train_val_fold_groups
            )
        )
        train_fold_data = train_val_fold_data.iloc[train_index]
        val_fold_data = train_val_fold_data.iloc[val_index]

        test_fold_data = label_data.iloc[test_index]

        fold_data["train"].append(train_fold_data)
        fold_data["val"].append(val_fold_data)
        fold_data["test"].append(test_fold_data)

    return fold_data

In [4]:
output_dir = "../../../data/experiments/rohban/images/preprocessing/screen_splits/"
os.makedirs(output_dir)

In [5]:
label_col = "gene_symbol"
group_col = "slide_image_name"
random_state = 1234
n_folds = 4

In [6]:
labels = set(md.loc[:, label_col]) - set(["EMPTY"])
for label in tqdm(labels):
    fold_data = get_data_splits_for_label(
        data=md,
        label_col=label_col,
        target_list=[label, "EMPTY"],
        n_folds=n_folds,
        group_col=group_col,
        random_state=random_state,
    )
    label_output_dir = os.path.join(output_dir, label)
    os.makedirs(label_output_dir)
    for k, v in fold_data.items():
        for i in range(len(v)):
            fold_label_data = fold_data[k][i].to_csv(
                os.path.join(label_output_dir, "nuclei_md_{}_fold_{}.csv".format(k, i))
            )

100%|██████████| 193/193 [02:25<00:00,  1.33it/s]


In [7]:
casp8_fold_data = get_data_splits_for_label(
    data=md,
    label_col="gene_symbol",
    target_list=["MAPK9", "EMPTY"],
    n_folds=4,
    group_col="slide_image_name",
    random_state=1234,
)

In [8]:
print(
    len(set(casp8_fold_data["train"][0].slide_image_name)),
    len(set(casp8_fold_data["val"][0].slide_image_name)),
    len(set(casp8_fold_data["test"][0].slide_image_name)),
)

930 233 387


In [9]:
from collections import Counter

print(
    Counter(casp8_fold_data["train"][0].gene_symbol),
    Counter(casp8_fold_data["val"][0].gene_symbol),
    Counter(casp8_fold_data["test"][0].gene_symbol),
)

Counter({'EMPTY': 3841, 'MAPK9': 3700}) Counter({'MAPK9': 1091, 'EMPTY': 956}) Counter({'MAPK9': 1601, 'EMPTY': 1595})
