In [None]:
import sys
from os import path
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from intervaltree import IntervalTree
from matplotlib import pyplot as plt
from matplotlib import colors as mcolors
from scipy.stats import pearsonr, zscore

PLOT_Y_LIM = 7.5
GENE_MERGE_DISTANCE = 2_000_000
ARM_COLORS = {'p': 'tab:blue', 'q': 'tab:orange', 'spanning': 'tab:green', None: 'grey'}

BASE_DATA_DIR = Path('data')
PLOT_OUTPUT_DIR = Path('graphics')
PLOT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

CASES_FILE_PATH = BASE_DATA_DIR / 'cases.csv'
CYTOBAND_FILE_PATH = BASE_DATA_DIR / 'cytoBand.txt'
RELEVANT_GENES_FILE_PATH = BASE_DATA_DIR / 'relevant_genes.csv'

# For cases_df processing
CASES_COLUMNS_INDICES = [0, 1, 2, 3, 4, 5, 6, 13, 14]
CASES_COLUMN_NAMES = ['patient_id', 'case_n_number', 'age', 'sex', 'tumor_type', 'diagnosis', 'DIN', 'sentrix_id', 'barcode']

# For cytoband_df processing
CYTOBAND_COLUMN_NAMES = ['chromosome', 'start', 'end', 'band', 'giemsa']

# For relevant_genes_df processing
RELEVANT_GENES_COLUMN_NAMES = ['gene', 'chromosome', 'start', 'end']

In [None]:
def _sanitize_chromosome_column(df, chromosome_col='chromosome'):
    """
    Converts a chromosome column to numeric Int64Dtype, removing 'chr' prefix.
    Rows with unparsable chromosome values (e.g., 'X', 'Y') are dropped.
    """
    if chromosome_col not in df.columns:
        return df

    df = df.copy()
    df[chromosome_col] = (
        df[chromosome_col]
        .astype(str)
        .str.replace('chr', '', regex=False)
        .pipe(pd.to_numeric, errors='coerce')
    )
    df = df.dropna(subset=[chromosome_col])
    df[chromosome_col] = df[chromosome_col].astype(pd.Int64Dtype())
    return df

In [None]:
def load_and_preprocess_cases(file_path):
    """
    Loads and preprocesses the cases data from a CSV file.
    """
    try:
        df = pd.read_csv(file_path, delimiter=',', encoding='ISO-8859-1', decimal=',')
    except FileNotFoundError:
        print(f"Error: Cases file not found at {file_path}")
        return pd.DataFrame()

    df = df.iloc[:, CASES_COLUMNS_INDICES]
    df.columns = CASES_COLUMN_NAMES

    # Remove leading 'P' from patient_id and convert to integer
    df['patient_id'] = df['patient_id'].astype(str).str.lstrip('P').astype(int)

    if 'DIN' in df.columns:
        df['DIN'] = pd.to_numeric(df['DIN'], errors='coerce')

    df = df.dropna().reset_index(drop=True)
    return df

In [None]:
def load_and_preprocess_cytoband(file_path):
    try:
        df = pd.read_csv(file_path, delimiter='\t', names=CYTOBAND_COLUMN_NAMES)
    except FileNotFoundError:
        print(f"Error: Cytoband file not found at {file_path}")
        return pd.DataFrame()
    return _sanitize_chromosome_column(df, 'chromosome')


def calculate_chromosome_features(cytoband_df):
    if cytoband_df.empty:
        return pd.DataFrame(columns=['chromosome', 'length', 'chromosome_absolute_start']), {}
    chromosomes_df = (
        cytoband_df
        .groupby('chromosome')
        .agg(length=('end', 'max'))
        .reset_index()
        .sort_values('chromosome')
    )
    chromosomes_df['chromosome_absolute_start'] = chromosomes_df['length'].cumsum() - chromosomes_df['length']
    chromosome_start_map = chromosomes_df.set_index('chromosome')['chromosome_absolute_start'].to_dict()
    return chromosomes_df, chromosome_start_map


def create_arm_mapping_df(cytoband_df, chromosome_start_map):
    if cytoband_df.empty or not chromosome_start_map:
        return pd.DataFrame(columns=['chromosome', 'arm', 'arm_start', 'arm_end', 'arm_abs_start', 'arm_abs_end'])
    chromosome_starts_series = pd.Series(chromosome_start_map, name='chr_abs_start').reset_index().rename(
        columns={'index': 'chromosome_num'})
    arm_df = (
        cytoband_df
        .assign(arm=lambda x: x['band'].astype(str).str[0].str.lower())
        .groupby(['chromosome', 'arm'])
        .agg(arm_start=('start', 'min'), arm_end=('end', 'max'))
        .reset_index()
        .merge(chromosome_starts_series, left_on='chromosome', right_on='chromosome_num')
        .assign(
            arm_abs_start=lambda x: x['chr_abs_start'] + x['arm_start'],
            arm_abs_end=lambda x: x['chr_abs_start'] + x['arm_end']
        )
        [['chromosome', 'arm', 'arm_start', 'arm_end', 'arm_abs_start', 'arm_abs_end']]
    )
    return arm_df


def create_arm_lookup_table(arm_mapping_df):
    """Pivots arm_mapping_df for easy lookup of p/q arm start/end."""
    if arm_mapping_df.empty:
        return pd.DataFrame()
    arm_lookup = arm_mapping_df.pivot(index='chromosome', columns='arm',
                                      values=['arm_start', 'arm_end'])
    if not arm_lookup.empty:
        new_cols = []
        for val_type in ['arm_start', 'arm_end']:
            for arm_type in ['p', 'q']:
                if (val_type, arm_type) in arm_lookup.columns:
                    new_cols.append(f'{arm_type}_{val_type.split("_")[1]}')  # e.g. p_start, q_end
                else:
                    new_cols.append(None)  # Placeholder for missing data

        cols_to_rename = {}
        if ('arm_start', 'p') in arm_lookup.columns: cols_to_rename[('arm_start', 'p')] = 'p_start'
        if ('arm_end', 'p') in arm_lookup.columns: cols_to_rename[('arm_end', 'p')] = 'p_end'
        if ('arm_start', 'q') in arm_lookup.columns: cols_to_rename[('arm_start', 'q')] = 'q_start'
        if ('arm_end', 'q') in arm_lookup.columns: cols_to_rename[('arm_end', 'q')] = 'q_end'

        arm_lookup.columns = ['_'.join(col).strip() for col in arm_lookup.columns.values]  # Flatten MultiIndex
        arm_lookup = arm_lookup.rename(columns={
            'arm_start_p': 'p_start', 'arm_end_p': 'p_end',
            'arm_start_q': 'q_start', 'arm_end_q': 'q_end'
        })
        for col in ['p_start', 'p_end', 'q_start', 'q_end']:
            if col not in arm_lookup.columns:
                arm_lookup[col] = np.nan
    return arm_lookup

In [None]:
def load_and_preprocess_relevant_genes(file_path):
    try:
        df = pd.read_csv(file_path, delimiter=';', names=RELEVANT_GENES_COLUMN_NAMES)
    except FileNotFoundError:
        print(f"Error: Relevant genes file not found at {file_path}")
        return pd.DataFrame()
    # _sanitize_chromosome_column ensures 'chromosome' is int
    return _sanitize_chromosome_column(df, 'chromosome')


def build_gene_interval_trees(relevant_genes_df):
    """Builds a dictionary of IntervalTrees for gene lookups, per chromosome."""
    if relevant_genes_df.empty: return {}
    trees = {}
    for _, row in relevant_genes_df.iterrows():
        chrom = row.chromosome
        if chrom not in trees:
            trees[chrom] = IntervalTree()

        start = int(row.start)
        end = int(row.end) + 1

        # Store the original start/end in the data payload for reference.
        trees[chrom].addi(start, end, {'gene': row.gene, 'start': int(row.start), 'end': int(row.end)})

    return trees


