# 1. Create Sample Sheets and Define Groups/Conditions

This notebook creates a metadata table for differential expression analysis by:
1. Processing the nf-core RNA-seq pipeline output
2. Adding condition and group annotations
3. Generating a standardized metadata file

### Required User Input

1. Define the following in the next cell:
   - Functions for condition/group filtering
   - Condition levels
   - Group levels
2. (Optional) Merge additional metadata from SRA run table if needed
3. Specify any samples to exclude from analysis

#### Expected Input Files
1. `rnaseq_output/pipeline_info/samplesheet.valid.csv`
   - Output from nf-core RNA-seq pipeline
   - Contains sample metadata and QC information

2. (Optional) `rnaseq_output/pipeline_info/SraRunTable.txt`
   - SRA run table with additional metadata
   - Only needed if merging additional SRA metadata

#### Output Files
- standardized metadata file

In [None]:
import os
from pathlib import Path

import pandas as pd
import numpy as np

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

### 1.1 Configure Notebook and define functions, condition levels, and group levels in metadata dataframe.

#### Required Metadata Columns in SAMPLE_CONDITIONS:

1. **sample_type_1**
   - Primary designation of cell type, tissue, etc.
   - Examples: BLOOD, PBMC, SYNOVIUM, ILEUM

2. **sample_type_***
   - Secondary description of cell type, tissue, etc.
   - Examples: Subdivided tissues

3. **disease_***
   - Related disease for each sample
   - Examples: SYSTEMIC_LUPUS_ERYTHEMATOSUS, RHEUMATOID_ARTHRITIS, SJOGRENS_SYNDROME, ULCERATIVE_COLITIS
   - Use NA for levels in factors that don't apply to specific disease

4. **sample_condition_***
   - Description of condition (HEALTHY/DISEASE)
   - sample_condition_1: primary disease phenotype (e.g., SLE, RA, T1D)
   - Note: Partially redundant with condition_*, where condition_* is used for design matrix, and sample_condition_* are used for storing additional metadata. There are situations (nested covariates) where all metadata not amenable to full rank model. 

5. **condition_***
   - Description of condition (HEALTHY/DISEASE)
   - condition_1: primary disease phenotype (e.g., SLE, RA, T1D, T2D)
   - condition_*: SNPs, drug treatments, etc.

#### Required Variable Definitions:

1. **SRX_COLUMN**
   - SRX accession location (usually in 'experiment_accession' or 'run_accession' column)

2. **SAMPLE_CONDITION_COLUMNS**
   - Dictionary: {condition_name: column_to_search}

3. **SAMPLE_GROUP_COLUMNS**
   - Dictionary: {group_name: column_to_search}

4. **SAMPLE_CONDITIONS**
   - Nested dictionary: {condition_name: {search_pattern: condition_level}}

5. **SAMPLE_GROUPS**
   - Nested dictionary: {group_name: {search_pattern: group_level}}

#### Functions to Configure:

- **filter_condition**: See function docstring
- **filter_group**: See function docstring

In [None]:
DATA_PATH = Path.cwd().parent

EXPERIMENT_ID = DATA_PATH.parts[-1]

RESULTS_PATH = DATA_PATH / 'de_results'

RESULTS_PATH.mkdir(exist_ok=True, parents=True)

METADATA_FH = RESULTS_PATH / f'{Path().cwd().parts[-1]}_metadata.csv'

# Define functions, condition levels, and group levels in metadata dataframe.
SRX_COLUMN = 'experiment_accession'

# Example setup for SAMPLE_CONDITION, SAMPLE_CONDITION_COLUMNS, etc.

# SAMPLE_CONDITION_COLUMNS = {
#     'condition-1': 'sample_description',
#     'sample_type_1': 'sample_description',
#     'disease_1': 'sample_description',
#     'sample_condition_1': 'sample_description',
#     'sample_condition_2': 'sample_description',
# }

# SAMPLE_GROUP_COLUMNS = {
# }

# SAMPLE_CONDITIONS = {
#     'condition-1' : { 
#         'DR': 'DISEASE_1',
#         'DM': 'DISEASE_1',
#         'Control': 'CONTROL',
#     },
#     'sample_type_1' : {
#         '':'PBMC',
#     },
#     'sample_condition_2': {
#         'DR': 'TYPE_2_DIABETES_RETINOPATHY',
#         'DM': 'TYPE_2_DIABETES_MELLITUS',
#         'Control': 'HEALTHY',
#     },
#     'sample_condition_1': {
#         'DR': 'TYPE_2_DIABETES',
#         'DM': 'TYPE_2_DIABETES',
#         'Control': 'HEALTHY',
#     },
#     'disease_1' : {
#         'DR':'TYPE_2_DIABETES',
#         'DM':'TYPE_2_DIABETES',
#         'Control':'TYPE_2_DIABETES',
#     },
# }


