In [None]:
import sys
from os import path
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from intervaltree import IntervalTree
from matplotlib import pyplot as plt
from scipy.stats import pearsonr, zscore
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.mixture import GaussianMixture

sns.set_theme(style='ticks')

PLOT_Y_LIM = 10
GENE_ANNOTATION_PADDING = 0
GENE_MERGE_DISTANCE = 2_000_000

SEX_PREDICTION_FEMALE_Q = 0.25
SEX_PREDICTION_MALE_Q = 0.75

SEX_PREDICTION_FEMALE_CONFIDENT_THRESHOLD = -10.0
SEX_PREDICTION_MALE_CONFIDENT_THRESHOLD = 0.0

SEX_PREDICTION_AMBIGUOUS_THRESHOLD = -3.0

ARM_COLORS = {'p': 'tab:blue', 'q': 'tab:orange', 'spanning': 'tab:green', None: 'grey'}

BASE_DATA_DIR = Path('data')
PLOT_OUTPUT_DIR = Path('graphics')
PLOT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

CASES_FILE_PATH = BASE_DATA_DIR / 'cases.csv'
CYTOBAND_FILE_PATH = BASE_DATA_DIR / 'cytoBand.txt'
RELEVANT_GENES_FILE_PATH = BASE_DATA_DIR / 'relevant_genes.csv'

# For cases_df processing
CASES_COLUMNS_INDICES = [1, 2, 3, 5, 13, 14]
CASES_COLUMN_NAMES = ['case_n_number', 'age', 'sex', 'diagnosis', 'sentrix_id', 'barcode']

# For cytoband_df processing
CYTOBAND_COLUMN_NAMES = ['chromosome', 'start', 'end', 'band', 'giemsa']

# For relevant_genes_df processing
RELEVANT_GENES_COLUMN_NAMES = ['gene', 'chromosome', 'start', 'end']

In [None]:
def _sanitize_chromosome_column(df, chromosome_col='chromosome'):
    """
    Converts a chromosome column to numeric Int64Dtype, replacing 'chr' prefix
    and mapping sex chromosomes to numeric equivalents (X->23, Y->24).
    Rows with unparsable chromosome values are dropped.
    """
    if chromosome_col not in df.columns:
        return df

    df = df.copy()
    # Standardize to string, lowercase, and remove 'chr'
    s = df[chromosome_col].astype(str).str.lower().str.replace('chr', '', regex=False)

    # Map sex and mitochondrial chromosomes to numeric representations
    s = s.replace({'x': '23', 'y': '24', 'm': '25', 'mt': '25'})

    # Coerce to numeric, dropping any rows that can't be converted
    df[chromosome_col] = pd.to_numeric(s, errors='coerce')
    df = df.dropna(subset=[chromosome_col])

    # Convert to an integer type that supports NaNs
    df[chromosome_col] = df[chromosome_col].astype(pd.Int64Dtype())
    return df

In [None]:
def load_and_preprocess_cases(file_path):
    """
    Loads and preprocesses the cases data from a CSV file.
    """
    try:
        df = pd.read_csv(file_path, delimiter=',', encoding='ISO-8859-1')
    except FileNotFoundError:
        print(f"Error: Cases file not found at {file_path}")
        return pd.DataFrame()

    df = df.iloc[:, CASES_COLUMNS_INDICES]
    df.columns = CASES_COLUMN_NAMES

    df = df.dropna().reset_index(drop=True)
    return df

In [None]:
def load_and_preprocess_cytoband(file_path):
    try:
        df = pd.read_csv(file_path, delimiter='\t', names=CYTOBAND_COLUMN_NAMES)
    except FileNotFoundError:
        print(f"Error: Cytoband file not found at {file_path}")
        return pd.DataFrame()
    return _sanitize_chromosome_column(df, 'chromosome')


def calculate_chromosome_features(cytoband_df):
    if cytoband_df.empty:
        return pd.DataFrame(columns=['chromosome', 'length', 'chromosome_absolute_start']), {}
    chromosomes_df = (
        cytoband_df
        .groupby('chromosome')
        .agg(length=('end', 'max'))
        .reset_index()
        .sort_values('chromosome')
    )
    chromosomes_df['chromosome_absolute_start'] = chromosomes_df['length'].cumsum() - chromosomes_df['length']
    chromosome_start_map = chromosomes_df.set_index('chromosome')['chromosome_absolute_start'].to_dict()
    return chromosomes_df, chromosome_start_map


def create_arm_mapping_df(cytoband_df, chromosome_start_map):
    if cytoband_df.empty or not chromosome_start_map:
        return pd.DataFrame(columns=['chromosome', 'arm', 'arm_start', 'arm_end', 'arm_abs_start', 'arm_abs_end'])
    chromosome_starts_series = pd.Series(chromosome_start_map, name='chr_abs_start').reset_index().rename(
        columns={'index': 'chromosome_num'})
    arm_df = (
        cytoband_df
        .assign(arm=lambda x: x['band'].astype(str).str[0].str.lower())
        .groupby(['chromosome', 'arm'])
        .agg(arm_start=('start', 'min'), arm_end=('end', 'max'))
        .reset_index()
        .merge(chromosome_starts_series, left_on='chromosome', right_on='chromosome_num')
        .assign(
            arm_abs_start=lambda x: x['chr_abs_start'] + x['arm_start'],
            arm_abs_end=lambda x: x['chr_abs_start'] + x['arm_end']
        )
        [['chromosome', 'arm', 'arm_start', 'arm_end', 'arm_abs_start', 'arm_abs_end']]
    )
    return arm_df


def create_arm_lookup_table(arm_mapping_df):
    """Pivots arm_mapping_df for easy lookup of p/q arm start/end."""
    if arm_mapping_df.empty:
        return pd.DataFrame()
    arm_lookup = arm_mapping_df.pivot(index='chromosome', columns='arm',
                                      values=['arm_start', 'arm_end'])
    if not arm_lookup.empty:  # Ensure columns exist before renaming
        # Dynamically create new column names based on available arms (p, q or both)
        new_cols = []
        for val_type in ['arm_start', 'arm_end']:
            for arm_type in ['p', 'q']:  # Assuming p and q are the primary arms
                if (val_type, arm_type) in arm_lookup.columns:
                    new_cols.append(f'{arm_type}_{val_type.split("_")[1]}')  # e.g. p_start, q_end
                else:  # Handle cases where an arm might be missing for a chromosome (though rare for p/q)
                    new_cols.append(None)  # Placeholder for missing data

        # Filter out None placeholders if any arm was completely missing
        # And ensure the resulting columns match expected structure (p_start, p_end, q_start, q_end)
        # This part needs to be robust if arm_mapping_df is sparse for some chromosomes
        # For simplicity, assuming 'p' and 'q' arms are typically present for autosomes.
        # If a chromosome only has 'p' or 'q', pivot will result in NaNs for the other.
        cols_to_rename = {}
        if ('arm_start', 'p') in arm_lookup.columns: cols_to_rename[('arm_start', 'p')] = 'p_start'
        if ('arm_end', 'p') in arm_lookup.columns: cols_to_rename[('arm_end', 'p')] = 'p_end'
        if ('arm_start', 'q') in arm_lookup.columns: cols_to_rename[('arm_start', 'q')] = 'q_start'
        if ('arm_end', 'q') in arm_lookup.columns: cols_to_rename[('arm_end', 'q')] = 'q_end'

        arm_lookup.columns = ['_'.join(col).strip() for col in arm_lookup.columns.values]  # Flatten MultiIndex
        arm_lookup = arm_lookup.rename(columns={
            'arm_start_p': 'p_start', 'arm_end_p': 'p_end',
            'arm_start_q': 'q_start', 'arm_end_q': 'q_end'
        })
        # Ensure all four columns exist, adding them with NaNs if not created by pivot
        for col in ['p_start', 'p_end', 'q_start', 'q_end']:
            if col not in arm_lookup.columns:
                arm_lookup[col] = np.nan
    return arm_lookup