def annotate_segments_with_genes(df, gene_interval_trees, genes_df):
    """
    Annotates DataFrame segments with overlapping genes from the relevant_genes list,
    applying logic on a per-case basis to correctly handle pre-annotated genes.
    """
    if df.empty or not gene_interval_trees:
        if 'gene' not in df.columns:
            df['gene'] = pd.NA
        return df

    all_relevant_genes = set(genes_df['gene'])
    case_id_col = 'case_n_number'

    # Ensure a 'gene' column exists and sanitize it globally first.
    # This is efficient and removes junk like 'Antitarget' from all rows.
    if 'gene' not in df.columns:
        df['gene'] = pd.NA
    df['gene'] = df['gene'].where(df['gene'].isin(all_relevant_genes), pd.NA)

    # Ensure chromosome column is numeric for tree lookups.
    df['chromosome'] = pd.to_numeric(df['chromosome'], errors='coerce')

    processed_cases = []
    for case_id, case_df in df.groupby(case_id_col):
        case_df = case_df.copy()

        # Identify genes already validly annotated for this case
        genes_on_panel_for_this_case = set(case_df['gene'].dropna())

        # Identify rows that still need annotation within THIS case.
        rows_to_annotate_mask = case_df['gene'].isna()
        rows_to_annotate_indices = case_df.index[rows_to_annotate_mask]

        # Iterate through and annotate only the necessary rows for this case.
        for index in rows_to_annotate_indices:
            row = case_df.loc[index]

            if pd.isna(row.chromosome):
                continue

            tree = gene_interval_trees.get(int(row.chromosome))
            if not tree:
                continue

            overlaps = tree.overlap(row.start, row.end)
            if overlaps:
                # Filter out genes that are already on this specific case's panel.
                additional_gene_overlaps = {
                    iv for iv in overlaps if iv.data['gene'] not in genes_on_panel_for_this_case
                }

                if additional_gene_overlaps:
                    first_overlap = sorted(additional_gene_overlaps, key=lambda i: (i.begin, i.data['gene']))[0]
                    case_df.at[index, 'gene'] = first_overlap.data['gene']

        processed_cases.append(case_df)

    # Recombine all the processed cases into a single DataFrame.
    if not processed_cases:
        return pd.DataFrame(columns=df.columns)  # Return empty DF if no cases were processed

    return pd.concat(processed_cases, ignore_index=True)


def aggregate_data_by_gene(annotated_df, merge_distance=GENE_MERGE_DISTANCE):
    """Aggregates log2 data to gene level and merges nearby genes, on a per-case basis."""
    if (annotated_df.empty or
            'gene' not in annotated_df.columns or
            'absolute_start' not in annotated_df.columns or
            'log2' not in annotated_df.columns):
        print("Warning: Cannot aggregate by gene due to missing columns.", file=sys.stderr)
        return pd.DataFrame()

    case_id_col = 'case_n_number'
    if case_id_col not in annotated_df.columns:
        print(f"Warning: Expected column '{case_id_col}' not found.", file=sys.stderr)
        return pd.DataFrame()

    ngs_with_gene = annotated_df[annotated_df['gene'].notna() & annotated_df['log2'].notna()].copy()
    if ngs_with_gene.empty:
        return pd.DataFrame()

    all_cases_final_reps = []

    for case_id, case_df in ngs_with_gene.groupby(case_id_col):

        # Aggregate to gene-level for this single case
        gene_reps_case = case_df.groupby('gene').agg(
            absolute_position=('absolute_start', 'median'),
            log2=('log2', 'median')
        ).reset_index()

        if gene_reps_case.empty:
            continue

        # Sort genes by position for this single case
        gene_reps_case = gene_reps_case.sort_values('absolute_position').reset_index(drop=True)

        # Apply only to this case's genes
        if merge_distance <= 0:
            final_reps_data_case = gene_reps_case.to_dict('records')
        else:
            groups = []
            current_group = []
            for _, row in gene_reps_case.iterrows():
                if not current_group:
                    current_group.append(row)
                    continue

                last_row_in_group = current_group[-1]
                if (row.absolute_position - last_row_in_group.absolute_position) > merge_distance:
                    groups.append(current_group)
                    current_group = [row]
                else:
                    current_group.append(row)

            if current_group:
                groups.append(current_group)

            # Re-aggregate the grouped genes for this single case
            final_reps_data_case = []
            for group in groups:
                group_df = pd.DataFrame(group)
                if len(group_df) == 1:
                    merged_row = group_df.iloc[0].to_dict()
                else:
                    unique_genes = sorted(list(set(group_df['gene'])))
                    merged_name = _merge_gene_names(unique_genes)
                    merged_row = {
                        'gene': merged_name,
                        'absolute_position': group_df['absolute_position'].median(),
                        'log2': group_df.loc[group_df['log2'].abs().idxmax()]['log2']
                    }
                merged_row[case_id_col] = case_id  # Add the case identifier back
                final_reps_data_case.append(merged_row)

        all_cases_final_reps.extend(final_reps_data_case)

    if not all_cases_final_reps:
        return pd.DataFrame()

    # Create the final DataFrame from all processed cases
    merged_gene_reps = pd.DataFrame(all_cases_final_reps)
    merged_gene_reps['log2'] = merged_gene_reps['log2'].clip(lower=-PLOT_Y_LIM, upper=PLOT_Y_LIM)

    return merged_gene_reps


def _merge_gene_names(gene_list):
    """
    Merges a list of gene names intelligently.
    e.g., ['CDKN2A', 'CDKN2B'] becomes 'CDKN2A/B'.
    """
    if len(gene_list) <= 1:
        return '/'.join(gene_list)

    prefix = path.commonprefix(gene_list)

    # If the prefix is trivial, don't shorten, just join
    if len(prefix) < 3:
        return '/'.join(gene_list)

    # Keep the first gene name whole, append only suffixes of the rest
    suffixes = [gene[len(prefix):] for gene in gene_list[1:]]
    return '/'.join([gene_list[0]] + suffixes)

In [None]:
def load_ngs_cnr_file(item_id, case_id, base_dir):
    """Load NGS raw data from .cnr files"""
    cnr_dir = base_dir / 'ngs/cnr'
    item_id_str = str(item_id) if pd.notna(item_id) else ""
    if not item_id_str:
        return None

    matches = list(cnr_dir.glob(f'*{item_id_str}*.cnr'))
    if not matches:
        print(f"CNR file not found for {item_id}")
        return None

    path = matches[0]
    try:
        print(f"Loading CNR file: {path}")
        df = pd.read_csv(path, sep='\t', header=0)
        print(f"CNR file shape: {df.shape}")

        # Clean up the data
        df['case_n_number'] = case_id
        df['barcode'] = item_id

        # Remove Antitarget entries and extreme values
        is_antitarget = (df['gene'] == 'Antitarget')
        df.loc[is_antitarget, 'gene'] = pd.NA
        if 'log2' in df.columns:
            df = df[df['log2'].abs() < 6].copy()

        return _sanitize_chromosome_column(df)
    except Exception as e:
        print(f"Error loading CNR file {path}: {e}")
        return None


def load_ngs_cns_file(item_id, case_id, base_dir):
    """Load NGS segments from .cns files"""
    cns_dir = base_dir / 'ngs/cns'
    item_id_str = str(item_id) if pd.notna(item_id) else ""
    if not item_id_str:
        return None

    matches = list(cns_dir.glob(f'*{item_id_str}*.cns'))
    if not matches:
        print(f"CNS file not found for {item_id}")
        return None

    path = matches[0]
    try:
        print(f"Loading CNS file: {path}")
        df = pd.read_csv(path, sep='\t', header=0)
        print(f"CNS file shape: {df.shape}")

        # Clean up the data
        df['case_n_number'] = case_id
        df['barcode'] = item_id

        # Filter extreme values
        if 'log2' in df.columns:
            df = df[df['log2'].abs() < 6].copy()

        return _sanitize_chromosome_column(df)
    except Exception as e:
        print(f"Error loading CNS file {path}: {e}")
        return None


def load_epic_igv_file(item_id, case_id, base_dir):
    """Load EPIC raw data from .igv files"""
    igv_dir = base_dir / 'epic/igv'
    path = igv_dir / f'{item_id}_CNV.igv'

    if not path.exists():
        print(f"IGV file not found: {path}")
        return None

    try:
        print(f"Loading IGV file: {path}")
        df = pd.read_csv(path, sep='\t', header=0)
        print(f"IGV file columns: {df.columns.tolist()}")
        print(f"IGV file shape: {df.shape}")

        if df.empty:
            print(f"IGV file is empty: {path}")
            return None

        # Get the column name for the log2 values (should be the last column, typically sample ID)
        log2_col = df.columns[-1]
        print(f"Using column '{log2_col}' as log2 values")

        # Rename columns to match expected format
        df = df.rename(columns={
            'Chromosome': 'chromosome',
            'Start': 'start',
            'End': 'end',
            log2_col: 'log2'
        })

        # Keep only the columns we need
        df = df[['chromosome', 'start', 'end', 'log2']]
        df['case_n_number'] = case_id
        df['sentrix_id'] = item_id

        return _sanitize_chromosome_column(df)
    except Exception as e:
        print(f"Error loading IGV file {path}: {e}")
        return None


