# This notebook accesses the output of the pediatric_cancer_atlas_profiling up to 2.feature_extraction and uses the relevant metadata, single cell feature and qc metrics to generate data splits for training, heldout and evaluation dataset based on cell line, plate and seeding density

In [1]:
import pathlib
import yaml

import numpy as np
import pandas as pd

## Read config

In [2]:
with open(pathlib.Path('.').absolute().parent / "config.yml", "r") as file:
    config = yaml.safe_load(file)

## Define paths to metadata, loaddata csvs and sc features

In [3]:
PROFILING_DIR = pathlib.Path(config['paths']['pediatric_cancer_atlas_profiling_path'])
DATASPLIT_OUTPUT_DIR = pathlib.Path('.') / 'data_split_loaddata'
DATASPLIT_OUTPUT_DIR.mkdir(exist_ok=True)

platemap_csv_path = PROFILING_DIR \
    / "0.download_data" / "metadata" / "platemaps"
assert platemap_csv_path.exists()

loaddata_csv_path = PROFILING_DIR \
    / "1.illumination_correction" / "loaddata_csvs"
assert loaddata_csv_path.exists()

qc_path = pathlib.Path('.').absolute() \
    / "preprocessing_output" / "qc_exclusion.csv"
assert qc_path.exists()

sc_features_parquet_path = pathlib.Path(config['paths']['sc_features_path'])

assert sc_features_parquet_path.exists()

## Train data split condition and other datasplit parameters

In [4]:
# Whether to remove sites with low QC score
QC = True

# Define columns in loaddata
SITE_COLUMN = 'Metadata_Site'
WELL_COLUMN = 'Metadata_Well'
PLATE_COLUMN = 'Metadata_Plate'

# Wells are uniquely identified by the combination of these columns
UNIQUE_IDENTIFIERS = [SITE_COLUMN, WELL_COLUMN, PLATE_COLUMN]

# Condition for train data (every other condition will be saved for evaluation)
TRAIN_CONDITION_KWARGS = {
    'cell_line': 'U2-OS',
    'platemap_file': 'Assay_Plate1_platemap',
    'seeding_density': [1_000, 2_000, 4_000, 8_000, 12_000]
}
# Conditions are uniquely identified by the combination of keys from TRAIN_CONDITION_KWARGS
CONDITIONS = list(TRAIN_CONDITION_KWARGS.keys())

## Load all barcode/platemap metadata and all loaddata csv files and merge

In [5]:
## Load platemap and well cell line metadata
barcode_df = pd.concat([pd.read_csv(f) for f in platemap_csv_path.glob('Barcode_*.csv')])

platemap_df = pd.DataFrame()
for platemap in barcode_df['platemap_file'].unique():
    df = pd.read_csv(platemap_csv_path / f'{platemap}.csv')
    df['platemap_file'] = platemap
    platemap_df = pd.concat([platemap_df, df])    
barcode_platemap_df = pd.merge(barcode_df, platemap_df, on='platemap_file', how='inner')

## QC removal
remove_sites = pd.read_csv(qc_path)

## Load data csvs
loaddata_df = pd.concat(
    [pd.read_csv(f) for f in loaddata_csv_path.glob('*.csv')], 
    ignore_index=True)

## Merge loaddata with barcode/platemap metadata to map condition to well
loaddata_barcode_platemap_df = pd.merge(
    barcode_platemap_df.rename(columns={'barcode': PLATE_COLUMN, 'well': WELL_COLUMN}),
    loaddata_df,
    on=[PLATE_COLUMN, WELL_COLUMN], 
    how='left')

## Perform QC removal
if QC:
    print(f"{loaddata_barcode_platemap_df.shape[0]} sites prior to QC")
    # Merge to correctly identify rows to be removed
    qc_merge_df = loaddata_barcode_platemap_df.merge(
        remove_sites, 
        on=UNIQUE_IDENTIFIERS, 
        how='left', 
        indicator=True
        )

    # Keep only rows that were NOT found in remove_sites
    loaddata_barcode_platemap_df = qc_merge_df[qc_merge_df['_merge'] == 'left_only'].drop(columns=['_merge'])
    print(f"{loaddata_barcode_platemap_df.shape[0]} sites after QC")