In [None]:
def load_and_preprocess_relevant_genes(file_path):
    try:
        df = pd.read_csv(file_path, delimiter=';', names=RELEVANT_GENES_COLUMN_NAMES)
    except FileNotFoundError:
        print(f"Error: Relevant genes file not found at {file_path}")
        return pd.DataFrame()
    # _sanitize_chromosome_column ensures 'chromosome' is int
    return _sanitize_chromosome_column(df, 'chromosome')


def build_gene_interval_trees(relevant_genes_df):
    """Builds a dictionary of IntervalTrees for gene lookups, per chromosome."""
    if relevant_genes_df.empty: return {}
    trees = {}
    for _, row in relevant_genes_df.iterrows():
        chrom = row.chromosome
        if chrom not in trees:
            trees[chrom] = IntervalTree()

        start = int(row.start)
        end = int(row.end) + 1

        # Store the original start/end in the data payload for reference.
        trees[chrom].addi(start, end, {'gene': row.gene, 'start': int(row.start), 'end': int(row.end)})

    return trees


def annotate_segments_with_genes(df, gene_interval_trees, genes_df):
    """
    Annotates DataFrame segments with overlapping genes from the relevant_genes list,
    applying logic on a per-case basis to correctly handle pre-annotated genes.
    """
    if df.empty or not gene_interval_trees:
        if 'gene' not in df.columns:
            df['gene'] = pd.NA
        return df

    # --- Pre-computation ---
    all_relevant_genes = set(genes_df['gene'])
    case_id_col = 'case_n_number'

    # Ensure a 'gene' column exists and sanitize it globally first.
    # This is efficient and removes junk like 'Antitarget' from all rows.
    if 'gene' not in df.columns:
        df['gene'] = pd.NA
    df['gene'] = df['gene'].where(df['gene'].isin(all_relevant_genes), pd.NA)

    # Ensure chromosome column is numeric for tree lookups.
    df['chromosome'] = pd.to_numeric(df['chromosome'], errors='coerce')

    # --- Per-Case Annotation Logic ---
    processed_cases = []
    for case_id, case_df in df.groupby(case_id_col):
        case_df = case_df.copy()

        # Identify genes already validly annotated for this case
        genes_on_panel_for_this_case = set(case_df['gene'].dropna())

        # Identify rows that still need annotation within THIS case.
        rows_to_annotate_mask = case_df['gene'].isna()
        rows_to_annotate_indices = case_df.index[rows_to_annotate_mask]

        # Iterate through and annotate only the necessary rows for this case.
        for index in rows_to_annotate_indices:
            row = case_df.loc[index]

            if pd.isna(row.chromosome):
                continue

            tree = gene_interval_trees.get(int(row.chromosome))
            if not tree:
                continue

            overlaps = tree.overlap(row.start, row.end)
            if overlaps:
                # Filter out genes that are already on this specific case's panel.
                additional_gene_overlaps = {
                    iv for iv in overlaps if iv.data['gene'] not in genes_on_panel_for_this_case
                }

                if additional_gene_overlaps:
                    first_overlap = sorted(additional_gene_overlaps, key=lambda i: (i.begin, i.data['gene']))[0]
                    case_df.at[index, 'gene'] = first_overlap.data['gene']

        processed_cases.append(case_df)

    # Recombine all the processed cases into a single DataFrame.
    if not processed_cases:
        return pd.DataFrame(columns=df.columns)  # Return empty DF if no cases were processed

    return pd.concat(processed_cases, ignore_index=True)


def aggregate_data_by_gene(annotated_df, merge_distance=GENE_MERGE_DISTANCE):
    """Aggregates log2 data to gene level and merges nearby genes, on a per-case basis."""
    if (annotated_df.empty or
            'gene' not in annotated_df.columns or
            'absolute_start' not in annotated_df.columns or
            'log2' not in annotated_df.columns):
        print("Warning: Cannot aggregate by gene due to missing columns.", file=sys.stderr)
        return pd.DataFrame()

    case_id_col = 'case_n_number'
    if case_id_col not in annotated_df.columns:
        print(f"Warning: Expected column '{case_id_col}' not found.", file=sys.stderr)
        return pd.DataFrame()

    ngs_with_gene = annotated_df[annotated_df['gene'].notna() & annotated_df['log2'].notna()].copy()
    if ngs_with_gene.empty:
        return pd.DataFrame()

    all_cases_final_reps = []

    for case_id, case_df in ngs_with_gene.groupby(case_id_col):

        # Aggregate to gene-level for this single case
        gene_reps_case = case_df.groupby('gene').agg(
            absolute_position=('absolute_start', 'median'),
            log2=('log2', 'median')
        ).reset_index()

        if gene_reps_case.empty:
            continue

        # Sort genes by position for this single case
        gene_reps_case = gene_reps_case.sort_values('absolute_position').reset_index(drop=True)

        # Apply only to this case's genes
        if merge_distance <= 0:
            final_reps_data_case = gene_reps_case.to_dict('records')
        else:
            groups = []
            current_group = []
            for _, row in gene_reps_case.iterrows():
                if not current_group:
                    current_group.append(row)
                    continue

                last_row_in_group = current_group[-1]
                if (row.absolute_position - last_row_in_group.absolute_position) > merge_distance:
                    groups.append(current_group)
                    current_group = [row]
                else:
                    current_group.append(row)

            if current_group:
                groups.append(current_group)

            # Re-aggregate the grouped genes for this single case
            final_reps_data_case = []
            for group in groups:
                group_df = pd.DataFrame(group)
                if len(group_df) == 1:
                    merged_row = group_df.iloc[0].to_dict()
                else:
                    unique_genes = sorted(list(set(group_df['gene'])))
                    merged_name = _merge_gene_names(unique_genes)
                    merged_row = {
                        'gene': merged_name,
                        'absolute_position': group_df['absolute_position'].median(),
                        'log2': group_df.loc[group_df['log2'].abs().idxmax()]['log2']
                    }
                merged_row[case_id_col] = case_id  # Add the case identifier back
                final_reps_data_case.append(merged_row)

        all_cases_final_reps.extend(final_reps_data_case)

    if not all_cases_final_reps:
        return pd.DataFrame()

    # Create the final DataFrame from all processed cases
    merged_gene_reps = pd.DataFrame(all_cases_final_reps)
    merged_gene_reps['log2'] = merged_gene_reps['log2'].clip(lower=-PLOT_Y_LIM, upper=PLOT_Y_LIM)

    return merged_gene_reps


def _merge_gene_names(gene_list):
    """
    Merges a list of gene names intelligently.
    e.g., ['CDKN2A', 'CDKN2B'] becomes 'CDKN2A/B'.
    """
    if len(gene_list) <= 1:
        return '/'.join(gene_list)

    prefix = path.commonprefix(gene_list)

    # If the prefix is trivial, don't shorten, just join
    if len(prefix) < 3:
        return '/'.join(gene_list)

    # Keep the first gene name whole, append only suffixes of the rest
    suffixes = [gene[len(prefix):] for gene in gene_list[1:]]
    return '/'.join([gene_list[0]] + suffixes)