def load_epic_seg_file(item_id, case_id, base_dir):
    """Load EPIC segments from .seg files"""
    seg_dir = base_dir / 'epic/seg'

    # Try to find the seg file - they might have different naming patterns
    matches = list(seg_dir.glob(f'*{item_id}*.seg'))
    if not matches:
        print(f"SEG file not found for {item_id}")
        return None

    path = matches[0]
    try:
        print(f"Loading SEG file: {path}")
        df = pd.read_csv(path, sep='\t', header=0)
        print(f"SEG file columns: {df.columns.tolist()}")
        print(f"SEG file shape: {df.shape}")

        if df.empty:
            print(f"SEG file is empty: {path}")
            return None

        # The SEG files have a specific format - extract the relevant columns
        # Find columns that contain the sample information
        sample_cols = [col for col in df.columns if item_id in col]

        if not sample_cols:
            print(f"No sample columns found for {item_id} in SEG file")
            return None

        # Extract chromosome, start, end, and log2 values
        # Assuming the format: ID, chrom, loc.start, loc.end, num.mark, bstat, pval, seg.mean, seg.median
        chrom_col = [col for col in df.columns if 'chrom' in col][0]
        start_col = [col for col in df.columns if 'loc.start' in col][0]
        end_col = [col for col in df.columns if 'loc.end' in col][0]

        # Use seg.mean as the log2 value (this is the segment mean)
        log2_col = [col for col in df.columns if 'seg.mean' in col][0]

        # Create cleaned dataframe
        cleaned_df = pd.DataFrame({
            'chromosome': df[chrom_col],
            'start': df[start_col],
            'end': df[end_col],
            'log2': df[log2_col]
        })

        cleaned_df['case_n_number'] = case_id
        cleaned_df['sentrix_id'] = item_id

        return _sanitize_chromosome_column(cleaned_df)
    except Exception as e:
        print(f"Error loading SEG file {path}: {e}")
        return None


def _normalize_log2(dfs, col='log2'):
    """
    Normalize log2 values across datasets using z-score normalization.
    This ensures all datasets (raw and segments, NGS and EPIC) are on the same scale.
    """
    if not dfs:
        print("Warning: No dataframes provided for normalization")
        return []

    # Collect all log2 values from all dataframes
    all_log2_values = []
    valid_dfs = []

    for i, df in enumerate(dfs):
        if df is None or df.empty:
            print(f"Warning: Dataframe {i} is None or empty, skipping")
            continue
        if col not in df.columns:
            print(f"Warning: Column '{col}' not found in dataframe {i}, skipping")
            continue

        # Get valid (non-NaN) log2 values
        valid_log2 = df[col].dropna()
        if valid_log2.empty:
            print(f"Warning: No valid log2 values in dataframe {i}")
            continue

        all_log2_values.append(valid_log2)
        valid_dfs.append(df)

    if not all_log2_values:
        print("Warning: No valid log2 values found across all dataframes")
        return [d.copy() if d is not None else pd.DataFrame() for d in dfs]

    # Combine all log2 values to calculate global statistics
    combined_log2 = pd.concat(all_log2_values, ignore_index=True)

    # Calculate global mean and standard deviation
    global_mean = combined_log2.mean()
    global_std = combined_log2.std()

    print(f"Log2 normalization statistics:")
    print(f"  - Total values: {len(combined_log2)}")
    print(f"  - Global mean: {global_mean:.4f}")
    print(f"  - Global std: {global_std:.4f}")
    print(f"  - Value range: [{combined_log2.min():.4f}, {combined_log2.max():.4f}]")

    # Handle edge case where std is 0 or NaN
    if pd.isna(global_std) or global_std == 0:
        print("Warning: Standard deviation is 0 or NaN, using raw values without normalization")
        return [d.copy() if d is not None else pd.DataFrame() for d in dfs]

    # Normalize each dataframe using the global statistics
    normalized_dfs = []
    for i, df in enumerate(dfs):
        if df is None or df.empty or col not in df.columns:
            normalized_dfs.append(df.copy() if df is not None else pd.DataFrame())
            continue

        df_copy = df.copy()

        # Apply z-score normalization: (x - mean) / std
        original_values = df_copy[col]
        normalized_values = (original_values - global_mean) / global_std

        # Count how many values were normalized
        valid_count = original_values.notna().sum()

        df_copy[col] = normalized_values
        normalized_dfs.append(df_copy)

        print(f"  - Dataframe {i}: {valid_count} values normalized")
        if valid_count > 0:
            print(f"    Original range: [{original_values.min():.4f}, {original_values.max():.4f}]")
            print(f"    Normalized range: [{normalized_values.min():.4f}, {normalized_values.max():.4f}]")

    return normalized_dfs


def load_combine_genomic_data(cases, base_dir, data_type, file_type, id_col, case_id_col='case_n_number',
                              l2_col='log2'):
    """
    Generic function to load and combine genomic data with consistent log2 normalization
    
    Args:
        cases: DataFrame with case information
        base_dir: Base data directory
        data_type: 'ngs' or 'epic'
        file_type: 'raw' (cnr/igv) or 'segments' (cns/seg)
        id_col: Column name for the ID (barcode/sentrix_id)
        case_id_col: Column name for case ID
        l2_col: Column name for log2 values
    """
    raw_data = []

    # Select appropriate loader function
    if data_type == 'ngs' and file_type == 'raw':
        loader_func = load_ngs_cnr_file
        print(f"Loading NGS raw data (.cnr files) for {len(cases)} cases...")
    elif data_type == 'ngs' and file_type == 'segments':
        loader_func = load_ngs_cns_file
        print(f"Loading NGS segments (.cns files) for {len(cases)} cases...")
    elif data_type == 'epic' and file_type == 'raw':
        loader_func = load_epic_igv_file
        print(f"Loading EPIC raw data (.igv files) for {len(cases)} cases...")
    elif data_type == 'epic' and file_type == 'segments':
        loader_func = load_epic_seg_file
        print(f"Loading EPIC segments (.seg files) for {len(cases)} cases...")
    else:
        raise ValueError(f"Invalid combination: data_type='{data_type}', file_type='{file_type}'")

    # Load data for each case
    for _, row in cases.iterrows():
        item_id, case_id = row[id_col], row[case_id_col]
        if pd.isna(item_id) or pd.isna(case_id) or str(item_id) == '0':
            print(f"Skipping case {case_id}: missing or invalid {id_col}")
            continue

        df = loader_func(str(item_id), str(case_id), base_dir)
        if df is not None and not df.empty:
            raw_data.append(df)
            print(f"Loaded data for case {case_id}: {df.shape}")
        else:
            print(f"No data loaded for case {case_id}")

    if not raw_data:
        print("No files were successfully loaded")
        return pd.DataFrame()

    print(f"Successfully loaded {len(raw_data)} files")

    # Apply log2 normalization across all loaded files
    print(f"Applying log2 normalization to {data_type} {file_type} data...")
    normalized_data = _normalize_log2(raw_data, l2_col)

    # Combine all normalized dataframes
    if normalized_data:
        combined_df = pd.concat(normalized_data, ignore_index=True)
        print(f"Combined {data_type} {file_type} data shape: {combined_df.shape}")

        # Final statistics check
        if l2_col in combined_df.columns:
            final_log2 = combined_df[l2_col].dropna()
            if not final_log2.empty:
                print(f"Final normalized {l2_col} statistics:")
                print(f"  - Mean: {final_log2.mean():.4f}")
                print(f"  - Std: {final_log2.std():.4f}")
                print(f"  - Range: [{final_log2.min():.4f}, {final_log2.max():.4f}]")

        return combined_df
    else:
        return pd.DataFrame()


def add_segment_arm_classification(df, arm_lookup_table):
    if df.empty or arm_lookup_table.empty:
        df['segment_arm'] = None
        return df
    df['chromosome'] = df['chromosome'].astype(int)
    dfa = df.merge(arm_lookup_table, on='chromosome', how='left')
    for pos in ['start', 'end']:
        ps, pe, qs, qe = (dfa[c].fillna(float('-inf') if 'start' in c else float('inf')) for c in
                          ['p_start', 'p_end', 'q_start', 'q_end'])
        dfa[f'{pos}_arm'] = np.select([(dfa[pos] >= ps) & (dfa[pos] < pe), (dfa[pos] >= qs) & (dfa[pos] < qe)],
                                      ['p', 'q'], default=None)
    dfa['segment_arm'] = np.where(dfa['start_arm'] == dfa['end_arm'], dfa['start_arm'], 'spanning')
    dfa.loc[(dfa['start_arm'].notna() & dfa['end_arm'].notna() & (
            dfa['start_arm'] != dfa['end_arm'])), 'segment_arm'] = 'spanning'
    dfa.loc[
        (dfa['start_arm'].isna() | dfa['end_arm'].isna()) & (dfa['start_arm'] != dfa['end_arm']), 'segment_arm'] = None
    return dfa.drop(
        columns=[c for c in ['p_start', 'p_end', 'q_start', 'q_end', 'start_arm', 'end_arm'] if c in dfa.columns])




