# Tutorial on preparing train-test split for MORPH

To evaluate MORPH’s ability to generalize to unseen genetic perturbations, you can define custom train-test splits based on perturbation identities. Specifically, each test set will exclude a list of perturbations that are completely withheld during training.

To create your split file:
1. Create a `.csv` file named `[your_data_id]_splits.csv`.
2. In this file, define the following columns:

| Column        | Description                                                                 |
|---------------|-----------------------------------------------------------------------------|
| `data`        | The name of your dataset (i.e., `[your_data_id]`)                           |
| `test_set_id` | A unique identifier for this test set (i.e., `[your_test_set_id]`)          |
| `test_set`    | A comma-separated list of perturbation names to be held out for testing     |
| `note`        | (Optional) Any additional comments or notes for this split                  |

3. Save the file to `MORPH/data/`

#### Example 
You can refer to the following example split file for formatting: `MORPH/data/example_data_splits.csv`

In [2]:
import pandas as pd
split_df = pd.read_csv('../data/example_data_splits.csv')
split_df

Unnamed: 0,data,test_set_id,test_set,note
0,example_data,random_split_1,"MAP4K5,KMT2A,UBASH3B+PTPN12,COL2A1,KLF1+BAK1,C...","random splits that includes 0/2, 1/2, 2/2 unse..."


#### Example Code for 5-Fold Splits

5-fold splitting is a common cross-validation strategy where the list of perturbations is divided into 5 subsets (or "folds"). In each round, one fold is used as the test set, and the remaining four are used for training. This allows you to evaluate how well the model generalizes across different subsets of perturbations.

Below is an example code snippet to generate 5-fold train-test splits based on perturbations:

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.model_selection import KFold

data_path = 'your/path/to/your/data.h5ad' # Replace with your actual data path
csv_file_path = 'your/path/to/your/data_splits.csv' # Replace with your desired output path
data_id = 'your_data_id'  # Replace with your actual data ID

adata = sc.read_h5ad(data_path)

# Set random seed
random_seed = 12
np.random.seed(random_seed)

# Get unique perturbations excluding control
all_ptb_targets = adata.obs['gene'].unique().tolist()
all_ptb_targets = [g for g in all_ptb_targets if g != 'non-targeting']
print('Total perturbations:', len(all_ptb_targets))

# Shuffle and split into folds
np.random.shuffle(all_ptb_targets)
kf = KFold(n_splits=5)
all_ptb_targets = np.array(all_ptb_targets)
folds = []
for train_index, test_index in kf.split(all_ptb_targets):
    folds.append(all_ptb_targets[test_index].tolist())
    
# Create dataframe
df = pd.DataFrame(columns=['test_set_id', 'test_set', 'random_seed'])
for i, test_set in enumerate(folds):
    test_set_id = f'random_fold_{i + 1}'
    new_row = {
        'test_set_id': test_set_id,
        'test_set': ','.join(test_set),
        'random_seed': random_seed
    }
    df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)

# Add data_id column
df['data'] = data_id
# Save to CSV
df.to_csv(csv_file_path, index=False)