In [None]:
def load_seg_cnr_file(is_seg, item_id, case_id, base_dir):  # Combined loader
    sub_path = 'epic/seg' if is_seg else 'ngs/cnr'
    file_suffix = '_CNV.seg' if is_seg else '.cnr'  # CNR glob pattern handled inside

    if is_seg:
        path = base_dir / sub_path / f'{item_id}{file_suffix}'
        if not path.exists(): return None
    else:  # CNR file, use glob
        cnr_dir = base_dir / sub_path
        item_id_str = str(item_id) if pd.notna(item_id) else ""
        if not item_id_str: return None
        matches = list(cnr_dir.glob(f'*{item_id_str}*{file_suffix}'))
        if not matches: return None
        path = matches[0]

    try:
        names = ['fsid', 'chromosome', 'start', 'end', 'log2'] if is_seg else None  # Auto-detect for CNR
        df = pd.read_csv(path, sep='\t', skiprows=1 if is_seg else 0,
                         usecols=[0, 1, 2, 3, 7] if is_seg else None, names=names)
    except Exception:
        return None

    df['case_n_number'] = case_id
    if is_seg:
        df['sentrix_id'] = item_id
        df = df.drop(columns=['fsid'])
    else:  # CNR specific
        df['barcode'] = item_id
        is_antitarget = (df['gene'] == 'Antitarget')
        df.loc[is_antitarget, 'gene'] = pd.NA
        if 'log2' in df.columns: df = df[df['log2'].abs() < 6].copy()

    return _sanitize_chromosome_column(df)


def _normalize_log2(dfs, col='log2'):
    if not dfs: return []
    vals = pd.concat([d[col] for d in dfs if col in d.columns and not d.empty], ignore_index=True)
    if vals.empty: return [d.copy() for d in dfs]
    m, s = vals.mean(), vals.std()
    s = 1.0 if (pd.isna(s) or s == 0) else s
    return [(d.assign(**{col: (d[col] - m) / s}) if col in d.columns and not d.empty else d.copy()) for d in dfs]


def load_combine_genomic(cases, bdir, is_seg_loader, id_col, cid_col='case_n_number', l2_col='log2'):
    raw = []
    for _, r in cases.iterrows():
        iid, csid = r[id_col], r[cid_col]
        if pd.isna(iid) or pd.isna(csid) or str(iid) == '0': continue
        df = load_seg_cnr_file(is_seg_loader, str(iid), str(csid), bdir)  # Pass boolean flag
        if df is not None and not df.empty: raw.append(df)
    if not raw: return pd.DataFrame()
    return pd.concat(_normalize_log2(raw, l2_col), ignore_index=True) if raw else pd.DataFrame()


def add_segment_arm_classification(df, arm_lookup_table):
    if df.empty or arm_lookup_table.empty:
        df['segment_arm'] = None
        return df
    df['chromosome'] = df['chromosome'].astype(int)
    dfa = df.merge(arm_lookup_table, on='chromosome', how='left')
    for pos in ['start', 'end']:
        ps, pe, qs, qe = (dfa[c].fillna(float('-inf') if 'start' in c else float('inf')) for c in
                          ['p_start', 'p_end', 'q_start', 'q_end'])
        dfa[f'{pos}_arm'] = np.select([(dfa[pos] >= ps) & (dfa[pos] < pe), (dfa[pos] >= qs) & (dfa[pos] < qe)],
                                      ['p', 'q'], default=None)
    dfa['segment_arm'] = np.where(dfa['start_arm'] == dfa['end_arm'], dfa['start_arm'], 'spanning')
    dfa.loc[(dfa['start_arm'].notna() & dfa['end_arm'].notna() & (
            dfa['start_arm'] != dfa['end_arm'])), 'segment_arm'] = 'spanning'
    dfa.loc[
        (dfa['start_arm'].isna() | dfa['end_arm'].isna()) & (dfa['start_arm'] != dfa['end_arm']), 'segment_arm'] = None
    return dfa.drop(
        columns=[c for c in ['p_start', 'p_end', 'q_start', 'q_end', 'start_arm', 'end_arm'] if c in dfa.columns])




In [None]:
def calculate_arm_log2_from_seg(classified_seg_df, arm_map_df, id_col_name='sentrix_id'):
    """
    Aggregates log2 from SEG data to arm level using segment clipping.
    Weights log2 by clipped segment coverage within the arm.
    """
    if classified_seg_df.empty or arm_map_df.empty: return pd.DataFrame()
    if 'segment_arm' not in classified_seg_df.columns: return pd.DataFrame()  # Need classification

    df = classified_seg_df.copy()
    df['chromosome'] = df['chromosome'].astype(int)
    arm_map_df['chromosome'] = arm_map_df['chromosome'].astype(int)

    results = []
    for (case, chrom), group in df.groupby(['case_n_number', 'chromosome']):
        chrom_arm_info = arm_map_df[arm_map_df['chromosome'] == chrom]
        for _, arm_row in chrom_arm_info.iterrows():
            arm, arm_start, arm_end, arm_abs_start = arm_row['arm'], arm_row['arm_start'], arm_row['arm_end'], arm_row[
                'arm_abs_start']

            # Select relevant segments (classified to this arm OR spanning) & overlapping the arm physically
            mask = (((group['segment_arm'] == arm) | (group['segment_arm'] == 'spanning')) &
                    (group['end'] > arm_start) & (group['start'] < arm_end))
            relevant_segments = group[mask].copy()

            item_id_val = None
            if not relevant_segments.empty:
                # Clip segments to arm boundaries
                relevant_segments['s_clip'] = relevant_segments['start'].clip(lower=arm_start)
                relevant_segments['e_clip'] = relevant_segments['end'].clip(upper=arm_end)
                relevant_segments['clip_cov'] = relevant_segments['e_clip'] - relevant_segments['s_clip']
                valid_clipped = relevant_segments[relevant_segments['clip_cov'] > 0]

                if not valid_clipped.empty:
                    total_clip_cov = valid_clipped['clip_cov'].sum()
                    log2_agg = (valid_clipped['log2'] * valid_clipped[
                        'clip_cov']).sum() / total_clip_cov if total_clip_cov > 0 else np.nan
                    agg_start, agg_end = valid_clipped['s_clip'].min(), valid_clipped['e_clip'].max()
                    if id_col_name in valid_clipped.columns: item_id_val = valid_clipped[id_col_name].iloc[0]
                else:
                    log2_agg, agg_start, agg_end = np.nan, arm_start, arm_end
            else:
                log2_agg, agg_start, agg_end = np.nan, arm_start, arm_end

            abs_start = arm_abs_start + (agg_start - arm_start)
            abs_end = arm_abs_start + (agg_end - arm_start)
            entry = {'case_n_number': case, 'chromosome': chrom, 'segment_arm': arm, 'log2': log2_agg,
                     'start': agg_start, 'end': agg_end, 'absolute_start': abs_start, 'absolute_end': abs_end}
            if item_id_val is not None: entry[id_col_name] = item_id_val
            results.append(entry)

    return pd.DataFrame(results).dropna(subset=['log2'])


# --- Functions specific to NGS/CNR aggregation ---