In [None]:
def calculate_arm_log2_from_segments(classified_seg_df, arm_map_df, id_col_name='sentrix_id'):
    """
    Aggregates log2 from segment data (IGV or SEG) to arm level using segment clipping.
    Weights log2 by clipped segment coverage within the arm.
    """
    if classified_seg_df.empty or arm_map_df.empty: return pd.DataFrame()
    if 'segment_arm' not in classified_seg_df.columns: return pd.DataFrame()  # Need classification

    df = classified_seg_df.copy()
    df['chromosome'] = df['chromosome'].astype(int)
    arm_map_df['chromosome'] = arm_map_df['chromosome'].astype(int)

    results = []
    for (case, chrom), group in df.groupby(['case_n_number', 'chromosome']):
        chrom_arm_info = arm_map_df[arm_map_df['chromosome'] == chrom]
        for _, arm_row in chrom_arm_info.iterrows():
            arm, arm_start, arm_end, arm_abs_start = arm_row['arm'], arm_row['arm_start'], arm_row['arm_end'], arm_row[
                'arm_abs_start']

            # Select relevant segments (classified to this arm OR spanning) & overlapping the arm physically
            mask = (((group['segment_arm'] == arm) | (group['segment_arm'] == 'spanning')) &
                    (group['end'] > arm_start) & (group['start'] < arm_end))
            relevant_segments = group[mask].copy()

            item_id_val = None
            if not relevant_segments.empty:
                # Clip segments to arm boundaries
                relevant_segments['s_clip'] = relevant_segments['start'].clip(lower=arm_start)
                relevant_segments['e_clip'] = relevant_segments['end'].clip(upper=arm_end)
                relevant_segments['clip_cov'] = relevant_segments['e_clip'] - relevant_segments['s_clip']
                valid_clipped = relevant_segments[relevant_segments['clip_cov'] > 0]

                if not valid_clipped.empty:
                    total_clip_cov = valid_clipped['clip_cov'].sum()
                    log2_agg = (valid_clipped['log2'] * valid_clipped[
                        'clip_cov']).sum() / total_clip_cov if total_clip_cov > 0 else np.nan
                    agg_start, agg_end = valid_clipped['s_clip'].min(), valid_clipped['e_clip'].max()
                    if id_col_name in valid_clipped.columns: item_id_val = valid_clipped[id_col_name].iloc[0]
                else:
                    log2_agg, agg_start, agg_end = np.nan, arm_start, arm_end
            else:
                log2_agg, agg_start, agg_end = np.nan, arm_start, arm_end

            abs_start = arm_abs_start + (agg_start - arm_start)
            abs_end = arm_abs_start + (agg_end - arm_start)
            entry = {'case_n_number': case, 'chromosome': chrom, 'segment_arm': arm, 'log2': log2_agg,
                     'start': agg_start, 'end': agg_end, 'absolute_start': abs_start, 'absolute_end': abs_end}
            if item_id_val is not None: entry[id_col_name] = item_id_val
            results.append(entry)

    return pd.DataFrame(results).dropna(subset=['log2'])


# --- Functions specific to NGS/CNR aggregation ---