10249 sites prior to QC
9358 sites after QC


## Data split

In [6]:
loaddata_barcode_platemap_train_df = loaddata_barcode_platemap_df.copy()
## Filter load data csvs dynamically with CONDITION_KWARGS
for k, v in TRAIN_CONDITION_KWARGS.items():
    if isinstance(v, list):
        loaddata_barcode_platemap_train_df = loaddata_barcode_platemap_train_df[loaddata_barcode_platemap_train_df[k].isin(v)]
    else:
        loaddata_barcode_platemap_train_df = loaddata_barcode_platemap_train_df[loaddata_barcode_platemap_train_df[k] == v]
    if len(loaddata_barcode_platemap_train_df) == 0:
        raise ValueError(f'No data found for {k}={v}')
print(f"{loaddata_barcode_platemap_train_df.shape[0]} sites for train and heldout")

loaddata_barcode_platemap_eval_df = loaddata_barcode_platemap_df.loc[
    ~loaddata_barcode_platemap_df.index.isin(loaddata_barcode_platemap_train_df.index)
]
print(f"{loaddata_barcode_platemap_eval_df.shape[0]} sites for evaluation")

537 sites for train and heldout
8821 sites for evaluation


## For each unique condition combation in train/heldout split, hold out one well at random

In [7]:
seed = 42
np.random.seed(seed)

# Group by seeding density and cell line
grouped = loaddata_barcode_platemap_train_df.groupby(CONDITIONS)

# Initialize lists to store holdout and train data
heldout_list = []
train_list = []

# Iterate over each group
for _, group in grouped:

    held_out_well = [np.random.choice(group[WELL_COLUMN].unique())]
    train_wells = group[~group[WELL_COLUMN].isin(held_out_well)][WELL_COLUMN].unique()

    loaddata_held_out_df = group[group[WELL_COLUMN].isin(held_out_well)].copy()
    loaddata_train_df = group[group[WELL_COLUMN].isin(train_wells)].copy()

    condition = group[CONDITIONS].iloc[0].to_dict()
    print(f"For Condition: {condition} Heldout well: {held_out_well} Train wells: {train_wells}")

    heldout_list.append(loaddata_held_out_df)
    train_list.append(loaddata_train_df)

# Concatenate the lists into dataframes
loaddata_heldout_df = pd.concat(heldout_list).reset_index(drop=True)
print(f"{loaddata_heldout_df.shape[0]} sites Heldout")
loaddata_train_df = pd.concat(train_list).reset_index(drop=True)
print(f"{loaddata_train_df.shape[0]} sites for Training")

For Condition: {'cell_line': 'U2-OS', 'platemap_file': 'Assay_Plate1_platemap', 'seeding_density': 1000} Heldout well: ['M14'] Train wells: ['M13' 'N13' 'N14']
For Condition: {'cell_line': 'U2-OS', 'platemap_file': 'Assay_Plate1_platemap', 'seeding_density': 2000} Heldout well: ['N16'] Train wells: ['M15' 'N15' 'M16']
For Condition: {'cell_line': 'U2-OS', 'platemap_file': 'Assay_Plate1_platemap', 'seeding_density': 4000} Heldout well: ['M17'] Train wells: ['N17' 'M18' 'N18']
For Condition: {'cell_line': 'U2-OS', 'platemap_file': 'Assay_Plate1_platemap', 'seeding_density': 8000} Heldout well: ['M20'] Train wells: ['M19' 'N19' 'N20']
For Condition: {'cell_line': 'U2-OS', 'platemap_file': 'Assay_Plate1_platemap', 'seeding_density': 12000} Heldout well: ['M22'] Train wells: ['M21' 'N21' 'N22']
135 sites Heldout
402 sites for Training


In [8]:
loaddata_heldout_df.to_csv(DATASPLIT_OUTPUT_DIR / 'loaddata_heldout.csv')
loaddata_train_df.to_csv(DATASPLIT_OUTPUT_DIR / 'loaddata_train.csv')
loaddata_barcode_platemap_eval_df.to_csv(DATASPLIT_OUTPUT_DIR / 'loaddata_eval.csv')