def prep_ngs_agg(df, chrom_start_map):
    """Prepares NGS data for arm aggregation: calculates absolute midpoints and length."""
    if df.empty or not chrom_start_map: return df
    df['chromosome'] = df['chromosome'].astype(int)  # Ensure int for mapping
    # Calculate absolute start based on bin midpoint
    df['absolute_start'] = df['chromosome'].map(chrom_start_map) + ((df['start'] + df['end']) // 2)
    df['length'] = df['end'] - df['start']  # Length of the bin
    return df


def agg_ngs_points_to_arms(classified_ngs_df, chrom_features_df):
    """Aggregates NGS points/bins to arms based on classification ('p','q','spanning')."""
    if classified_ngs_df.empty or 'segment_arm' not in classified_ngs_df.columns or chrom_features_df.empty:
        return pd.DataFrame()

    # Work with valid points having classification and positive length (for weighting)
    ngs_valid = classified_ngs_df[classified_ngs_df['segment_arm'].notna()].copy()
    if 'length' not in ngs_valid.columns:  # Ensure length exists
        ngs_valid['length'] = ngs_valid['end'] - ngs_valid['start']
    ngs_valid = ngs_valid[ngs_valid['length'] > 0]
    if ngs_valid.empty: return pd.DataFrame()

    ngs_valid['chromosome'] = ngs_valid['chromosome'].astype(int)
    chrom_features_df['chromosome'] = chrom_features_df['chromosome'].astype(int)

    # Aggregate using weighted average, weighted by bin length
    arm_agg = (
        ngs_valid.groupby(['case_n_number', 'chromosome', 'segment_arm'])
        .agg(
            log2=('log2', lambda x: np.average(x, weights=ngs_valid.loc[
                x.index, 'length']) if not x.empty and x.notna().any() else np.nan),
            arm_start=('start', 'min'),
            arm_end=('end', 'max'),
            # Optional: count number of bins per arm aggregate
            # num_bins = ('start', 'size')
        ).reset_index()
    )

    # Add absolute start/end for the aggregated arm segment
    arm_agg = arm_agg.merge(chrom_features_df[['chromosome', 'chromosome_absolute_start']], on='chromosome', how='left')
    arm_agg['arm_absolute_start'] = arm_agg['chromosome_absolute_start'] + arm_agg['arm_start']
    arm_agg['arm_absolute_end'] = arm_agg['chromosome_absolute_start'] + arm_agg['arm_end']

    return arm_agg.drop(columns=['chromosome_absolute_start'], errors='ignore').dropna(subset=['log2'])

In [None]:
def predict_and_add_sex(ngs_df, cases_df, case_id_col='case_n_number'):
    """
    Predicts sex using an asymmetric dual-quantile fitness model and adds the
    prediction and a confidence score to the cases dataframe.
    """
    if ngs_df.empty or cases_df.empty:
        print("Warning: Cannot predict sex, input dataframe is empty.", file=sys.stderr)
        cases_df['predicted_sex'] = 'unknown'
        cases_df['sex_confidence'] = 0.0
        return cases_df

    predictions = {}
    confidences = {}

    print("\n--- Predicting sex for all cases using dual-quantile model ---")
    for case_id, group in ngs_df.groupby(case_id_col):
        # Use 24 as the numeric representation for chromosome 'Y'
        y_chrom_bins = group[group['chromosome'] == 24]

        if y_chrom_bins.empty:
            predictions[case_id] = 'unknown'
            confidences[case_id] = 0.0
            continue

        q_female = y_chrom_bins['log2'].quantile(SEX_PREDICTION_FEMALE_Q)
        q_male = y_chrom_bins['log2'].quantile(SEX_PREDICTION_MALE_Q)

        if pd.isna(q_female) or pd.isna(q_male):
            predictions[case_id] = 'unknown'
            confidences[case_id] = 0.0
            continue

        female_range = SEX_PREDICTION_AMBIGUOUS_THRESHOLD - SEX_PREDICTION_FEMALE_CONFIDENT_THRESHOLD
        male_range = SEX_PREDICTION_MALE_CONFIDENT_THRESHOLD - SEX_PREDICTION_AMBIGUOUS_THRESHOLD

        female_fitness = (SEX_PREDICTION_AMBIGUOUS_THRESHOLD - q_female) / (female_range + 1e-9)
        male_fitness = (q_male - SEX_PREDICTION_AMBIGUOUS_THRESHOLD) / (male_range + 1e-9)

        female_fitness = np.clip(female_fitness, 0.0, 1.0)
        male_fitness = np.clip(male_fitness, 0.0, 1.0)

        confidence = abs(female_fitness - male_fitness)

        if female_fitness > male_fitness:
            prediction = "female"
        elif male_fitness > female_fitness:
            prediction = "male"
        else:
            prediction = "unknown"
            confidence = 0.0

        predictions[case_id] = prediction
        confidences[case_id] = confidence

    cases_df['predicted_sex'] = cases_df[case_id_col].map(predictions).fillna('unknown')
    cases_df['sex_confidence'] = cases_df[case_id_col].map(confidences).fillna(0.0)

    print("Sex prediction complete. Results added to cases dataframe.\n")
    return cases_df

In [None]:
import numpy as np
import pandas as pd

def predict_sex(ngs_df, cases_df, case_id_col='case_n_number'):
    """
    Predicts sex using a self-calibrating, linear scoring system based on the
    relative z-scores of the X and Y chromosomes.
    """
    MALE_Z_HIGH_SCORE = -1.5
    MALE_Z_LOW_SCORE = 0.0

    FEMALE_Z_HIGH_SCORE = -9.0
    FEMALE_Z_LOW_SCORE = 0.0

    if ngs_df.empty or cases_df.empty:
        cases_df['predicted_sex'], cases_df['sex_confidence'] = 'unknown', 0.0
        return cases_df

    predictions, confidences = {}, {}

    for case_id, group in ngs_df.groupby(case_id_col):
        # Isolate chromosome data
        autosomal_bins = group[(group['chromosome'] >= 1) & (group['chromosome'] <= 22)]
        x_bins = group[group['chromosome'] == 23]
        y_bins = group[group['chromosome'] == 24]

        if autosomal_bins.empty or x_bins.empty or y_bins.empty:
            predictions[case_id], confidences[case_id] = 'unknown', 0.0
            continue

        # Calculate the sample's unique baseline statistics
        median_auto = autosomal_bins['log2'].median()
        autosomal_std = autosomal_bins['log2'].std()

        # Handle cases with no variance
        if autosomal_std < 1e-9:
            predictions[case_id], confidences[case_id] = 'unknown', 0.0
            continue

        # Calculate the median log2 for sex chromosomes
        median_x = x_bins['log2'].median()
        median_y = y_bins['log2'].median()

        # Calculate z-scores relative to the sample's own autosomes
        z_score_x = (median_x - median_auto) / autosomal_std
        z_score_y = (median_y - median_auto) / autosomal_std

        # --- Calculate Male Score (0.0 to 1.0) ---
        male_range = MALE_Z_LOW_SCORE - MALE_Z_HIGH_SCORE
        male_score = (MALE_Z_LOW_SCORE - z_score_x) / male_range
        male_score = np.clip(male_score, 0.0, 1.0)

        # --- Calculate Female Score (0.0 to 1.0) ---
        female_range = FEMALE_Z_LOW_SCORE - FEMALE_Z_HIGH_SCORE
        female_score = (FEMALE_Z_LOW_SCORE - z_score_y) / female_range
        female_score = np.clip(female_score, 0.0, 1.0)

        # --- Decision and Confidence ---
        if male_score > female_score:
            prediction = "male"
        else:
            prediction = "female"

        # Confidence is the absolute difference between the scores
        confidence = abs(male_score - female_score)

        predictions[case_id] = prediction
        confidences[case_id] = np.clip(confidence, 0.0, 1.0)

    # Map the results back to the main cases dataframe
    cases_df['predicted_sex'] = cases_df[case_id_col].map(predictions).fillna('unknown')
    cases_df['sex_confidence'] = cases_df[case_id_col].map(confidences).fillna(0.0)

    return cases_df

In [None]:
def _plot_ngs_scatter(ax, ngs_df_case):
    """Plots individual NGS bins. Colors based on log2 value (gain/loss/neutral)."""
    if ngs_df_case.empty or 'log2' not in ngs_df_case.columns or 'absolute_start' not in ngs_df_case.columns:
        return

    log2_values = pd.to_numeric(ngs_df_case['log2'], errors='coerce').clip(-PLOT_Y_LIM, PLOT_Y_LIM)
    abs_pos = ngs_df_case['absolute_start']

    # Calculate alpha based on z-score of log2 values
    if len(log2_values) > 1:  # zscore needs at least 2 points
        # zscore returns a numpy array if input is a Series, re-index to match original
        abs_z = np.abs(zscore(log2_values.to_numpy()))
        calculated_alphas = np.clip((abs_z / 3) ** 2.25, 0.01, 1.0)
        alphas = pd.Series(calculated_alphas, index=log2_values.index)
    else:
        alphas = pd.Series([0.5], index=log2_values.index)  # Default alpha for a single point

    # Define colors
    color_gain = 'tab:orange'
    color_loss = 'tab:blue'
    color_neutral = 'grey'

    # Create masks for plotting, excluding NaNs from these categories
    mask_gain = (log2_values > 1e-6)  # Use a small epsilon for > 0
    mask_loss = (log2_values < -1e-6)  # Use a small epsilon for < 0
    mask_neutral = np.isclose(log2_values, 0, atol=1e-6)

    # Plot gains
    if mask_gain.any():
        ax.scatter(abs_pos[mask_gain], log2_values[mask_gain],
                   color=color_gain, alpha=alphas[mask_gain], s=2)

    # Plot losses
    if mask_loss.any():
        ax.scatter(abs_pos[mask_loss], log2_values[mask_loss],
                   color=color_loss, alpha=alphas[mask_loss], s=2)

    # Plot neutral points
    if mask_neutral.any():
        ax.scatter(abs_pos[mask_neutral], log2_values[mask_neutral],
                   color=color_neutral, alpha=alphas[mask_neutral], s=2)


def _plot_arm_means(ax, df_aggregated_case, color, linestyle, linewidth, label_text_prefix):
    if df_aggregated_case.empty or 'log2' not in df_aggregated_case.columns:
        return

    start_col = 'arm_absolute_start' if 'arm_absolute_start' in df_aggregated_case.columns else 'absolute_start'
    end_col = 'arm_absolute_end' if 'arm_absolute_end' in df_aggregated_case.columns else 'absolute_end'

    if not all(c in df_aggregated_case.columns for c in [start_col, end_col, 'log2']):
        return

    for idx, row in df_aggregated_case.reset_index().iterrows():
        if pd.notna(row['log2']) and pd.notna(row[start_col]) and pd.notna(row[end_col]):
            ax.plot(
                [row[start_col], row[end_col]],
                [row['log2'], row['log2']],
                color=color,
                linestyle=linestyle,
                linewidth=linewidth,
                label=f'{label_text_prefix}' if idx == 0 else None  # Label only first segment
            )


def _plot_gene_labels(ax, gene_reps_case):
    if gene_reps_case.empty or not all(c in gene_reps_case.columns for c in ['absolute_position', 'log2', 'gene']):
        return

    valid_gene_reps = gene_reps_case.dropna(subset=['absolute_position', 'log2', 'gene'])
    if valid_gene_reps.empty: return

    ax.scatter(
        valid_gene_reps['absolute_position'],
        valid_gene_reps['log2'],
        color='black', s=4
    )
    for _, row in valid_gene_reps.iterrows():
        y, name = row['log2'], row['gene']

        offset = 0.15 if y > 0 else -0.15
        if abs(y) > PLOT_Y_LIM * 0.75:
            offset = -1 * offset
        va = 'bottom' if offset > 0 else 'top'

        ax.text(
            row['absolute_position'],
            y + offset,
            name,
            fontsize=8,
            ha='center',
            va=va,
            rotation=90
        )


def draw_cnv_plot(case_info, df_ngs_case, df_ngs_agg, df_epic_agg, df_gene_reps,
                  df_chroms, arm_map_df):
    case_n_number = case_info['case_n_number']
    fig, ax = plt.subplots(figsize=(11, 6))

    # Use the style-specific helper functions
    _plot_ngs_scatter(ax, df_ngs_case)
    _plot_arm_means(ax, df_ngs_agg, 'purple', '-', 1.5, 'average cnr (ngs sample)')
    _plot_arm_means(ax, df_epic_agg, 'purple', '--', 1.0, 'average cnr (epic reference)')
    _plot_gene_labels(ax, df_gene_reps)

    # Axis lines and ticks
    ax.axhline(0, color='grey', linewidth=0.5)

    if not df_chroms.empty and all(
            c in df_chroms.columns for c in ['chromosome_absolute_start', 'length', 'chromosome']):
        chrom_starts = df_chroms['chromosome_absolute_start']
        ax.vlines(chrom_starts, -PLOT_Y_LIM, PLOT_Y_LIM,
                  color='grey', linestyle='--', linewidth=0.5)

        # Draw vertical lines at p/q arm boundaries
        if not arm_map_df.empty and 'arm_abs_end' in arm_map_df.columns:
            # The end of the 'p' arm is the centromere / p-q boundary
            pq_boundaries = arm_map_df.loc[arm_map_df['arm'] == 'p', 'arm_abs_end']
            if not pq_boundaries.empty:
                ax.vlines(pq_boundaries, -PLOT_Y_LIM, PLOT_Y_LIM,
                          color='grey', linestyle=':', linewidth=0.3)

        mids = chrom_starts + df_chroms['length'] / 2
        ax.set_xticks(mids)

        # Create user-friendly labels for ticks, converting 23/24 back to X/Y
        xticklabels = df_chroms['chromosome'].astype(str).replace({'23': 'X', '24': 'Y'}).tolist()
        ax.set_xticklabels(xticklabels, rotation=90)

    xlim_max = df_chroms['chromosome_absolute_start'].iloc[-1] + df_chroms['length'].iloc[-1]

    ax.set(
        xlim=(0, xlim_max),
        ylim=(-PLOT_Y_LIM, PLOT_Y_LIM),
        xlabel='genomic position by chromosome',
        ylabel='copy number deviation (log2)'
    )

    plt.suptitle(f'CNV profile of {case_n_number} from NGS sample vs. EPIC reference', fontsize=14)

    known_sex = case_info.get('sex', 'N/A')
    predicted_sex = case_info.get('predicted_sex', 'N/A')
    confidence = case_info.get('sex_confidence', None)

    pred_text = predicted_sex
    if confidence is not None and predicted_sex != 'unknown':
        pred_text += f' {confidence:.0%}'

    title_text = (
        f"{known_sex} (prediction: {pred_text}), age {case_info.get('age', 'N/A')}, "
        f"{case_info.get('diagnosis', 'N/A')}"
    )
    plt.title(title_text, fontsize=10, pad=6)

    ax.grid(False)
    ax.legend()
    plt.tight_layout()

    # Ensure output directory exists
    (PLOT_OUTPUT_DIR / 'scatter').mkdir(parents=True, exist_ok=True)
    fname = PLOT_OUTPUT_DIR / 'scatter' / f"ngs_scatter_{case_n_number.replace('/', '_')}.png"
    # plt.savefig(fname, dpi=150)
    # print(f"Saved plot: {fname}") # Commented out to reduce console noise in a loop

    plt.show(fig)

In [None]:
def calculate_arm_correlations(ngs_aggregated_df, epic_aggregated_df):
    """Calculates Pearson correlation between NGS and EPIC log2 values per arm."""
    if ngs_aggregated_df.empty or epic_aggregated_df.empty:
        return pd.DataFrame(columns=['chromosome', 'arm', 'pearson_r', 'p_value', 'n_samples'])

    merged_df = pd.merge(
        ngs_aggregated_df[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        epic_aggregated_df[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        on=['case_n_number', 'chromosome', 'segment_arm'],
        suffixes=('_ngs', '_epic'),
        how='inner'  # Only cases/arms present in both
    )
    if merged_df.empty: return pd.DataFrame(columns=['chromosome', 'arm', 'pearson_r', 'p_value', 'n_samples'])

    corr_results = []
    for (chrom, arm), group in merged_df.groupby(['chromosome', 'segment_arm']):
        if len(group) >= 3:  # Pearson r needs at least 2, but more is better
            r, p = pearsonr(group['log2_ngs'], group['log2_epic'])
        else:
            r, p = np.nan, np.nan
        corr_results.append({'chromosome': chrom, 'arm': arm, 'pearson_r': r, 'p_value': p, 'n_samples': len(group)})

    return pd.DataFrame(corr_results).sort_values(['chromosome', 'arm']).reset_index(drop=True)

In [None]:
def plot_correlation_scatter(merged_correlation_df, arm_corr_stats_df):
    """Plots scatter of NGS vs EPIC log2 values, faceted by chromosome, colored by arm, with correlation annotations."""
    if merged_correlation_df.empty:
        print("Merged data for correlation plot is empty. Skipping plot.")
        return

    g = sns.relplot(
        data=merged_correlation_df, x='log2_ngs', y='log2_epic',
        col='chromosome', hue='segment_arm', kind='scatter', palette=ARM_COLORS,
        col_wrap=6, height=2, aspect=1, s=20, alpha=0.7,
        facet_kws={'sharex': True, 'sharey': True}
    )

    for ax, chrom_name in zip(g.axes.flat, g.col_names):
        if pd.isna(chrom_name): continue  # Skip if chrom_name is NaN (can happen if few chromosomes)
        chrom_corrs_on_plot = arm_corr_stats_df[arm_corr_stats_df['chromosome'] == chrom_name]

        for arm_label, color in ARM_COLORS.items():
            if arm_label is None: continue  # Skip None arm for text annotation
            arm_data_for_text = chrom_corrs_on_plot[chrom_corrs_on_plot['arm'] == arm_label]

            if not arm_data_for_text.empty and not pd.isna(arm_data_for_text.iloc[0]['pearson_r']):
                r_val = arm_data_for_text.iloc[0]['pearson_r']
                p_val = arm_data_for_text.iloc[0]['p_value']

                text_label = f'r({arm_label})={r_val:.2f}'
                x_pos, ha = (0.05, 'left') if arm_label == 'p' else (0.95, 'right')

                # Significance highlighting for p-value
                bg_color = '#f0f0f0'  # default
                if p_val < 0.05: bg_color = '#d4edda'  # green-ish for p < 0.05
                if p_val < 0.01: bg_color = '#c3e6cb'  # darker green-ish for p < 0.01

                ax.text(x_pos, 0.95, text_label, transform=ax.transAxes, color=color,
                        ha=ha, va='top', fontsize=8,
                        bbox=dict(facecolor=bg_color, alpha=0.6, edgecolor='none', pad=0.2))

    g.set_axis_labels('ngs cn ratio (log2)', 'epic cn ratio (log2)', fontsize=10)
    g.fig.subplots_adjust(top=0.92)  # Adjust top for suptitle
    g.fig.suptitle('NGS vs EPIC: copy number correlation by chromosome and arm', fontsize=14)

    fname = PLOT_OUTPUT_DIR / 'correlation' / "ngs_epic_arm_correlation.png"
    # plt.savefig(fname, dpi=200)
    # print(f"Saved correlation plot: {fname}")
    plt.show()

In [None]:
cases_df = load_and_preprocess_cases(CASES_FILE_PATH)
cases_df

In [None]:
cytoband_df = load_and_preprocess_cytoband(CYTOBAND_FILE_PATH)
cytoband_df

In [None]:
genes_df = load_and_preprocess_relevant_genes(RELEVANT_GENES_FILE_PATH)
genes_df

In [None]:
print(
    f"Base Dir: {BASE_DATA_DIR.resolve()}, Plot Dir: {PLOT_OUTPUT_DIR.resolve()}, Corr. Plot Dir: {PLOT_OUTPUT_DIR.resolve()}")

chroms_df, chrom_map = pd.DataFrame(), {}
arm_map_df, arm_lookup_table = pd.DataFrame(), pd.DataFrame()
gene_trees = {}

if not cytoband_df.empty:
    chroms_df, chrom_map = calculate_chromosome_features(cytoband_df)
    arm_map_df = create_arm_mapping_df(cytoband_df, chrom_map)
    arm_lookup_table = create_arm_lookup_table(arm_map_df)
if not genes_df.empty:
    gene_trees = build_gene_interval_trees(genes_df)
gene_trees

In [None]:
# Process EPIC data
epic_df_aggregated_all = pd.DataFrame()
if not cases_df.empty and not arm_lookup_table.empty and not arm_map_df.empty:
    epic_raw_df = load_combine_genomic(cases_df, BASE_DATA_DIR, True, 'sentrix_id')  # True for SEG
    if not epic_raw_df.empty:
        epic_classified = add_segment_arm_classification(epic_raw_df.copy(), arm_lookup_table)
        epic_df_aggregated_all = calculate_arm_log2_from_seg(epic_classified, arm_map_df, 'sentrix_id')
        print(f"Aggregated EPIC data: {epic_df_aggregated_all.shape}")

epic_df_aggregated_all

In [None]:
# Process NGS data
ngs_df_processed_full = pd.DataFrame()  # For CNV plots (points)
ngs_arm_aggregated_all = pd.DataFrame()  # For CNV plots (lines) & correlation
ngs_gene_reps_all = pd.DataFrame()  # For CNV plots (gene labels)

if not cases_df.empty:
    ngs_raw_df = load_combine_genomic(cases_df, BASE_DATA_DIR, False, 'barcode')  # False for CNR
    if not ngs_raw_df.empty:
        cases_df = predict_sex(ngs_raw_df, cases_df)
        ngs_curr = ngs_raw_df.copy()
        if chrom_map:
            ngs_curr = prep_ngs_agg(ngs_curr, chrom_map)
        if not arm_lookup_table.empty:
            ngs_curr = add_segment_arm_classification(ngs_curr, arm_lookup_table)
        if gene_trees:
            ngs_curr = annotate_segments_with_genes(ngs_curr, gene_trees, genes_df)
            ngs_gene_reps_all = aggregate_data_by_gene(ngs_curr)
        if not chroms_df.empty and 'segment_arm' in ngs_curr.columns:
            ngs_arm_aggregated_all = agg_ngs_points_to_arms(ngs_curr, chroms_df)

        ngs_df_processed_full = ngs_curr  # This df has points for CNV plot
        print(
            f"Processed NGS. Points: {ngs_df_processed_full.shape}, Arm Agg: {ngs_arm_aggregated_all.shape}, Gene Reps: {ngs_gene_reps_all.shape}")

ngs_curr

In [None]:
# Generate CNV plots for each case
if not cases_df.empty and not ngs_df_processed_full.empty:
    print(f"\n--- Generating CNV plots for {len(cases_df)} cases ---")
    for _, case_r in cases_df.iterrows():
        cnum = case_r['case_n_number']
        ngs_c = ngs_df_processed_full[ngs_df_processed_full['case_n_number'] == cnum].copy()
        ngs_agg_c = ngs_arm_aggregated_all[ngs_arm_aggregated_all['case_n_number'] == cnum].copy()
        epic_agg_c = epic_df_aggregated_all[epic_df_aggregated_all['case_n_number'] == cnum].copy()
        gene_reps_c = ngs_gene_reps_all[ngs_gene_reps_all['case_n_number'] == cnum].copy()
        if ngs_c.empty and ngs_agg_c.empty and epic_agg_c.empty and gene_reps_c.empty:
            print(f"Skipping plot for {cnum}: No data.")
            continue
        draw_cnv_plot(case_r, ngs_c, ngs_agg_c, epic_agg_c, gene_reps_c, chroms_df, arm_map_df)

In [None]:
# Calculate and plot correlations
if not ngs_arm_aggregated_all.empty and not epic_df_aggregated_all.empty:
    print("\n--- Calculating and plotting NGS vs EPIC arm correlations ---")
    arm_correlations_df = calculate_arm_correlations(ngs_arm_aggregated_all, epic_df_aggregated_all)
    print("Arm correlation:\n", arm_correlations_df)

    # Need the merged df for plotting points (not just stats)
    merged_for_plot = pd.merge(
        ngs_arm_aggregated_all[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        epic_df_aggregated_all[['case_n_number', 'chromosome', 'segment_arm', 'log2']],
        on=['case_n_number', 'chromosome', 'segment_arm'],
        suffixes=('_ngs', '_epic'), how='inner'
    )
    plot_correlation_scatter(merged_for_plot, arm_correlations_df)
else:
    print("Skipping correlation analysis: Aggregated NGS or EPIC data is missing.")

In [None]:
arm_correlations_df['pearson_r'].mean()

In [None]:
arm_correlations_df['pearson_r'].median()

In [None]:
### Integrated Model Testing and Validation Cell ###

# This cell tests the "Linear Z-Score" model, which uses a linear
# scoring system for the male (X-chr) and female (Y-chr) hypotheses.

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Ensure the required data is available from previous cells
if 'ngs_raw_df' not in locals() or 'cases_df' not in locals() or ngs_raw_df.empty or cases_df.empty:
    print("Error: 'ngs_raw_df' or 'cases_df' is not loaded. Please run the data loading cells first.", file=sys.stderr)
else:
    # --- Step 1: Define the Prediction Model to Test ---
    def predict_sex_linear_zscore(ngs_df, cases_df, case_id_col='case_n_number'):
        """
        Predicts sex using a linear scoring system based on the relative
        z-scores of the X and Y chromosomes.
        """
        # --- Model Parameters (Unitless and statistically meaningful) ---
        # For the Male Score (based on z_score_x)
        MALE_Z_HIGH_SCORE = -1.5  # The z-score giving a score of 1.0
        MALE_Z_LOW_SCORE = 0.0     # The z-score giving a score of 0.0

        # For the Female Score (based on z_score_y)
        FEMALE_Z_HIGH_SCORE = -9.0 # The z-score giving a score of 1.0
        FEMALE_Z_LOW_SCORE = 0.0    # The z-score giving a score of 0.0

        predictions, confidences = {}, {}

        for case_id, group in ngs_df.groupby(case_id_col):
            autosomal_bins = group[(group['chromosome'] >= 1) & (group['chromosome'] <= 22)]
            x_bins = group[group['chromosome'] == 23]
            y_bins = group[group['chromosome'] == 24]

            if autosomal_bins.empty or x_bins.empty or y_bins.empty:
                predictions[case_id], confidences[case_id] = 'unknown', 0.0
                continue

            median_auto = autosomal_bins['log2'].median()
            autosomal_std = autosomal_bins['log2'].std()
            if autosomal_std < 1e-9:
                predictions[case_id], confidences[case_id] = 'unknown', 0.0
                continue

            median_x = x_bins['log2'].median()
            median_y = y_bins['log2'].median()

            z_score_x = (median_x - median_auto) / autosomal_std
            z_score_y = (median_y - median_auto) / autosomal_std

            # --- Calculate Male Score (0.0 to 1.0) ---
            male_range = MALE_Z_LOW_SCORE - MALE_Z_HIGH_SCORE
            male_score = (MALE_Z_LOW_SCORE - z_score_x) / male_range
            male_score = np.clip(male_score, 0.0, 1.0)

            # --- Calculate Female Score (0.0 to 1.0) ---
            female_range = FEMALE_Z_LOW_SCORE - FEMALE_Z_HIGH_SCORE
            female_score = (FEMALE_Z_LOW_SCORE - z_score_y) / female_range
            female_score = np.clip(female_score, 0.0, 1.0)

            # --- Decision and Confidence ---
            if male_score > female_score:
                prediction = "male"
            else:
                prediction = "female"

            confidence = abs(male_score - female_score)

            predictions[case_id] = prediction
            confidences[case_id] = np.clip(confidence, 0.0, 1.0)

        cases_df['predicted_sex'] = cases_df[case_id_col].map(predictions).fillna('unknown')
        cases_df['sex_confidence'] = cases_df[case_id_col].map(confidences).fillna(0.0)
        return cases_df

    # --- Step 2: Run the Prediction ---
    print("--- Running Sex Prediction using the Linear Z-Score Model ---")
    cases_with_predictions = predict_sex_linear_zscore(ngs_raw_df, cases_df.copy())
    print("Prediction complete.")

    # --- Step 3: Validate the Results ---
    valid_comparison_df = cases_with_predictions[
        (cases_with_predictions['sex'].isin(['male', 'female'])) &
        (cases_with_predictions['predicted_sex'].isin(['male', 'female']))
    ].copy()

    if not valid_comparison_df.empty:
        print("\n--- Sex Prediction Validation ---")

        incorrect_predictions = valid_comparison_df[
            valid_comparison_df['sex'] != valid_comparison_df['predicted_sex']
        ]

        if not incorrect_predictions.empty:
            print("\nIncorrect Predictions Found:")
            print(incorrect_predictions[['case_n_number', 'sex', 'predicted_sex', 'sex_confidence']])
        else:
            print("\n‚úÖ All predictions were correct!")

        accuracy = accuracy_score(valid_comparison_df['sex'], valid_comparison_df['predicted_sex'])
        print(f"\nOverall Accuracy: {accuracy:.2%}")

        avg_confidence = valid_comparison_df['sex_confidence'].mean()
        print(f"Average Confidence: {avg_confidence:.2%}")

        print("\nClassification Report:")
        labels = sorted(list(set(valid_comparison_df['sex'])))
        print(classification_report(valid_comparison_df['sex'], valid_comparison_df['predicted_sex'], labels=labels, zero_division=0))

        print("\nConfusion Matrix:")
        cm = confusion_matrix(valid_comparison_df['sex'], valid_comparison_df['predicted_sex'], labels=labels)
        cm_df = pd.DataFrame(cm, index=[f'Actual: {l}' for l in labels], columns=[f'Predicted: {l}' for l in labels])
        print(cm_df)

    else:
        print("\nCould not perform validation. No valid known and predicted sexes to compare.")

In [None]:
### Display Key Chromosomal Statistics for All Cases (with Z-Scores) ###

# This cell calculates essential statistics for each sample, including the z-scores
# for the X and Y chromosome medians relative to the autosomal baseline.
# This provides a clear table for debugging and analyzing model inputs.

import numpy as np
import pandas as pd

# Ensure the required data is available from previous cells
if 'ngs_raw_df' not in locals() or 'cases_df' not in locals() or ngs_raw_df.empty or cases_df.empty:
    print("Error: 'ngs_raw_df' or 'cases_df' is not loaded. Please run the data loading cells first.", file=sys.stderr)
else:
    # --- Calculate Statistics for Each Case ---
    print("--- Calculating Chromosomal Statistics and Z-Scores for Each Case ---")

    case_stats = []
    for case_id, group in ngs_raw_df.groupby('case_n_number'):
        # Isolate data for each chromosome type
        autosomal_bins = group[(group['chromosome'] >= 1) & (group['chromosome'] <= 22)]
        x_bins = group[group['chromosome'] == 23]
        y_bins = group[group['chromosome'] == 24]

        # Calculate base stats, handling cases where a chromosome might be missing data
        median_auto = autosomal_bins['log2'].median() if not autosomal_bins.empty else np.nan
        std_auto = autosomal_bins['log2'].std() if not autosomal_bins.empty else np.nan
        median_x = x_bins['log2'].median() if not x_bins.empty else np.nan
        median_y = y_bins['log2'].median() if not y_bins.empty else np.nan

        # Calculate z-scores, handling potential division by zero
        if pd.notna(std_auto) and std_auto > 1e-9:
            z_score_x = (median_x - median_auto) / std_auto
            z_score_y = (median_y - median_auto) / std_auto
        else:
            z_score_x, z_score_y = np.nan, np.nan

        # Store all results
        case_stats.append({
            'case_n_number': case_id,
            'autosomal_median': median_auto,
            'autosomal_std': std_auto,
            'x_median': median_x,
            'y_median': median_y,
            'z_score_x': z_score_x,
            'z_score_y': z_score_y
        })

    # Convert the list of results into a DataFrame
    stats_df = pd.DataFrame(case_stats)

    # Merge with cases_df to add the 'sex' column
    summary_df = pd.merge(cases_df[['case_n_number', 'sex']], stats_df, on='case_n_number')

    # Reorder columns for maximum clarity
    summary_df = summary_df[[
        'case_n_number',
        'sex',
        'autosomal_median',
        'autosomal_std',
        'x_median',
        'z_score_x',
        'y_median',
        'z_score_y'
    ]]

    # Set display options for better readability
    pd.set_option('display.float_format', '{:.4f}'.format)

    print("\nSummary statistics table created:")

    # Display the final table
    display(summary_df)

    # Optional: Reset float format if you don't want it for subsequent cells
    # pd.reset_option('display.float_format')

In [None]:
### Simulation to Optimize Autosomal Z-Score Thresholds

import time
from sklearn.metrics import accuracy_score

# Ensure the required data is available from previous cells
if 'ngs_raw_df' not in locals() or 'cases_df' not in locals() or ngs_raw_df.empty or cases_df.empty:
    print("Error: 'ngs_raw_df' or 'cases_df' is not loaded. Please run data loading cells first.", file=sys.stderr)
else:
    # --- Pre-calculation Step (for speed) ---
    print("Pre-calculating Z-scores for all samples...")
    z_scores = {}
    for case_id, group in ngs_raw_df.groupby('case_n_number'):
        y_chrom_bins = group[group['chromosome'] == 24]
        autosomal_bins = group[(group['chromosome'] >= 1) & (group['chromosome'] <= 22)]
        if not y_chrom_bins.empty and not autosomal_bins.empty:
            autosomal_std = autosomal_bins['log2'].std()
            if autosomal_std > 1e-6:
                autosomal_mean = autosomal_bins['log2'].mean()
                y_median = y_chrom_bins['log2'].median()
                z_scores[case_id] = (y_median - autosomal_mean) / autosomal_std

    # Convert to a DataFrame for easier processing
    z_scores_df = pd.Series(z_scores, name='z_score').to_frame().reset_index()
    z_scores_df = z_scores_df.rename(columns={'index': 'case_n_number'})
    # Merge with true sex labels
    z_scores_df = z_scores_df.merge(cases_df[['case_n_number', 'sex']], on='case_n_number')
    print(f"Calculated Z-scores for {len(z_scores_df)} samples.")

    # --- Helper function for the simulation ---
    def _calculate_metrics_zscore(df, female_z_thresh, male_z_thresh):
        y_true, y_pred, confidences = [], [], []

        for _, row in df.iterrows():
            z = row['z_score']
            true_sex = row['sex']

            # --- Prediction and Confidence (mirrors main function) ---
            if z < female_z_thresh:
                prediction = "female"
                confidence = min(1.0, (female_z_thresh - z) / abs(female_z_thresh))
            elif z > male_z_thresh:
                prediction = "male"
                confidence = min(1.0, (z - male_z_thresh) / abs(male_z_thresh))
            else:
                prediction = "unknown"
                mid_point = (female_z_thresh + male_z_thresh) / 2
                confidence = 1.0 - (abs(z - mid_point) / abs(mid_point - female_z_thresh))

            confidences.append(np.clip(confidence, 0.0, 1.0))

            if true_sex in ['male', 'female'] and prediction != 'unknown':
                y_true.append(true_sex)
                y_pred.append(prediction)

        return {
            'accuracy': accuracy_score(y_true, y_pred) if y_true else 0.0,
            'confidence': np.mean(confidences) if confidences else 0.0
        }

    # --- Simulation Setup ---
    print("\n--- Starting Parameter Optimization for Z-Score Model ---")

    ACCURACY_WEIGHT = 0.7
    CONFIDENCE_WEIGHT = 0.3

    # Define search space for z-score thresholds
    female_thresholds = np.arange(-5.0, -1.0, 0.1)
    male_thresholds = np.arange(-3.0, 1.0, 0.1)

    total_iterations = len(female_thresholds) * len(male_thresholds)
    print(f"Search space: {len(female_thresholds)} (female) x {len(male_thresholds)} (male) = {total_iterations} total combinations.")

    # --- Grid Search Execution ---
    best_combined_score = -1.0
    best_params = {}
    best_metrics = {}
    iteration_count = 0
    start_time = time.time()

    for f_thresh in female_thresholds:
        for m_thresh in male_thresholds:
            iteration_count += 1
            if not f_thresh < m_thresh: continue # Ensure female threshold is lower than male

            metrics = _calculate_metrics_zscore(z_scores_df, f_thresh, m_thresh)
            combined_score = (ACCURACY_WEIGHT * metrics['accuracy']) + (CONFIDENCE_WEIGHT * metrics['confidence'])

            if combined_score > best_combined_score:
                best_combined_score = combined_score
                best_params = {
                    'female_z_thresh': round(f_thresh, 2),
                    'male_z_thresh': round(m_thresh, 2)
                }
                best_metrics = metrics
                print(f"*** New best! Score: {best_combined_score:.4f} "
                      f"(Acc: {best_metrics['accuracy']:.2%}, Conf: {best_metrics['confidence']:.2%}) | "
                      f"Params: F_z={best_params['female_z_thresh']}, M_z={best_params['male_z_thresh']} ***")

    end_time = time.time()
    print("\n--- Optimization Complete ---")
    print(f"Total time: {end_time - start_time:.2f} seconds")

    if not best_params:
        print("\nCould not find any valid parameter combinations.")
    else:
        print("\nüèÜ Best Z-Score Thresholds Found:")
        print(f"  - female_z_thresh = {best_params['female_z_thresh']}")
        print(f"  - male_z_thresh   = {best_params['male_z_thresh']}")
        print(f"\nüìà Resulting Metrics with these parameters:")
        print(f"  - Combined Score: {best_combined_score:.4f}")
        print(f"  - Accuracy:       {best_metrics['accuracy']:.3%}")
        print(f"  - Avg Confidence: {best_metrics['confidence']:.3%}")
        print("\nSuggestion: Update the default parameters in the 'predict_sex_autosomal_zscore' function with these values.")