In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from intervaltree import IntervalTree
from matplotlib import pyplot as plt
from scipy.stats import pearsonr, zscore

sns.set_theme(style='ticks')

PLOT_Y_LIM = 10
GENE_ANNOTATION_PADDING = 1_000_000
ARM_COLORS = {'p': 'tab:blue', 'q': 'tab:orange', 'spanning': 'tab:green', None: 'grey'}

BASE_DATA_DIR = Path('data')
PLOT_OUTPUT_DIR = Path('graphics')
PLOT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

CASES_FILE_PATH = BASE_DATA_DIR / 'cases.csv'
CYTOBAND_FILE_PATH = BASE_DATA_DIR / 'cytoBand.txt'
RELEVANT_GENES_FILE_PATH = BASE_DATA_DIR / 'relevant_genes.csv'

# For cases_df processing
CASES_COLUMNS_INDICES = [1, 2, 3, 5, 13, 14]
CASES_NEW_COLUMN_NAMES = ['case_n_number', 'age', 'sex', 'diagnosis', 'sentrix_id', 'barcode']

# For cytoband_df processing
CYTOBAND_COLUMN_NAMES = ['chromosome', 'start', 'end', 'band', 'giemsa']

# For relevant_genes_df processing
RELEVANT_GENES_COLUMN_NAMES = ['gene', 'chromosome', 'start', 'end']

In [None]:
def _sanitize_chromosome_column(df, chromosome_col='chromosome'):
    """
    Converts a chromosome column to numeric Int64Dtype, removing 'chr' prefix.
    Rows with unparsable chromosome values (e.g., 'X', 'Y') are dropped.
    """
    if chromosome_col not in df.columns:
        return df

    df = df.copy()
    df[chromosome_col] = (
        df[chromosome_col]
        .astype(str)
        .str.replace('chr', '', regex=False)
        .pipe(pd.to_numeric, errors='coerce')
    )
    df = df.dropna(subset=[chromosome_col])
    df[chromosome_col] = df[chromosome_col].astype(pd.Int64Dtype())
    return df

In [None]:
def load_and_preprocess_cases(file_path):
    """
    Loads and preprocesses the cases data from a CSV file.
    """
    try:
        df = pd.read_csv(file_path, delimiter=',', encoding='ISO-8859-1')
    except FileNotFoundError:
        print(f"Error: Cases file not found at {file_path}")
        return pd.DataFrame()

    df = df.iloc[:, CASES_COLUMNS_INDICES]
    df.columns = CASES_NEW_COLUMN_NAMES

    df = df.dropna().reset_index(drop=True)
    return df

In [None]:
def load_and_preprocess_cytoband(file_path):
    try:
        df = pd.read_csv(file_path, delimiter='\t', names=CYTOBAND_COLUMN_NAMES)
    except FileNotFoundError:
        print(f"Error: Cytoband file not found at {file_path}")
        return pd.DataFrame()
    return _sanitize_chromosome_column(df, 'chromosome')


def calculate_chromosome_features(cytoband_df):
    if cytoband_df.empty:
        return pd.DataFrame(columns=['chromosome', 'length', 'chromosome_absolute_start']), {}
    chromosomes_df = (
        cytoband_df
        .groupby('chromosome')
        .agg(length=('end', 'max'))
        .reset_index()
        .sort_values('chromosome')
    )
    chromosomes_df['chromosome_absolute_start'] = chromosomes_df['length'].cumsum() - chromosomes_df['length']
    chromosome_start_map = chromosomes_df.set_index('chromosome')['chromosome_absolute_start'].to_dict()
    return chromosomes_df, chromosome_start_map


def create_arm_mapping_df(cytoband_df, chromosome_start_map):
    if cytoband_df.empty or not chromosome_start_map:
        return pd.DataFrame(columns=['chromosome', 'arm', 'arm_start', 'arm_end', 'arm_abs_start', 'arm_abs_end'])
    chromosome_starts_series = pd.Series(chromosome_start_map, name='chr_abs_start').reset_index().rename(
        columns={'index': 'chromosome_num'})
    arm_df = (
        cytoband_df
        .assign(arm=lambda x: x['band'].astype(str).str[0].str.lower())
        .groupby(['chromosome', 'arm'])
        .agg(arm_start=('start', 'min'), arm_end=('end', 'max'))
        .reset_index()
        .merge(chromosome_starts_series, left_on='chromosome', right_on='chromosome_num')
        .assign(
            arm_abs_start=lambda x: x['chr_abs_start'] + x['arm_start'],
            arm_abs_end=lambda x: x['chr_abs_start'] + x['arm_end']
        )
        [['chromosome', 'arm', 'arm_start', 'arm_end', 'arm_abs_start', 'arm_abs_end']]
    )
    return arm_df