def prep_ngs_agg(df, chrom_start_map):
    """Prepares NGS data for arm aggregation: calculates absolute midpoints and length."""
    if df.empty or not chrom_start_map: return df
    df['chromosome'] = df['chromosome'].astype(int)  # Ensure int for mapping
    # Calculate absolute start based on bin midpoint
    df['absolute_start'] = df['chromosome'].map(chrom_start_map) + ((df['start'] + df['end']) // 2)
    df['length'] = df['end'] - df['start']  # Length of the bin
    return df


def agg_ngs_points_to_arms(classified_ngs_df, chrom_features_df):
    """Aggregates NGS points/bins to arms based on classification ('p','q','spanning')."""
    if classified_ngs_df.empty or 'segment_arm' not in classified_ngs_df.columns or chrom_features_df.empty:
        return pd.DataFrame()

    # Work with valid points having classification and positive length (for weighting)
    ngs_valid = classified_ngs_df[classified_ngs_df['segment_arm'].notna()].copy()
    if 'length' not in ngs_valid.columns:  # Ensure length exists
        ngs_valid['length'] = ngs_valid['end'] - ngs_valid['start']
    ngs_valid = ngs_valid[ngs_valid['length'] > 0]
    if ngs_valid.empty: return pd.DataFrame()

    ngs_valid['chromosome'] = ngs_valid['chromosome'].astype(int)
    chrom_features_df['chromosome'] = chrom_features_df['chromosome'].astype(int)

    # Aggregate using weighted average, weighted by bin length
    arm_agg = (
        ngs_valid.groupby(['case_n_number', 'chromosome', 'segment_arm'])
        .agg(
            log2=('log2', lambda x: np.average(x, weights=ngs_valid.loc[
                x.index, 'length']) if not x.empty and x.notna().any() else np.nan),
            arm_start=('start', 'min'),
            arm_end=('end', 'max'),
            # Optional: count number of bins per arm aggregate
            # num_bins = ('start', 'size')
        ).reset_index()
    )

    # Add absolute start/end for the aggregated arm segment
    arm_agg = arm_agg.merge(chrom_features_df[['chromosome', 'chromosome_absolute_start']], on='chromosome', how='left')
    arm_agg['arm_absolute_start'] = arm_agg['chromosome_absolute_start'] + arm_agg['arm_start']
    arm_agg['arm_absolute_end'] = arm_agg['chromosome_absolute_start'] + arm_agg['arm_end']

    return arm_agg.drop(columns=['chromosome_absolute_start'], errors='ignore').dropna(subset=['log2'])

In [None]:
def _plot_ngs_scatter(ax, ngs_df_case):
    """Plots individual NGS bins. Colors based on log2 value (gain/loss/neutral)."""
    if ngs_df_case.empty or 'log2' not in ngs_df_case.columns or 'absolute_start' not in ngs_df_case.columns:
        return

    log2_values = pd.to_numeric(ngs_df_case['log2'], errors='coerce').clip(-PLOT_Y_LIM, PLOT_Y_LIM)
    abs_pos = ngs_df_case['absolute_start']

    # Calculate alpha based on z-score of log2 values
    if len(log2_values) > 1:  # zscore needs at least 2 points
        # zscore returns a numpy array if input is a Series, re-index to match original
        abs_z = np.abs(zscore(log2_values.to_numpy()))
        calculated_alphas = np.clip((abs_z / 3) ** 2.25, 0.01, 1.0)
        alphas = pd.Series(calculated_alphas, index=log2_values.index)
    else:
        alphas = pd.Series([0.5], index=log2_values.index)  # Default alpha for a single point

    # Define colors
    color_gain = 'tab:orange'
    color_loss = 'tab:blue'
    color_neutral = 'grey'

    # Create masks for plotting, excluding NaNs from these categories
    mask_gain = (log2_values > 1e-6)  # Use a small epsilon for > 0
    mask_loss = (log2_values < -1e-6)  # Use a small epsilon for < 0
    mask_neutral = np.isclose(log2_values, 0, atol=1e-6)

    # Plot gains
    if mask_gain.any():
        ax.scatter(abs_pos[mask_gain], log2_values[mask_gain],
                   color=color_gain, alpha=alphas[mask_gain], s=2)

    # Plot losses
    if mask_loss.any():
        ax.scatter(abs_pos[mask_loss], log2_values[mask_loss],
                   color=color_loss, alpha=alphas[mask_loss], s=2)

    # Plot neutral points
    if mask_neutral.any():
        ax.scatter(abs_pos[mask_neutral], log2_values[mask_neutral],
                   color=color_neutral, alpha=alphas[mask_neutral], s=2)


def _plot_arm_means(ax, df_aggregated_case, color, linestyle, linewidth, label_text_prefix):
    if df_aggregated_case.empty or 'log2' not in df_aggregated_case.columns:
        return

    start_col = 'arm_absolute_start' if 'arm_absolute_start' in df_aggregated_case.columns else 'absolute_start'
    end_col = 'arm_absolute_end' if 'arm_absolute_end' in df_aggregated_case.columns else 'absolute_end'

    if not all(c in df_aggregated_case.columns for c in [start_col, end_col, 'log2']):
        return

    for idx, row in df_aggregated_case.reset_index().iterrows():
        if pd.notna(row['log2']) and pd.notna(row[start_col]) and pd.notna(row[end_col]):
            ax.plot(
                [row[start_col], row[end_col]],
                [row['log2'], row['log2']],
                color=color,
                linestyle=linestyle,
                linewidth=linewidth,
                label=f'{label_text_prefix}' if idx == 0 else None  # Label only first segment
            )


def _plot_gene_labels(ax, gene_reps_case):
    if gene_reps_case.empty or not all(c in gene_reps_case.columns for c in ['absolute_position', 'log2', 'gene']):
        return

    valid_gene_reps = gene_reps_case.dropna(subset=['absolute_position', 'log2', 'gene'])
    if valid_gene_reps.empty: return

    ax.scatter(
        valid_gene_reps['absolute_position'],
        valid_gene_reps['log2'],
        color='black', s=4
    )
    for _, row in valid_gene_reps.iterrows():
        y, name = row['log2'], row['gene']

        offset = 0.15 if y > 0 else -0.15
        if abs(y) > PLOT_Y_LIM * 0.75:
            offset = -1 * offset
        va = 'bottom' if offset > 0 else 'top'

        ax.text(
            row['absolute_position'],
            y + offset,
            name,
            fontsize=8,
            ha='center',
            va=va,
            rotation=90
        )


def draw_cnv_plot(case_info, df_ngs_case, df_ngs_agg, df_epic_agg, df_gene_reps,
                  df_chroms, arm_map_df):
    case_n_number = case_info['case_n_number']
    fig, ax = plt.subplots(figsize=(10, 6))

    # Use the style-specific helper functions
    _plot_ngs_scatter(ax, df_ngs_case)
    _plot_arm_means(ax, df_ngs_agg, 'purple', '-', 1.5, 'average cnr (ngs sample)')
    _plot_arm_means(ax, df_epic_agg, 'purple', '--', 1.0, 'average cnr (epic reference)')
    _plot_gene_labels(ax, df_gene_reps)

    # Axis lines and ticks
    ax.axhline(0, color='grey', linewidth=0.5)

    if not df_chroms.empty and all(
            c in df_chroms.columns for c in ['chromosome_absolute_start', 'length', 'chromosome']):
        chrom_starts = df_chroms['chromosome_absolute_start']
        ax.vlines(chrom_starts, -PLOT_Y_LIM, PLOT_Y_LIM,
                  color='grey', linestyle='--', linewidth=0.5)

        # Draw vertical lines at p/q arm boundaries
        if not arm_map_df.empty and 'arm_abs_end' in arm_map_df.columns:
            # The end of the 'p' arm is the centromere / p-q boundary
            pq_boundaries = arm_map_df.loc[arm_map_df['arm'] == 'p', 'arm_abs_end']
            if not pq_boundaries.empty:
                ax.vlines(pq_boundaries, -PLOT_Y_LIM, PLOT_Y_LIM,
                          color='grey', linestyle=':', linewidth=0.3)

        mids = chrom_starts + df_chroms['length'] / 2
        ax.set_xticks(mids)
        ax.set_xticklabels(df_chroms['chromosome'].astype(str), rotation=90)

    xlim_max = df_chroms['chromosome_absolute_start'].iloc[-1] + df_chroms['length'].iloc[-1]

    ax.set(
        xlim=(0, xlim_max),
        ylim=(-PLOT_Y_LIM, PLOT_Y_LIM),
        xlabel='genomic position by chromosome',
        ylabel='copy number deviation (log2)'
    )

    plt.suptitle(f'CNV profile of {case_n_number} from NGS sample vs. EPIC reference', fontsize=14)
    plt.title(
        f"{case_info.get('sex', 'N/A')} (age {case_info.get('age', 'N/A')}), {case_info.get('diagnosis', 'N/A')}",
        fontsize=10, pad=6
    )

    ax.grid(False)
    ax.legend()
    plt.tight_layout()

    fname = PLOT_OUTPUT_DIR / 'scatter' / f"ngs_scatter_{case_n_number.replace('/', '_')}.png"
    plt.savefig(fname, dpi=300)
    print(f"Saved plot: {fname}")

    plt.show()

In [None]:
def calculate_arm_correlations(ngs_aggregated_df, epic_aggregated_df):
    """Calculates Pearson correlation between NGS and EPIC log2 values per arm."""
    if ngs_aggregated_df.empty or epic_aggregated_df.empty:
        return pd.DataFrame(columns=['chromosome', 'arm', 'pearson_r', 'p_value', 'n_samples'])

    merged_df = pd.merge(
        ngs_aggregated_df[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        epic_aggregated_df[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        on=['case_n_number', 'chromosome', 'segment_arm'],
        suffixes=('_ngs', '_epic'),
        how='inner'  # Only cases/arms present in both
    )
    if merged_df.empty: return pd.DataFrame(columns=['chromosome', 'arm', 'pearson_r', 'p_value', 'n_samples'])

    corr_results = []
    for (chrom, arm), group in merged_df.groupby(['chromosome', 'segment_arm']):
        if len(group) >= 3:  # Pearson r needs at least 2, but more is better
            r, p = pearsonr(group['log2_ngs'], group['log2_epic'])
        else:
            r, p = np.nan, np.nan
        corr_results.append({'chromosome': chrom, 'arm': arm, 'pearson_r': r, 'p_value': p, 'n_samples': len(group)})

    return pd.DataFrame(corr_results).sort_values(['chromosome', 'arm']).reset_index(drop=True)

In [None]:
def plot_correlation_scatter(merged_correlation_df, arm_corr_stats_df):
    """Plots scatter of NGS vs EPIC log2 values, faceted by chromosome, colored by arm, with correlation annotations."""
    if merged_correlation_df.empty:
        print("Merged data for correlation plot is empty. Skipping plot.")
        return

    g = sns.relplot(
        data=merged_correlation_df, x='log2_ngs', y='log2_epic',
        col='chromosome', hue='segment_arm', kind='scatter', palette=ARM_COLORS,
        col_wrap=6, height=2, aspect=1, s=20, alpha=0.7,
        facet_kws={'sharex': True, 'sharey': True}
    )

    for ax, chrom_name in zip(g.axes.flat, g.col_names):
        if pd.isna(chrom_name): continue  # Skip if chrom_name is NaN (can happen if few chromosomes)
        chrom_corrs_on_plot = arm_corr_stats_df[arm_corr_stats_df['chromosome'] == chrom_name]

        for arm_label, color in ARM_COLORS.items():
            if arm_label is None: continue  # Skip None arm for text annotation
            arm_data_for_text = chrom_corrs_on_plot[chrom_corrs_on_plot['arm'] == arm_label]

            if not arm_data_for_text.empty and not pd.isna(arm_data_for_text.iloc[0]['pearson_r']):
                r_val = arm_data_for_text.iloc[0]['pearson_r']
                p_val = arm_data_for_text.iloc[0]['p_value']

                text_label = f'r({arm_label})={r_val:.2f}'
                x_pos, ha = (0.05, 'left') if arm_label == 'p' else (0.95, 'right')

                ax.text(x_pos, 0.95, text_label, transform=ax.transAxes, color=color,
                        ha=ha, va='top', fontsize=8,
                        bbox=dict(facecolor='#d4edda', alpha=0.6, edgecolor='none', pad=0.2))

    g.set_axis_labels('ngs cn ratio (log2)', 'epic cn ratio (log2)', fontsize=10)
    g.fig.subplots_adjust(top=0.92)  # Adjust top for suptitle
    g.fig.suptitle('NGS vs EPIC: copy number correlation by chromosome and arm', fontsize=14)

    fname = PLOT_OUTPUT_DIR / 'correlation' / "ngs_epic_arm_correlation.png"
    plt.savefig(fname, dpi=300)
    print(f"Saved correlation plot: {fname}")
    plt.show()

In [None]:
cases_df = load_and_preprocess_cases(CASES_FILE_PATH)
cases_df

In [None]:
cytoband_df = load_and_preprocess_cytoband(CYTOBAND_FILE_PATH)
cytoband_df

In [None]:
genes_df = load_and_preprocess_relevant_genes(RELEVANT_GENES_FILE_PATH)
genes_df

In [None]:
print(
    f"Base Dir: {BASE_DATA_DIR.resolve()}, Plot Dir: {PLOT_OUTPUT_DIR.resolve()}, Corr. Plot Dir: {PLOT_OUTPUT_DIR.resolve()}")

chroms_df, chrom_map = pd.DataFrame(), {}
arm_map_df, arm_lookup_table = pd.DataFrame(), pd.DataFrame()
gene_trees = {}

if not cytoband_df.empty:
    chroms_df, chrom_map = calculate_chromosome_features(cytoband_df)
    arm_map_df = create_arm_mapping_df(cytoband_df, chrom_map)
    arm_lookup_table = create_arm_lookup_table(arm_map_df)
if not genes_df.empty:
    gene_trees = build_gene_interval_trees(genes_df)
gene_trees

In [None]:
# Process EPIC data - both raw and segments
epic_raw_df = pd.DataFrame()
epic_segments_df = pd.DataFrame()
epic_df_aggregated_all = pd.DataFrame()

if not cases_df.empty and not arm_lookup_table.empty and not arm_map_df.empty:
    print("=== Loading EPIC data ===")

    # Load EPIC raw data from IGV files
    print("\n--- Loading EPIC raw data from IGV files ---")
    igv_dir = BASE_DATA_DIR / 'epic/igv'
    if igv_dir.exists():
        igv_files = list(igv_dir.glob('*.igv'))
        print(f"Found {len(igv_files)} IGV files in {igv_dir}")
    else:
        print(f"Warning: IGV directory does not exist: {igv_dir}")

    epic_raw_df = load_combine_genomic_data(cases_df, BASE_DATA_DIR, 'epic', 'raw', 'sentrix_id')
    print(f"EPIC raw data loaded: {epic_raw_df.shape}")

    # Load EPIC segments from SEG files  
    print("\n--- Loading EPIC segments from SEG files ---")
    seg_dir = BASE_DATA_DIR / 'epic/seg'
    if seg_dir.exists():
        seg_files = list(seg_dir.glob('*.seg'))
        print(f"Found {len(seg_files)} SEG files in {seg_dir}")
    else:
        print(f"Warning: SEG directory does not exist: {seg_dir}")

    epic_segments_df = load_combine_genomic_data(cases_df, BASE_DATA_DIR, 'epic', 'segments', 'sentrix_id')
    print(f"EPIC segments loaded: {epic_segments_df.shape}")

    # Process EPIC segments for arm aggregation (using segments for correlation analysis)
    if not epic_segments_df.empty:
        print("\n--- Processing EPIC segments for arm aggregation ---")
        epic_classified = add_segment_arm_classification(epic_segments_df.copy(), arm_lookup_table)
        print(f"EPIC segments classified: {epic_classified.shape}")

        epic_df_aggregated_all = calculate_arm_log2_from_segments(epic_classified, arm_map_df, 'sentrix_id')
        print(f"EPIC aggregated data: {epic_df_aggregated_all.shape}")
    else:
        print("No EPIC segments loaded - check if SEG files exist and have correct naming")

print(f"\nEPIC data summary:")
print(f"- Raw data (IGV): {epic_raw_df.shape}")
print(f"- Segments (SEG): {epic_segments_df.shape}")
print(f"- Aggregated: {epic_df_aggregated_all.shape}")

epic_df_aggregated_all

In [None]:
ngs_raw_df = pd.DataFrame()  # Raw CNR data (for scatter plots)
ngs_segments_df = pd.DataFrame()  # CNS segments (for heatmaps and analysis)
ngs_df_processed_full = pd.DataFrame()  # Processed raw data (for scatter plots)
ngs_arm_aggregated_all = pd.DataFrame()  # Arm-level aggregation (for correlation)
ngs_gene_reps_all = pd.DataFrame()  # Gene-level aggregation (for gene labels)

if not cases_df.empty:
    print("=== Loading NGS data ===")

    # Load NGS raw data from CNR files
    print("\n--- Loading NGS raw data from CNR files ---")
    cnr_dir = BASE_DATA_DIR / 'ngs/cnr'
    if cnr_dir.exists():
        cnr_files = list(cnr_dir.glob('*.cnr'))
        print(f"Found {len(cnr_files)} CNR files in {cnr_dir}")
    else:
        print(f"Warning: CNR directory does not exist: {cnr_dir}")

    ngs_raw_df = load_combine_genomic_data(cases_df, BASE_DATA_DIR, 'ngs', 'raw', 'barcode')
    print(f"NGS raw data loaded: {ngs_raw_df.shape}")

    # Load NGS segments from CNS files
    print("\n--- Loading NGS segments from CNS files ---")
    cns_dir = BASE_DATA_DIR / 'ngs/cns'
    if cns_dir.exists():
        cns_files = list(cns_dir.glob('*.cns'))
        print(f"Found {len(cns_files)} CNS files in {cns_dir}")
    else:
        print(f"Warning: CNS directory does not exist: {cns_dir}")

    ngs_segments_df = load_combine_genomic_data(cases_df, BASE_DATA_DIR, 'ngs', 'segments', 'barcode')
    print(f"NGS segments loaded: {ngs_segments_df.shape}")

    # Process NGS raw data for scatter plots
    if not ngs_raw_df.empty:
        print("\n--- Processing NGS raw data ---")
        ngs_curr = ngs_raw_df.copy()

        # Add absolute genomic positions if chromosome map is available
        if chrom_map:
            ngs_curr = prep_ngs_agg(ngs_curr, chrom_map)

        # Classify each segment/bin by chromosome arm (p/q/spanning) if lookup table is available
        if not arm_lookup_table.empty:
            ngs_curr = add_segment_arm_classification(ngs_curr, arm_lookup_table)

        # Annotate with relevant genes and aggregate to gene level if gene trees are available
        if gene_trees:
            ngs_curr = annotate_segments_with_genes(ngs_curr, gene_trees, genes_df)
            ngs_gene_reps_all = aggregate_data_by_gene(ngs_curr)

        # Aggregate NGS points to chromosome arms for correlation/plotting if possible
        if not chroms_df.empty and 'segment_arm' in ngs_curr.columns:
            ngs_arm_aggregated_all = agg_ngs_points_to_arms(ngs_curr, chroms_df)

        # Store the fully processed NGS DataFrame
        ngs_df_processed_full = ngs_curr

    # Process NGS segments for additional analysis if needed
    if not ngs_segments_df.empty:
        print("\n--- Processing NGS segments ---")
        # Add absolute genomic positions for segments
        if chrom_map:
            ngs_segments_df = prep_ngs_agg(ngs_segments_df, chrom_map)

        # Classify segments by chromosome arm
        if not arm_lookup_table.empty:
            ngs_segments_df = add_segment_arm_classification(ngs_segments_df, arm_lookup_table)

print(f"\nNGS data summary:")
print(f"- Raw data (CNR): {ngs_raw_df.shape}")
print(f"- Segments (CNS): {ngs_segments_df.shape}")
print(f"- Processed raw: {ngs_df_processed_full.shape}")
print(f"- Arm aggregated: {ngs_arm_aggregated_all.shape}")
print(f"- Gene reps: {ngs_gene_reps_all.shape}")

# Show the processed NGS DataFrame for inspection
ngs_df_processed_full

In [None]:
# Generate CNV plots for each case
if not cases_df.empty and not ngs_df_processed_full.empty:
    print(f"\n--- Generating CNV plots for {len(cases_df)} cases ---")
    print(f"EPIC data shape: {epic_df_aggregated_all.shape}")
    print(f"NGS arm aggregated shape: {ngs_arm_aggregated_all.shape}")
    print(f"NGS gene reps shape: {ngs_gene_reps_all.shape}")

    for _, case_r in cases_df.iterrows():
        cnum = case_r['case_n_number']
        ngs_c = ngs_df_processed_full[ngs_df_processed_full['case_n_number'] == cnum].copy()
        ngs_agg_c = ngs_arm_aggregated_all[ngs_arm_aggregated_all[
                                               'case_n_number'] == cnum].copy() if not ngs_arm_aggregated_all.empty else pd.DataFrame()
        epic_agg_c = epic_df_aggregated_all[epic_df_aggregated_all[
                                                'case_n_number'] == cnum].copy() if not epic_df_aggregated_all.empty else pd.DataFrame()
        gene_reps_c = ngs_gene_reps_all[
            ngs_gene_reps_all['case_n_number'] == cnum].copy() if not ngs_gene_reps_all.empty else pd.DataFrame()

        if ngs_c.empty and ngs_agg_c.empty and epic_agg_c.empty and gene_reps_c.empty:
            print(f"Skipping plot for {cnum}: No data.")
            continue
        draw_cnv_plot(case_r, ngs_c, ngs_agg_c, epic_agg_c, gene_reps_c, chroms_df, arm_map_df)

In [None]:
# Calculate and plot correlations
if not ngs_arm_aggregated_all.empty and not epic_df_aggregated_all.empty:
    print("\n--- Calculating and plotting NGS vs EPIC arm correlations ---")
    arm_correlations_df = calculate_arm_correlations(ngs_arm_aggregated_all, epic_df_aggregated_all)
    print("Arm correlation:\n", arm_correlations_df)

    # Need the merged df for plotting points (not just stats)
    merged_for_plot = pd.merge(
        ngs_arm_aggregated_all[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        epic_df_aggregated_all[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        on=['case_n_number', 'chromosome', 'segment_arm'],
        suffixes=('_ngs', '_epic'), how='inner'
    )
    plot_correlation_scatter(merged_for_plot, arm_correlations_df)
else:
    print("Skipping correlation analysis: Aggregated NGS or EPIC data is missing.")

In [None]:
def plot_cnv_heatmap_new(ngs_data, epic_data, cases_info, chrom_info, title, ax,
                     show_tumor_type_axis=True, group_by_column=None):
    """
    Generates a CNV heatmap with a robust, unified logic for all plotting scenarios.
    - If group_by_column is provided (e.g., 'tumor_type'), it creates visually
      distinct sub-diagrams for each group with headers and dividers.
      IMPORTANT: cases_info MUST be sorted by the group_by_column for this to work.
    """
    # --- Setup genomic coordinate system ---
    chrom_starts = chrom_info.set_index('chromosome')['chromosome_absolute_start']
    chrom_sizes = chrom_info.set_index('chromosome')['length']
    chrom_ends = chrom_starts + chrom_sizes
    total_genome_length = chrom_ends.max()
    bin_size = 1_000_000
    total_bins = int(np.ceil(total_genome_length / bin_size))

    # --- Determine mode and setup dimensions ---
    single_row_mode = epic_data.empty or ngs_data.empty
    rows_per_patient = 1 if single_row_mode else 2
    num_patients = len(cases_info)

    # --- Calculate total rows needed, including gaps for headers ---
    header_gap_size = 2 # Space for the header and visual separation
    num_groups = 0
    if group_by_column and group_by_column in cases_info.columns and num_patients > 0:
        num_groups = cases_info[group_by_column].nunique()
    total_rows = (num_patients * rows_per_patient) + (num_groups * header_gap_size)
    heatmap_array = np.full((total_rows, total_bins), np.nan)

    # --- Unified Loop to Populate Data and Store Tick Positions ---
    y_ticks_positions = []
    y_tick_labels = []
    current_row = 0
    last_group_id = None

    for i in range(num_patients):
        case_info = cases_info.iloc[i]

        if group_by_column:
            current_group_id = case_info[group_by_column]
            # If a new group starts, create the header section
            if current_group_id != last_group_id:
                # Add a thick divider line above the header, but not for the very first group
                if last_group_id is not None:
                    ax.axhline(y=current_row - 0.5, color='black', linewidth=1.5)

                # Place the header text in the middle of the dedicated gap space
                header_y_pos = current_row + (header_gap_size / 2) - 0.5
                ax.text(0, header_y_pos, current_group_id,
                        ha='left', va='center',
                        fontsize=9, fontweight='bold', color='black')

                # Advance the row cursor to create the gap
                current_row += header_gap_size
                last_group_id = current_group_id

        # Store positions for patient labels
        y_ticks_positions.append(current_row + (rows_per_patient / 2 - 0.5))
        y_tick_labels.append('P' + str(case_info['patient_id']))

        # Populate heatmap data for the current patient
        case_id = case_info['case_n_number']
        if single_row_mode:
            data_source = ngs_data if epic_data.empty else epic_data
            case_data = data_source[data_source['case_n_number'] == case_id]
            for _, seg in case_data.iterrows():
                if seg['chromosome'] in chrom_starts.index:
                    abs_start = chrom_starts[seg['chromosome']] + seg['start']
                    abs_end = chrom_starts[seg['chromosome']] + seg['end']
                    s_bin, e_bin = int(abs_start / bin_size), int(np.ceil(abs_end / bin_size))
                    heatmap_array[current_row, s_bin:e_bin] = seg['log2']
        else: # Two-row mode
            epic_row_idx, ngs_row_idx = current_row, current_row + 1
            if not epic_data.empty:
                epic_case_data = epic_data[epic_data['case_n_number'] == case_id]
                for _, seg in epic_case_data.iterrows():
                    if seg['chromosome'] in chrom_starts.index:
                        abs_start = chrom_starts[seg['chromosome']] + seg['start']
                        abs_end = chrom_starts[seg['chromosome']] + seg['end']
                        s_bin, e_bin = int(abs_start / bin_size), int(np.ceil(abs_end / bin_size))
                        heatmap_array[epic_row_idx, s_bin:e_bin] = seg['log2']
            if not ngs_data.empty:
                ngs_case_data = ngs_data[ngs_data['case_n_number'] == case_id]
                for _, seg in ngs_case_data.iterrows():
                    if seg['chromosome'] in chrom_starts.index:
                        abs_start = chrom_starts[seg['chromosome']] + seg['start']
                        abs_end = chrom_starts[seg['chromosome']] + seg['end']
                        s_bin, e_bin = int(abs_start / bin_size), int(np.ceil(abs_end / bin_size))
                        heatmap_array[ngs_row_idx, s_bin:e_bin] = seg['log2']

        # Draw Horizontal Lines Within the Loop
        if not single_row_mode:
            ax.axhline(y=current_row + 0.5, color='grey', linestyle=':', linewidth=0.5)

        # Always draw a thin line between patients within a group
        if i < num_patients - 1:
             # Check if the next patient is in the same group
            if not group_by_column or (cases_info.iloc[i+1][group_by_column] == case_info[group_by_column]):
                ax.axhline(y=current_row + rows_per_patient - 0.5, color='black', linestyle='-', linewidth=0.5)

        current_row += rows_per_patient

    # --- Plot the final matrix ---
    cmap = plt.get_cmap('RdBu_r'); cmap.set_bad(color='white')
    norm = mcolors.Normalize(vmin=-2, vmax=2)
    clipped_data = np.clip(heatmap_array, -2, 2)
    ax.imshow(clipped_data, cmap=cmap, norm=norm, aspect='auto', interpolation='none')

    # --- Configure axes ---
    ax.grid(False)
    ax.set_yticks(y_ticks_positions)
    ax.set_yticklabels(y_tick_labels, fontsize=8)

    chrom_mid_bins = (chrom_starts + chrom_sizes / 2) / bin_size
    ax.set_xticks(chrom_mid_bins)
    ax.set_xticklabels(['chr' + str(c) for c in chrom_mid_bins.index], rotation=90, fontsize=8)

    if show_tumor_type_axis:
        ax2 = ax.twinx()
        ax2.set_ylim(ax.get_ylim())
        ax2.set_yticks(y_ticks_positions)
        tumor_labels = cases_info['tumor_type'].tolist()
        ax2.set_yticklabels(tumor_labels, fontsize=8)
        ax2.grid(False)
        ax2.tick_params(axis='y', colors='black', direction='out')
        for spine in ax2.spines.values():
            spine.set_edgecolor('black')
        unique_tumor_types = sorted(list(set(tumor_labels)))
        cmap_tumor = plt.get_cmap('Dark2', len(unique_tumor_types))
        tumor_type_to_color = {tt: cmap_tumor(i) for i, tt in enumerate(unique_tumor_types)}
        for tick_label, tumor_type in zip(ax2.get_yticklabels(), tumor_labels):
            tick_label.set_color(tumor_type_to_color.get(tumor_type, 'black'))
        ax2.set_ylabel('Tumor type', fontsize=10)

    ax.tick_params(axis='y', colors='black', direction='out')
    for spine in ax.spines.values():
        spine.set_edgecolor('black')

    chrom_end_bins = chrom_ends / bin_size
    for end_bin in chrom_end_bins.values[:-1]:
        boundary_pos = np.ceil(end_bin) - 0.5
        ax.axvline(x=boundary_pos, color='black', linestyle='--', linewidth=0.25)

    ax.set_xlim(-0.5, total_bins - 0.5)
    ax.set_ylim(total_rows - 0.5, -0.5) # Set ylim explicitly to match total rows
    ax.set_title(title, fontsize=14)
    ax.set_xlabel('Chromosome', fontsize=10)
    ax.set_ylabel('Sample', fontsize=10)

    return ax

In [None]:
def plot_cnv_heatmap(ngs_data, epic_data, cases_info, chrom_info, title, ax):
    # ... (Setup code is the same) ...
    # 1. SETUP & 2. POPULATE MATRIX (This part is correct)
    chrom_starts = chrom_info.set_index('chromosome')['chromosome_absolute_start']
    chrom_sizes = chrom_info.set_index('chromosome')['length']
    chrom_ends = chrom_starts + chrom_sizes
    total_genome_length = chrom_ends.max()
    bin_size = 1_000_000
    total_bins = int(np.ceil(total_genome_length / bin_size))
    patient_ids_subset = cases_info['patient_id'].tolist()
    num_patients = len(patient_ids_subset)
    heatmap_array = np.full((num_patients * 2, total_bins), np.nan)
    patient_to_case_map = cases_info.set_index('patient_id')['case_n_number'].to_dict()
    for i, patient_id in enumerate(patient_ids_subset):
        case_id = patient_to_case_map.get(patient_id)
        if case_id is None: continue
        epic_row_idx, ngs_row_idx = i * 2, i * 2 + 1
        if not epic_data.empty:
            epic_case_data = epic_data[epic_data['case_n_number'] == case_id]
            for _, seg in epic_case_data.iterrows():
                if seg['chromosome'] not in chrom_starts.index: continue
                abs_start = chrom_starts[seg['chromosome']] + seg['start']
                abs_end = chrom_starts[seg['chromosome']] + seg['end']
                s_bin, e_bin = int(abs_start / bin_size), int(np.ceil(abs_end / bin_size))
                heatmap_array[epic_row_idx, s_bin:e_bin] = seg['log2']
        if not ngs_data.empty:
            ngs_case_data = ngs_data[ngs_data['case_n_number'] == case_id]
            for _, seg in ngs_case_data.iterrows():
                if seg['chromosome'] not in chrom_starts.index: continue
                abs_start = chrom_starts[seg['chromosome']] + seg['start']
                abs_end = chrom_starts[seg['chromosome']] + seg['end']
                s_bin, e_bin = int(abs_start / bin_size), int(np.ceil(abs_end / bin_size))
                heatmap_array[ngs_row_idx, s_bin:e_bin] = seg['log2']

    # 3. PLOT THE MATRIX
    cmap = plt.get_cmap('RdBu_r'); cmap.set_bad(color='#fbfbfb')
    norm = mcolors.Normalize(vmin=-2, vmax=2)
    clipped_data = np.clip(heatmap_array, -2, 2)
    ax.imshow(clipped_data, cmap=cmap, norm=norm, aspect='auto', interpolation='none')

    # 4. CONFIGURE AXES AND LABELS
    ax.grid(False)
    y_ticks = [i * 2 + 0.5 for i in range(num_patients)]
    ax.set_yticks(y_ticks)
    ax.set_yticklabels(['P' + str(pid) for pid in patient_ids_subset], fontsize=8)

    chrom_mid_bins = (chrom_starts + chrom_sizes / 2) / bin_size
    ax.set_xticks(chrom_mid_bins)
    ax.set_xticklabels(['chr' + str(c) for c in chrom_mid_bins.index], rotation=90, fontsize=8)

    ax2 = ax.twinx()
    ax2.set_ylim(ax.get_ylim())
    ax2.set_yticks(y_ticks)

    tumor_labels = cases_info['tumor_type'].tolist()[::-1]
    ax2.set_yticklabels(tumor_labels, fontsize=8)

    ax2.grid(False)
    ax.tick_params(axis='y', colors='black', direction='out')
    ax2.tick_params(axis='y', colors='black', direction='out')

    for spine in ax.spines.values():
        spine.set_edgecolor('black')
    for spine in ax2.spines.values():
        spine.set_edgecolor('black')

    unique_tumor_types = sorted(list(set(tumor_labels)))
    cmap_tumor = plt.get_cmap('Dark2', len(unique_tumor_types))
    tumor_type_to_color = {tt: cmap_tumor(i) for i, tt in enumerate(unique_tumor_types)}

    for tick_label, tumor_type in zip(ax2.get_yticklabels(), tumor_labels):
        tick_label.set_color(tumor_type_to_color.get(tumor_type, 'black'))

    chrom_end_bins = chrom_ends / bin_size
    for end_bin in chrom_end_bins.values[:-1]:
        boundary_pos = np.ceil(end_bin) - 0.5
        ax.axvline(x=boundary_pos, color='black', linestyle='--', linewidth=0.25)

    for i in range(num_patients):
        ax.axhline(y=(i * 2) + 0.5, color='grey', linestyle='--',
                    linewidth=0.15, antialiased=False)
        if i < num_patients - 1:
            ax.axhline(y=(i * 2) + 1.5, color='black', linestyle='-', linewidth=0.75)

    ax.set_xlim(-0.5, total_bins - 0.5)
    ax.invert_yaxis()
    ax.set_title(title, fontsize=14)
    ax.set_xlabel('Chromosome', fontsize=10)
    ax.set_ylabel('Sample', fontsize=10)
    ax2.set_ylabel('Tumor type', fontsize=10)

    return ax

In [None]:
# --- Generate a single CNV Heatmap for all cases ---

if not cases_df.empty and not ngs_segments_df.empty and not epic_segments_df.empty:

    print(f"--- Generating a single CNV heatmap for all {len(cases_df)} cases ---")

    cases_sorted = cases_df.sort_values(['patient_id', 'tumor_type'], ascending=False).reset_index(drop=True)

    cases_subset = cases_sorted

    ngs_subset = ngs_segments_df
    epic_subset = epic_segments_df

    if ngs_subset.empty or epic_subset.empty:
        print(f"Skipping heatmap: Missing NGS or EPIC segment data.")
    else:
        fig, ax = plt.subplots(figsize=(8, len(cases_subset) * 0.175))

        plot_cnv_heatmap(
            ngs_data=ngs_subset,
            epic_data=epic_subset,
            cases_info=cases_subset,
            chrom_info=chroms_df,
            title='CNV heatmap from EPIC (upper row) vs. NGS (lower row)',
            ax=ax,
        )

        plt.tight_layout()

        fname = PLOT_OUTPUT_DIR / 'heatmap' / "cnv_heatmap_all_cases.png"
        fname.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(fname, dpi=300, bbox_inches='tight')
        print(f"Saved heatmap: {fname}")

        plt.show()
else:
    print("Skipping heatmap generation: Missing necessary data (cases, NGS segments, or EPIC segments).")

In [None]:
print("--- Starting batch heatmap generation ---")

if cases_df.empty or ngs_segments_df.empty or epic_segments_df.empty or chroms_df.empty:
    print("Skipping heatmap generation: Missing necessary data.")
else:
    print("\n1. Generating one heatmap per tumor type...")
    unique_tumor_types = cases_df['tumor_type'].unique()
    for tumor_type in unique_tumor_types:
        cases_subset = cases_df[cases_df['tumor_type'] == tumor_type].sort_values('patient_id').reset_index(drop=True)
        if cases_subset.empty: continue
        fig, ax = plt.subplots(figsize=(12, len(cases_subset) * 0.3))
        plot_cnv_heatmap(ngs_data=ngs_segments_df[ngs_segments_df['case_n_number'].isin(cases_subset['case_n_number'])],
                         epic_data=epic_segments_df[epic_segments_df['case_n_number'].isin(cases_subset['case_n_number'])],
                         cases_info=cases_subset, chrom_info=chroms_df, title=f'CNV Heatmap for {tumor_type}', ax=ax,
                         show_tumor_type_axis=False)
        safe_tumor_name = tumor_type.replace(' ', '_').replace('/', '_')
        fname = PLOT_OUTPUT_DIR / 'heatmap' / f"cnv_heatmap_{safe_tumor_name}.png"
        fname.parent.mkdir(parents=True, exist_ok=True); plt.savefig(fname, dpi=300, bbox_inches='tight'); plt.show(); plt.close(fig)

    print("\n2. Generating sub-diagram heatmaps for all EPIC and all NGS data...")

    cases_sorted_all = cases_df.sort_values(['patient_id', 'tumor_type']).reset_index(drop=True)

    num_gaps = cases_sorted_all['tumor_type'].nunique() - 1
    extra_height = num_gaps * 0.4 # Heuristic for sizing

    print("  - Plotting grouped EPIC data for all samples...")
    fig_epic, ax_epic = plt.subplots(figsize=(12, len(cases_sorted_all) * 0.2 + extra_height))
    plot_cnv_heatmap(
        ngs_data=pd.DataFrame(),
        epic_data=epic_segments_df,
        cases_info=cases_sorted_all,
        chrom_info=chroms_df,
        title='CNV heatmap (EPIC)',
        ax=ax_epic,
        show_tumor_type_axis=False,
        group_by_column='tumor_type'
    )
    fname_epic = PLOT_OUTPUT_DIR / 'heatmap' / "cnv_heatmap_all_cases_EPIC_grouped.png"
    plt.savefig(fname_epic, dpi=300, bbox_inches='tight')
    print(f"  - Saved grouped EPIC heatmap: {fname_epic}")
    plt.show()
    plt.close(fig_epic)

    print("  - Plotting grouped NGS data for all samples...")
    fig_ngs, ax_ngs = plt.subplots(figsize=(12, len(cases_sorted_all) * 0.2 + extra_height))
    plot_cnv_heatmap(
        ngs_data=ngs_segments_df,
        epic_data=pd.DataFrame(),
        cases_info=cases_sorted_all,
        chrom_info=chroms_df,
        title='CNV heatmap (NGS)',
        ax=ax_ngs,
        show_tumor_type_axis=False,
        group_by_column='tumor_type'
    )
    fname_ngs = PLOT_OUTPUT_DIR / 'heatmap' / "cnv_heatmap_all_cases_NGS_grouped.png"
    plt.savefig(fname_ngs, dpi=300, bbox_inches='tight')
    print(f"  - Saved grouped NGS heatmap: {fname_ngs}")
    plt.show()
    plt.close(fig_ngs)

print("\n--- Batch heatmap generation complete ---")

In [None]:
arm_correlations_df['pearson_r'].mean().round(3)

In [None]:
arm_correlations_df['p_value'].mean().round(6)

In [None]:
### Stratified Correlation Analysis by DNA Quality (DIN)
# This analysis separates the cohort into two groups based on the DIN value
# and calculates the per-arm correlations for each group.

print("=== Starting Stratified Correlation Analysis by DIN ===")

if 'DIN' not in cases_df.columns:
    print("Warning: 'DIN' column not found in cases_df. Skipping stratified analysis.")
else:
    # Stratify cases into two groups
    low_din_cases = cases_df[cases_df['DIN'] <= 6]['case_n_number'].tolist()
    high_din_cases = cases_df[cases_df['DIN'] > 6]['case_n_number'].tolist()

    print(f"\nLow DIN (<= 6) group size: {len(low_din_cases)} cases")
    print(f"High DIN (> 6) group size: {len(high_din_cases)} cases")

    def run_and_print_stratified_correlation(ngs_agg, epic_agg, cases_list, group_title):
        print(f"\n--- {group_title} ---")
        if not cases_list or ngs_agg.empty or epic_agg.empty:
            print("Not enough data to proceed.")
            return

        # Filter the main aggregated dataframes by the stratified case lists
        ngs_subset = ngs_agg[ngs_agg['case_n_number'].isin(cases_list)]
        epic_subset = epic_agg[epic_agg['case_n_number'].isin(cases_list)]

        if ngs_subset.empty or epic_subset.empty:
            print("Subsets for NGS or EPIC are empty for this group.")
            return

        # Calculate correlations for the subset
        arm_corr_stats = calculate_arm_correlations(ngs_subset, epic_subset)

        if arm_corr_stats.empty:
            print("No overlapping arms found to calculate correlation.")
            return

        print(f"Mean r: {arm_corr_stats['pearson_r'].mean():.3f}")
        print(f"Mean p-value: {arm_corr_stats['p_value'].mean():.5f}")
        print("Per-arm correlation statistics:")
        print(arm_corr_stats)

    # Run the analysis for both groups
    run_and_print_stratified_correlation(ngs_arm_aggregated_all, epic_df_aggregated_all, low_din_cases, 'Low DIN Group (DIN <= 6)')
    run_and_print_stratified_correlation(ngs_arm_aggregated_all, epic_df_aggregated_all, high_din_cases, 'High DIN Group (DIN > 6)')