# SAMPLE_GROUPS = {
# }

# List of columns to keep in metadata file.
KEEP_COLUMNS = [ 
    'single_end', 
    'strandedness', 
    'experiment_accession', 
    'submission_accession', 
    'library_layout', 
    'library_selection', 
    'library_source', 
    'library_strategy', 
    'library_name', 
    'instrument_model', 
    'instrument_platform', 
    'read_count', 
    'tax_id', 
    'sample_title', 
    'experiment_title', 
    'sample_description',
]

def filter_condition(cell: str, condition_key: str) -> bool:
    '''Modify to check for a key present in SAMPLE_CONDITION

    Args:
        cell (str): contents of single cell from metadata dataframe
        condition_key (str): key to search cell 

    Returns:
        (bool): transfer key to cell
    '''
    # return cell.startswith(condition_key)
    return condition_key in cell
    # return cell.endswith(condition_key)

def filter_group(cell: str, group_key: str) -> bool:
    '''Modify to check for key present in SAMPLE_GROUP

    Args:
        cell (str): contents of single cell from metadata dataframe
        group_key (str): key to search cell

    Returns:
        (bool): transfer key to cell
    '''
    #return cell.startswith(group_key)
    return group_key in cell
    # return cell.endswith(group_key)

In [None]:
# Read output of nf-core RNA-seq samplesheet into dataframe

samplesheet_valid = pd.read_csv(
    DATA_PATH / 'rnaseq_output/pipeline_info/samplesheet.valid.csv',
)
samplesheet_valid

### 1.2 Merge Additional Metadata from SRA Run Table (Optional)
* Merge run_table into samplesheet_valid if extra metadata from run_table is needed.

In [None]:
# Merge run_table into samplesheet_valid if extra metadata from run_table is needed.

# run_table = pd.read_csv(
#     DATA_PATH / 'rnaseq_output/pipeline_info/SraRunTable.txt',
# )
# samplesheet_valid = samplesheet_valid.merge(
#     run_table.loc[:,['Run',]].rename(
#         {'Run': 'run_accession'},
#         axis=1,
#     ), 
#     on='run_accession',
# )
# samplesheet_valid

### 1.3 Drop samples from metadata.
* Samples are dropped based on "SRX" id in metadata dataframe.

In [None]:
# Filter out samples that aren't desired in further analysis.

samples_to_remove = []

sample_indices = samplesheet_valid.loc[samplesheet_valid[SRX_COLUMN].isin(samples_to_remove)].index

sample_indices
samplesheet_valid.drop(sample_indices, axis=0, inplace=True)

samplesheet_valid

### 1.4 Transfer Condition and Group Keys to Metadata
* Match sample conditions and groups to those in SAMPLE_CONDITION_PREFIXES and SAMPLE_GROUP_PREFIXES 

In [None]:
for k, v in SAMPLE_CONDITION_COLUMNS.items():
    samplesheet_valid[k] = np.nan
    for kg, vg in SAMPLE_CONDITIONS[k].items():
        samplesheet_valid.loc[samplesheet_valid[v].map(lambda x: filter_condition(x, kg)), k] = vg
    try:
        assert not samplesheet_valid[k].isnull().any()
    except Exception as e:
        print(k, v, samplesheet_valid[k])
        raise e


for k, v in SAMPLE_GROUP_COLUMNS.items():
    samplesheet_valid[k] = np.nan
    for kg, vg in SAMPLE_GROUPS[k].items():
        samplesheet_valid.loc[samplesheet_valid[v].map(lambda x: filter_group(x, kg)), k] = vg
    try:
        assert not samplesheet_valid[k].isnull().any()
    except Exception as e:
        print(k, v, samplesheet_valid[k])
        raise e


### 1.5 Save metadata to csv.

In [None]:
# Push metadata to csv. 

metadata = samplesheet_valid.loc[
    :, [SRX_COLUMN] + 
    list(SAMPLE_CONDITIONS.keys()) + 
    list(SAMPLE_GROUPS.keys()) + 
    [c for c in KEEP_COLUMNS if c != SRX_COLUMN]
]

metadata.rename({SRX_COLUMN: 'accession'}, inplace=True, axis=1)


# Collapse technical replicates into eachother by summing read counts across runs.
if 'read_count' in metadata.columns:
    metadata['read_count'] = metadata.groupby('accession')['read_count'].transform('sum')

# Sample dataframe from output of nf-core rnaseq lists distinct samples by SRR*, 
# but groups samples by SRX* for analysis.
metadata.drop_duplicates(inplace=True)

single_groups = []
for c in metadata.columns:
    if not c.startswith('group') or c.startswith('condition'):
        continue
    if len(metadata[c].value_counts()) == 1:
        single_groups.append(c)

metadata.drop(single_groups, axis=1, inplace=True)

metadata.to_csv(METADATA_FH, index=False)
metadata