def create_arm_lookup_table(arm_mapping_df):
    """Pivots arm_mapping_df for easy lookup of p/q arm start/end."""
    if arm_mapping_df.empty:
        return pd.DataFrame()
    arm_lookup = arm_mapping_df.pivot(index='chromosome', columns='arm',
                                      values=['arm_start', 'arm_end'])
    if not arm_lookup.empty:  # Ensure columns exist before renaming
        # Dynamically create new column names based on available arms (p, q or both)
        new_cols = []
        for val_type in ['arm_start', 'arm_end']:
            for arm_type in ['p', 'q']:  # Assuming p and q are the primary arms
                if (val_type, arm_type) in arm_lookup.columns:
                    new_cols.append(f'{arm_type}_{val_type.split("_")[1]}')  # e.g. p_start, q_end
                else:  # Handle cases where an arm might be missing for a chromosome (though rare for p/q)
                    new_cols.append(None)  # Placeholder for missing data

        # Filter out None placeholders if any arm was completely missing
        # And ensure the resulting columns match expected structure (p_start, p_end, q_start, q_end)
        # This part needs to be robust if arm_mapping_df is sparse for some chromosomes
        # For simplicity, assuming 'p' and 'q' arms are typically present for autosomes.
        # If a chromosome only has 'p' or 'q', pivot will result in NaNs for the other.
        cols_to_rename = {}
        if ('arm_start', 'p') in arm_lookup.columns: cols_to_rename[('arm_start', 'p')] = 'p_start'
        if ('arm_end', 'p') in arm_lookup.columns: cols_to_rename[('arm_end', 'p')] = 'p_end'
        if ('arm_start', 'q') in arm_lookup.columns: cols_to_rename[('arm_start', 'q')] = 'q_start'
        if ('arm_end', 'q') in arm_lookup.columns: cols_to_rename[('arm_end', 'q')] = 'q_end'

        arm_lookup.columns = ['_'.join(col).strip() for col in arm_lookup.columns.values]  # Flatten MultiIndex
        arm_lookup = arm_lookup.rename(columns={
            'arm_start_p': 'p_start', 'arm_end_p': 'p_end',
            'arm_start_q': 'q_start', 'arm_end_q': 'q_end'
        })
        # Ensure all four columns exist, adding them with NaNs if not created by pivot
        for col in ['p_start', 'p_end', 'q_start', 'q_end']:
            if col not in arm_lookup.columns:
                arm_lookup[col] = np.nan
    return arm_lookup

In [None]:
def load_and_preprocess_relevant_genes(file_path):
    try:
        df = pd.read_csv(file_path, delimiter=';', names=RELEVANT_GENES_COLUMN_NAMES)
    except FileNotFoundError:
        print(f"Error: Relevant genes file not found at {file_path}")
        return pd.DataFrame()
    # _sanitize_chromosome_column ensures 'chromosome' is int
    return _sanitize_chromosome_column(df, 'chromosome')


def build_gene_interval_trees(relevant_genes_df):
    """Builds a dictionary of IntervalTrees for gene lookups, per chromosome."""
    if relevant_genes_df.empty: return {}
    trees = {}
    for _, row in relevant_genes_df.iterrows():
        chrom = row.chromosome
        if chrom not in trees:
            trees[chrom] = IntervalTree()
        # Define interval for gene: from its start to start + 1Mb
        # This uses the gene's start from the file, not its end.
        start = int(row.start)
        end = int(row.end)
        trees[chrom].addi(start, end, {'gene': row.gene, 'start': start, 'end': end})

    return trees


def annotate_segments_with_genes(df, gene_interval_trees, padding=GENE_ANNOTATION_PADDING):
    """Annotates DataFrame segments with overlapping genes from interval trees."""
    if df.empty or not gene_interval_trees:
        df['gene'] = None
        return df

    gene_annotations = []
    for _, row in df.iterrows():
        chrom = int(row.chromosome)

        segment_start = int(row.start)
        segment_end = int(row.end)

        query_start = max(0, segment_start - padding)
        query_end = segment_end + 1  # query_end is exclusive for tree.overlap

        tree = gene_interval_trees.get(chrom)
        gene_found = None

        if tree:
            # Query interval tree with the padded window
            overlaps = tree.overlap(query_start, query_end)
            if overlaps:
                gene_found = sorted(overlaps)[0].data['gene']

        gene_annotations.append(gene_found)

    df['gene'] = gene_annotations
    return df


def aggregate_data_by_gene(annotated_df, log2_cutoff):
    """Aggregates data by gene and case, calculating median absolute position and log2."""
    if annotated_df.empty or 'gene' not in annotated_df.columns:
        return pd.DataFrame()

    ngs_with_gene = annotated_df[annotated_df['gene'].notna() & (annotated_df['gene'] != '-')]
    if ngs_with_gene.empty: return pd.DataFrame()

    # Ensure 'absolute_start' and 'log2' exist, needed for aggregation
    if 'absolute_start' not in ngs_with_gene.columns or 'log2' not in ngs_with_gene.columns:
        print("Warning: 'absolute_start' or 'log2' missing for gene aggregation.")
        return pd.DataFrame()

    gene_reps = ngs_with_gene.groupby(['gene', 'case_n_number']).agg(
        absolute_position=('absolute_start', 'median'),
        log2=('log2', 'median')
    ).reset_index()

    gene_reps.loc[gene_reps['log2'] > log2_cutoff, 'log2'] = log2_cutoff
    gene_reps.loc[gene_reps['log2'] < -log2_cutoff, 'log2'] = -log2_cutoff  # Also clip negative extreme
    return gene_reps

In [None]:
def load_seg_cnr_file(is_seg, item_id, case_id, base_dir):  # Combined loader
    sub_path = 'epic/seg' if is_seg else 'ngs/cnr'
    file_suffix = '_CNV.seg' if is_seg else '.cnr'  # CNR glob pattern handled inside

    if is_seg:
        path = base_dir / sub_path / f'{item_id}{file_suffix}'
        if not path.exists(): return None
    else:  # CNR file, use glob
        cnr_dir = base_dir / sub_path
        item_id_str = str(item_id) if pd.notna(item_id) else ""
        if not item_id_str: return None
        matches = list(cnr_dir.glob(f'*{item_id_str}*{file_suffix}'))
        if not matches: return None
        path = matches[0]

    try:
        names = ['fsid', 'chromosome', 'start', 'end', 'log2'] if is_seg else None  # Auto-detect for CNR
        df = pd.read_csv(path, sep='\t', skiprows=1 if is_seg else 0,
                         usecols=[0, 1, 2, 3, 7] if is_seg else None, names=names)
    except Exception:
        return None

    df['case_n_number'] = case_id
    if is_seg:
        df['sentrix_id'] = item_id
        df = df.drop(columns=['fsid'])
    else:  # CNR specific
        df['barcode'] = item_id
        df = df.drop(columns=['gene'], errors='ignore')
        if 'log2' in df.columns: df = df[df['log2'].abs() < 6].copy()

    return _sanitize_chromosome_column(df)


def _normalize_log2(dfs, col='log2'):
    if not dfs: return []
    vals = pd.concat([d[col] for d in dfs if col in d.columns and not d.empty], ignore_index=True)
    if vals.empty: return [d.copy() for d in dfs]
    m, s = vals.mean(), vals.std();
    s = 1.0 if (pd.isna(s) or s == 0) else s
    return [(d.assign(**{col: (d[col] - m) / s}) if col in d.columns and not d.empty else d.copy()) for d in dfs]


def load_combine_genomic(cases, bdir, is_seg_loader, id_col, cid_col='case_n_number', l2_col='log2'):
    raw = []
    for _, r in cases.iterrows():
        iid, csid = r[id_col], r[cid_col]
        if pd.isna(iid) or pd.isna(csid) or str(iid) == '0': continue
        df = load_seg_cnr_file(is_seg_loader, str(iid), str(csid), bdir)  # Pass boolean flag
        if df is not None and not df.empty: raw.append(df)
    if not raw: return pd.DataFrame()
    return pd.concat(_normalize_log2(raw, l2_col), ignore_index=True) if raw else pd.DataFrame()


def add_segment_arm_classification(df, arm_lookup_table):
    if df.empty or arm_lookup_table.empty:
        df['segment_arm'] = None
        return df
    df['chromosome'] = df['chromosome'].astype(int)
    dfa = df.merge(arm_lookup_table, on='chromosome', how='left')
    for pos in ['start', 'end']:
        ps, pe, qs, qe = (dfa[c].fillna(float('-inf') if 'start' in c else float('inf')) for c in
                          ['p_start', 'p_end', 'q_start', 'q_end'])
        dfa[f'{pos}_arm'] = np.select([(dfa[pos] >= ps) & (dfa[pos] < pe), (dfa[pos] >= qs) & (dfa[pos] < qe)],
                                      ['p', 'q'], default=None)
    dfa['segment_arm'] = np.where(dfa['start_arm'] == dfa['end_arm'], dfa['start_arm'], 'spanning')
    dfa.loc[(dfa['start_arm'].notna() & dfa['end_arm'].notna() & (
            dfa['start_arm'] != dfa['end_arm'])), 'segment_arm'] = 'spanning'
    dfa.loc[
        (dfa['start_arm'].isna() | dfa['end_arm'].isna()) & (dfa['start_arm'] != dfa['end_arm']), 'segment_arm'] = None
    return dfa.drop(
        columns=[c for c in ['p_start', 'p_end', 'q_start', 'q_end', 'start_arm', 'end_arm'] if c in dfa.columns])




In [None]:
def calculate_arm_log2_from_seg(classified_seg_df, arm_map_df, id_col_name='sentrix_id'):
    """
    Aggregates log2 from SEG data to arm level using segment clipping.
    Weights log2 by clipped segment coverage within the arm.
    """
    if classified_seg_df.empty or arm_map_df.empty: return pd.DataFrame()
    if 'segment_arm' not in classified_seg_df.columns: return pd.DataFrame()  # Need classification

    df = classified_seg_df.copy()
    df['chromosome'] = df['chromosome'].astype(int)
    arm_map_df['chromosome'] = arm_map_df['chromosome'].astype(int)

    results = []
    for (case, chrom), group in df.groupby(['case_n_number', 'chromosome']):
        chrom_arm_info = arm_map_df[arm_map_df['chromosome'] == chrom]
        for _, arm_row in chrom_arm_info.iterrows():
            arm, arm_start, arm_end, arm_abs_start = arm_row['arm'], arm_row['arm_start'], arm_row['arm_end'], arm_row[
                'arm_abs_start']

            # Select relevant segments (classified to this arm OR spanning) & overlapping the arm physically
            mask = (((group['segment_arm'] == arm) | (group['segment_arm'] == 'spanning')) &
                    (group['end'] > arm_start) & (group['start'] < arm_end))
            relevant_segments = group[mask].copy()

            item_id_val = None
            if not relevant_segments.empty:
                # Clip segments to arm boundaries
                relevant_segments['s_clip'] = relevant_segments['start'].clip(lower=arm_start)
                relevant_segments['e_clip'] = relevant_segments['end'].clip(upper=arm_end)
                relevant_segments['clip_cov'] = relevant_segments['e_clip'] - relevant_segments['s_clip']
                valid_clipped = relevant_segments[relevant_segments['clip_cov'] > 0]

                if not valid_clipped.empty:
                    total_clip_cov = valid_clipped['clip_cov'].sum()
                    log2_agg = (valid_clipped['log2'] * valid_clipped[
                        'clip_cov']).sum() / total_clip_cov if total_clip_cov > 0 else np.nan
                    agg_start, agg_end = valid_clipped['s_clip'].min(), valid_clipped['e_clip'].max()
                    if id_col_name in valid_clipped.columns: item_id_val = valid_clipped[id_col_name].iloc[0]
                else:
                    log2_agg, agg_start, agg_end = np.nan, arm_start, arm_end
            else:
                log2_agg, agg_start, agg_end = np.nan, arm_start, arm_end

            abs_start = arm_abs_start + (agg_start - arm_start)
            abs_end = arm_abs_start + (agg_end - arm_start)
            entry = {'case_n_number': case, 'chromosome': chrom, 'segment_arm': arm, 'log2': log2_agg,
                     'start': agg_start, 'end': agg_end, 'absolute_start': abs_start, 'absolute_end': abs_end}
            if item_id_val is not None: entry[id_col_name] = item_id_val
            results.append(entry)

    return pd.DataFrame(results).dropna(subset=['log2'])


# --- Functions specific to NGS/CNR aggregation ---

def prep_ngs_agg(df, chrom_start_map):
    """Prepares NGS data for arm aggregation: calculates absolute midpoints and length."""
    if df.empty or not chrom_start_map: return df
    df['chromosome'] = df['chromosome'].astype(int)  # Ensure int for mapping
    # Calculate absolute start based on bin midpoint
    df['absolute_start'] = df['chromosome'].map(chrom_start_map) + ((df['start'] + df['end']) // 2)
    df['length'] = df['end'] - df['start']  # Length of the bin
    return df


def agg_ngs_points_to_arms(classified_ngs_df, chrom_features_df):
    """Aggregates NGS points/bins to arms based on classification ('p','q','spanning')."""
    if classified_ngs_df.empty or 'segment_arm' not in classified_ngs_df.columns or chrom_features_df.empty:
        return pd.DataFrame()

    # Work with valid points having classification and positive length (for weighting)
    ngs_valid = classified_ngs_df[classified_ngs_df['segment_arm'].notna()].copy()
    if 'length' not in ngs_valid.columns:  # Ensure length exists
        ngs_valid['length'] = ngs_valid['end'] - ngs_valid['start']
    ngs_valid = ngs_valid[ngs_valid['length'] > 0]
    if ngs_valid.empty: return pd.DataFrame()

    ngs_valid['chromosome'] = ngs_valid['chromosome'].astype(int)
    chrom_features_df['chromosome'] = chrom_features_df['chromosome'].astype(int)

    # Aggregate using weighted average, weighted by bin length
    arm_agg = (
        ngs_valid.groupby(['case_n_number', 'chromosome', 'segment_arm'])
        .agg(
            log2=('log2', lambda x: np.average(x, weights=ngs_valid.loc[
                x.index, 'length']) if not x.empty and x.notna().any() else np.nan),
            arm_start=('start', 'min'),
            arm_end=('end', 'max'),
            # Optional: count number of bins per arm aggregate
            # num_bins = ('start', 'size')
        ).reset_index()
    )

    # Add absolute start/end for the aggregated arm segment
    arm_agg = arm_agg.merge(chrom_features_df[['chromosome', 'chromosome_absolute_start']], on='chromosome', how='left')
    arm_agg['arm_absolute_start'] = arm_agg['chromosome_absolute_start'] + arm_agg['arm_start']
    arm_agg['arm_absolute_end'] = arm_agg['chromosome_absolute_start'] + arm_agg['arm_end']

    return arm_agg.drop(columns=['chromosome_absolute_start'], errors='ignore').dropna(subset=['log2'])

In [None]:
def _plot_ngs_scatter(ax, ngs_df_case, arm_colors_dict):
    if ngs_df_case.empty or 'log2' not in ngs_df_case.columns:
        return

    # Alpha calculation based on z-score of log2 for the current case's ngs_df
    # Ensure log2 is numeric before zscore
    log2_values_numeric = pd.to_numeric(ngs_df_case['log2'], errors='coerce')

    alphas_series = pd.Series(0.5, index=ngs_df_case.index)  # Default alpha
    if not log2_values_numeric.isna().all():
        abs_z = np.abs(zscore(log2_values_numeric, nan_policy='omit'))
        calculated_alphas = np.clip(abs_z / 3, 0.1, 1.0) ** 1.75
        # Ensure calculated_alphas aligns with ngs_df_case.index
        alphas_series = pd.Series(calculated_alphas, index=ngs_df_case.index).fillna(0.1)

    # Iterate through unique arms present in the case's data
    # Handle if 'segment_arm' column doesn't exist or has no unique values
    unique_arms = []
    if 'segment_arm' in ngs_df_case.columns and ngs_df_case['segment_arm'].notna().any():
        unique_arms = ngs_df_case['segment_arm'].unique()

    if not any(unique_arms):  # If only NaN or None, or empty unique_arms
        if not ngs_df_case.empty:  # Plot all as 'grey' if no specific arms
            ax.scatter(ngs_df_case['absolute_start'], ngs_df_case['log2'],
                       color=arm_colors_dict.get(None, 'grey'),  # Default to grey
                       alpha=alphas_series,
                       s=2, label=None)  # No specific arm label
        return

    for arm in unique_arms:
        sub_df = ngs_df_case[ngs_df_case['segment_arm'] == arm]
        if sub_df.empty:
            continue

        sub_alphas = alphas_series.reindex(sub_df.index).fillna(0.1)  # Get alphas for this subset

        ax.scatter(
            sub_df['absolute_start'],
            sub_df['log2'],
            c=arm_colors_dict.get(arm, 'grey'),  # Use .get for fallback to grey
            alpha=sub_alphas,
            s=2,  # Original point size
            label=str(arm) if arm is not None else "None"  # Original labeling (might create duplicate legend entries)
        )


def _plot_arm_means(ax, df_aggregated_case, color, linestyle, linewidth, label_text_prefix):
    if df_aggregated_case.empty or 'log2' not in df_aggregated_case.columns:
        return

    start_col = 'arm_absolute_start' if 'arm_absolute_start' in df_aggregated_case.columns else 'absolute_start'
    end_col = 'arm_absolute_end' if 'arm_absolute_end' in df_aggregated_case.columns else 'absolute_end'

    if not all(c in df_aggregated_case.columns for c in [start_col, end_col, 'log2']):
        return

    for idx, row in df_aggregated_case.reset_index().iterrows():
        if pd.notna(row['log2']) and pd.notna(row[start_col]) and pd.notna(row[end_col]):
            ax.plot(
                [row[start_col], row[end_col]],
                [row['log2'], row['log2']],
                color=color,
                linestyle=linestyle,
                linewidth=linewidth,
                label=f'{label_text_prefix}' if idx == 0 else None  # Label only first segment
            )


def _plot_gene_labels(ax, gene_reps_case):
    if gene_reps_case.empty or not all(c in gene_reps_case.columns for c in ['absolute_position', 'log2', 'gene']):
        return

    valid_gene_reps = gene_reps_case.dropna(subset=['absolute_position', 'log2', 'gene'])
    if valid_gene_reps.empty: return

    ax.scatter(
        valid_gene_reps['absolute_position'],
        valid_gene_reps['log2'],
        color='black', s=4, label='genes'
    )
    for _, row in valid_gene_reps.iterrows():
        y, name = row['log2'], row['gene']

        offset = 0.15 if y > 0 else -0.15
        if abs(y) > PLOT_Y_LIM * 0.75:
            offset = -1 * offset
        va = 'bottom' if offset > 0 else 'top'

        ax.text(
            row['absolute_position'],
            y + offset,
            name,
            fontsize=8,
            ha='center',
            va=va,
            rotation=90
        )


def draw_cnv_plot(case_info, df_ngs_case, df_ngs_agg, df_epic_agg, df_gene_reps,
                  df_chroms):
    case_n_number = case_info['case_n_number']
    fig, ax = plt.subplots(figsize=(10, 6))

    # Use the style-specific helper functions
    _plot_ngs_scatter(ax, df_ngs_case, ARM_COLORS)
    _plot_arm_means(ax, df_ngs_agg, 'purple', '-', 1.5, 'average cnr (ngs sample)')
    _plot_arm_means(ax, df_epic_agg, 'purple', '--', 1.0, 'average cnr (epic reference)')
    _plot_gene_labels(ax, df_gene_reps)

    # Axis lines and ticks
    ax.axhline(0, color='grey', linewidth=0.5)
    if not df_chroms.empty and all(
            c in df_chroms.columns for c in ['chromosome_absolute_start', 'length', 'chromosome']):
        chrom_starts = df_chroms['chromosome_absolute_start']
        ax.vlines(chrom_starts, -PLOT_Y_LIM, PLOT_Y_LIM,
                  color='grey', linestyle='--', linewidth=0.5)

        mids = chrom_starts + df_chroms['length'] / 2
        ax.set_xticks(mids)
        ax.set_xticklabels(df_chroms['chromosome'].astype(str), rotation=90)

    xlim_max = df_chroms['chromosome_absolute_start'].iloc[-1] + df_chroms['length'].iloc[-1]

    ax.set(
        xlim=(0, xlim_max),
        ylim=(-PLOT_Y_LIM, PLOT_Y_LIM),
        xlabel='genomic position by chromosome',
        ylabel='copy number deviation (log2)'
    )

    plt.suptitle(f'CNV profile of {case_n_number} from NGS sample vs. EPIC reference', fontsize=14)
    plt.title(
        f"{case_info.get('sex', 'N/A')} (age {case_info.get('age', 'N/A')}), {case_info.get('diagnosis', 'N/A')}",
        fontsize=10, pad=6
    )

    ax.grid(False)
    ax.legend()
    plt.tight_layout()

    fname = PLOT_OUTPUT_DIR / 'scatter' / f"ngs_scatter_{case_n_number.replace('/', '_')}.png"
    plt.savefig(fname, dpi=150)
    print(f"Saved plot: {fname}")

    plt.show()

In [None]:
def calculate_arm_correlations(ngs_aggregated_df, epic_aggregated_df):
    """Calculates Pearson correlation between NGS and EPIC log2 values per arm."""
    if ngs_aggregated_df.empty or epic_aggregated_df.empty:
        return pd.DataFrame(columns=['chromosome', 'arm', 'pearson_r', 'p_value', 'n_samples'])

    merged_df = pd.merge(
        ngs_aggregated_df[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        epic_aggregated_df[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        on=['case_n_number', 'chromosome', 'segment_arm'],
        suffixes=('_ngs', '_epic'),
        how='inner'  # Only cases/arms present in both
    )
    if merged_df.empty: return pd.DataFrame(columns=['chromosome', 'arm', 'pearson_r', 'p_value', 'n_samples'])

    corr_results = []
    for (chrom, arm), group in merged_df.groupby(['chromosome', 'segment_arm']):
        if len(group) >= 3:  # Pearson r needs at least 2, but more is better
            r, p = pearsonr(group['log2_ngs'], group['log2_epic'])
        else:
            r, p = np.nan, np.nan
        corr_results.append({'chromosome': chrom, 'arm': arm, 'pearson_r': r, 'p_value': p, 'n_samples': len(group)})

    return pd.DataFrame(corr_results).sort_values(['chromosome', 'arm']).reset_index(drop=True)

In [None]:
def plot_correlation_scatter(merged_correlation_df, arm_corr_stats_df):
    """Plots scatter of NGS vs EPIC log2 values, faceted by chromosome, colored by arm, with correlation annotations."""
    if merged_correlation_df.empty:
        print("Merged data for correlation plot is empty. Skipping plot.")
        return

    g = sns.relplot(
        data=merged_correlation_df, x='log2_ngs', y='log2_epic',
        col='chromosome', hue='segment_arm', kind='scatter', palette=ARM_COLORS,
        col_wrap=6, height=2, aspect=1, s=20, alpha=0.7,
        facet_kws={'sharex': True, 'sharey': True}
    )

    for ax, chrom_name in zip(g.axes.flat, g.col_names):
        if pd.isna(chrom_name): continue  # Skip if chrom_name is NaN (can happen if few chromosomes)
        chrom_corrs_on_plot = arm_corr_stats_df[arm_corr_stats_df['chromosome'] == chrom_name]

        for arm_label, color in ARM_COLORS.items():
            if arm_label is None: continue  # Skip None arm for text annotation
            arm_data_for_text = chrom_corrs_on_plot[chrom_corrs_on_plot['arm'] == arm_label]

            if not arm_data_for_text.empty and not pd.isna(arm_data_for_text.iloc[0]['pearson_r']):
                r_val = arm_data_for_text.iloc[0]['pearson_r']
                p_val = arm_data_for_text.iloc[0]['p_value']

                text_label = f'r({arm_label})={r_val:.2f}'
                x_pos, ha = (0.05, 'left') if arm_label == 'p' else (0.95, 'right')

                # Significance highlighting for p-value
                bg_color = '#f0f0f0'  # default
                if p_val < 0.05: bg_color = '#d4edda'  # green-ish for p < 0.05
                if p_val < 0.01: bg_color = '#c3e6cb'  # darker green-ish for p < 0.01

                ax.text(x_pos, 0.95, text_label, transform=ax.transAxes, color=color,
                        ha=ha, va='top', fontsize=8,
                        bbox=dict(facecolor=bg_color, alpha=0.6, edgecolor='none', pad=0.2))

    g.set_axis_labels('ngs cn ratio (log2)', 'epic cn ratio (log2)', fontsize=10)
    g.fig.subplots_adjust(top=0.92)  # Adjust top for suptitle
    g.fig.suptitle('NGS vs EPIC: copy number correlation by chromosome and arm', fontsize=14)

    fname = PLOT_OUTPUT_DIR / 'correlation' / "ngs_epic_arm_correlation.png"
    plt.savefig(fname, dpi=200)
    print(f"Saved correlation plot: {fname}")
    plt.show()

In [None]:
cases_df = load_and_preprocess_cases(CASES_FILE_PATH)
cases_df

In [None]:
cytoband_df = load_and_preprocess_cytoband(CYTOBAND_FILE_PATH)
cytoband_df

In [None]:
genes_df = load_and_preprocess_relevant_genes(RELEVANT_GENES_FILE_PATH)
genes_df

In [None]:
print(
    f"Base Dir: {BASE_DATA_DIR.resolve()}, Plot Dir: {PLOT_OUTPUT_DIR.resolve()}, Corr. Plot Dir: {PLOT_OUTPUT_DIR.resolve()}")

chroms_df, chrom_map = pd.DataFrame(), {}
arm_map_df, arm_lookup_table = pd.DataFrame(), pd.DataFrame()
gene_trees = {}

if not cytoband_df.empty:
    chroms_df, chrom_map = calculate_chromosome_features(cytoband_df)
    arm_map_df = create_arm_mapping_df(cytoband_df, chrom_map)
    arm_lookup_table = create_arm_lookup_table(arm_map_df)
if not genes_df.empty:
    gene_trees = build_gene_interval_trees(genes_df)
gene_trees

In [None]:
# Process EPIC data
epic_df_aggregated_all = pd.DataFrame()
if not cases_df.empty and not arm_lookup_table.empty and not arm_map_df.empty:
    epic_raw_df = load_combine_genomic(cases_df, BASE_DATA_DIR, True, 'sentrix_id')  # True for SEG
    if not epic_raw_df.empty:
        epic_classified = add_segment_arm_classification(epic_raw_df.copy(), arm_lookup_table)
        epic_df_aggregated_all = calculate_arm_log2_from_seg(epic_classified, arm_map_df, 'sentrix_id')
        print(f"Aggregated EPIC data: {epic_df_aggregated_all.shape}")

epic_df_aggregated_all

In [None]:
# Process NGS data
ngs_df_processed_full = pd.DataFrame()  # For CNV plots (points)
ngs_arm_aggregated_all = pd.DataFrame()  # For CNV plots (lines) & correlation
ngs_gene_reps_all = pd.DataFrame()  # For CNV plots (gene labels)

if not cases_df.empty:
    ngs_raw_df = load_combine_genomic(cases_df, BASE_DATA_DIR, False, 'barcode')  # False for CNR
    if not ngs_raw_df.empty:
        ngs_curr = ngs_raw_df.copy()
        if chrom_map:
            ngs_curr = prep_ngs_agg(ngs_curr, chrom_map)
        if not arm_lookup_table.empty:
            ngs_curr = add_segment_arm_classification(ngs_curr, arm_lookup_table)
        if gene_trees:
            ngs_curr = annotate_segments_with_genes(ngs_curr, gene_trees)
            ngs_gene_reps_all = aggregate_data_by_gene(ngs_curr, PLOT_Y_LIM)
        if not chroms_df.empty and 'segment_arm' in ngs_curr.columns:
            ngs_arm_aggregated_all = agg_ngs_points_to_arms(ngs_curr, chroms_df)

        ngs_df_processed_full = ngs_curr  # This df has points for CNV plot
        print(
            f"Processed NGS. Points: {ngs_df_processed_full.shape}, Arm Agg: {ngs_arm_aggregated_all.shape}, Gene Reps: {ngs_gene_reps_all.shape}")

ngs_curr

In [None]:
# Generate CNV plots for each case
if not cases_df.empty and not ngs_df_processed_full.empty:
    print(f"\n--- Generating CNV plots for {len(cases_df)} cases ---")
    for _, case_r in cases_df.iterrows():
        cnum = case_r['case_n_number']
        ngs_c = ngs_df_processed_full[ngs_df_processed_full['case_n_number'] == cnum].copy()
        ngs_agg_c = ngs_arm_aggregated_all[ngs_arm_aggregated_all['case_n_number'] == cnum].copy()
        epic_agg_c = epic_df_aggregated_all[epic_df_aggregated_all['case_n_number'] == cnum].copy()
        gene_reps_c = ngs_gene_reps_all[ngs_gene_reps_all['case_n_number'] == cnum].copy()
        if ngs_c.empty and ngs_agg_c.empty and epic_agg_c.empty and gene_reps_c.empty:
            print(f"Skipping plot for {cnum}: No data.")
            continue
        draw_cnv_plot(case_r, ngs_c, ngs_agg_c, epic_agg_c, gene_reps_c, chroms_df)

In [None]:
# Calculate and plot correlations
if not ngs_arm_aggregated_all.empty and not epic_df_aggregated_all.empty:
    print("\n--- Calculating and plotting NGS vs EPIC arm correlations ---")
    arm_correlations_df = calculate_arm_correlations(ngs_arm_aggregated_all, epic_df_aggregated_all)
    print("Arm correlation:\n", arm_correlations_df)

    # Need the merged df for plotting points (not just stats)
    merged_for_plot = pd.merge(
        ngs_arm_aggregated_all[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        epic_df_aggregated_all[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        on=['case_n_number', 'chromosome', 'segment_arm'],
        suffixes=('_ngs', '_epic'), how='inner'
    )
    plot_correlation_scatter(merged_for_plot, arm_correlations_df)
else:
    print("Skipping correlation analysis: Aggregated NGS or EPIC data is missing.")