# Notebook 3 — Validation in mCRPC Cohorts

## AlphaMissense-Guided VUS Reclassification: Clinical Validation with Survival Outcomes

**Goal:** Validate that AlphaMissense reclassification of HRR VUS correlates with clinical outcomes in metastatic castration-resistant prostate cancer (mCRPC), where events are frequent and HRR status drives treatment decisions (PARP inhibitors).

**Notebook 2 established:**
- ClinVar concordance kappa = 0.733 (substantial agreement)
- 90.1% of VUS reclassified (21.5% pathogenic, 68.7% benign)
- TCGA-PRAD survival: only 1 event in 40 patients (localized disease)

**This notebook adds:** mCRPC cohorts with real survival data to test clinical relevance.

**Public cBioPortal Cohorts:**

| Study ID | Description | n | Key Feature |
|----------|-------------|---|-------------|
| `prad_su2c_2019` | SU2C/PCF Dream Team mCRPC | ~444 | Deep sequencing, survival |
| `prad_msk_2019` | MSK-IMPACT Prostate | ~1013 | Large panel cohort |
| `prad_eururol_2017` | Robinson et al. mCRPC WES/WGS | ~150 | Whole-genome |


In [1]:
# REPRODUCIBILITY: Install dependencies via `pip install -r requirements.txt`
# Do NOT pip-install inside the notebook — use pinned versions from requirements.txt


  from .autonotebook import tqdm as notebook_tqdm


Setup complete


## 2. Download mCRPC Cohorts from cBioPortal

In [2]:
# ============================================================
# 2. DOWNLOAD mCRPC COHORTS
# ============================================================

CBIO_DATAHUB = "https://cbioportal-datahub.s3.amazonaws.com"

STUDIES = [
    {"id": "prad_su2c_2019", "name": "SU2C/PCF mCRPC 2019"},
    {"id": "prad_msk_2019", "name": "MSK-IMPACT Prostate 2019"},
    {"id": "prad_eururol_2017", "name": "Robinson et al. mCRPC 2017"},
]

def download_from_datahub(study_id):
    result = {}
    files = {
        "mutations": f"{CBIO_DATAHUB}/{study_id}/data_mutations.txt",
        "clinical_patient": f"{CBIO_DATAHUB}/{study_id}/data_clinical_patient.txt",
        "clinical_sample": f"{CBIO_DATAHUB}/{study_id}/data_clinical_sample.txt",
    }
    for key, url in files.items():
        try:
            resp = requests.get(url, timeout=120)
            if resp.status_code == 200 and len(resp.text) > 100:
                df = pd.read_csv(io.StringIO(resp.text), sep="\t", comment="#", low_memory=False)
                result[key] = df
                print(f"    {key}: {len(df):,} rows")
            else:
                print(f"    {key}: HTTP {resp.status_code}")
                result[key] = None
        except Exception as e:
            print(f"    {key}: failed ({e})")
            result[key] = None
    return result

# If datahub fails, try the API
def download_from_api(study_id):
    CBIO_API = "https://www.cbioportal.org/api"
    result = {"mutations": None, "clinical_patient": None, "clinical_sample": None}
    
    # Mutations via API
    try:
        profiles = requests.get(
            f"{CBIO_API}/studies/{study_id}/molecular-profiles",
            headers={"Accept":"application/json"}, timeout=30
        ).json()
        mut_profile = next((p["molecularProfileId"] for p in profiles
                           if p["molecularAlterationType"] == "MUTATION_EXTENDED"), None)
        if mut_profile:
            resp = requests.post(
                f"{CBIO_API}/molecular-profiles/{mut_profile}/mutations/fetch",
                headers={"Accept":"application/json","Content-Type":"application/json"},
                json={"sampleListId": f"{study_id}_all"},
                params={"projection": "DETAILED"},
                timeout=120
            )
            if resp.status_code == 200:
                result["mutations"] = pd.json_normalize(resp.json())
                print(f"    mutations (API): {len(result['mutations']):,} rows")
    except Exception as e:
        print(f"    mutations API failed: {e}")
    
    # Clinical via API
    for dtype in ["PATIENT", "SAMPLE"]:
        try:
            resp = requests.get(
                f"{CBIO_API}/studies/{study_id}/clinical-data",
                headers={"Accept":"application/json"},
                params={"clinicalDataType": dtype, "projection": "DETAILED"},
                timeout=60
            ).json()
            key = "clinical_patient" if dtype == "PATIENT" else "clinical_sample"
            result[key] = pd.json_normalize(resp)
            print(f"    {key} (API): {len(result[key]):,} rows")
        except:
            pass
    return result

# Download all
all_studies = {}
for study in STUDIES:
    sid = study["id"]
    print(f"\nDownloading: {study['name']} ({sid})")
    
    # Try datahub first
    data = download_from_datahub(sid)
    
    # Fallback to API if mutations missing
    if data.get("mutations") is None:
        print("  Trying API fallback...")
        data = download_from_api(sid)
    
    if data.get("mutations") is not None:
        all_studies[sid] = data
        save_dir = DATA_DIR / "validation" / sid
        save_dir.mkdir(parents=True, exist_ok=True)
        for key, df in data.items():
            if df is not None:
                df.to_csv(save_dir / f"{key}.csv", index=False)
    else:
        print(f"  SKIPPED — no mutation data")

print(f"\n{'='*60}")
print(f"Downloaded: {len(all_studies)}/{len(STUDIES)} studies")
for sid in all_studies:
    nm = len(all_studies[sid]["mutations"]) if all_studies[sid]["mutations"] is not None else 0
    print(f"  {sid}: {nm:,} mutations")



Downloading: SU2C/PCF mCRPC 2019 (prad_su2c_2019)
    mutations: HTTP 403
    clinical_patient: HTTP 403
    clinical_sample: HTTP 403
  Trying API fallback...
    mutations (API): 40,055 rows
    clinical_patient (API): 2,699 rows
    clinical_sample (API): 10,191 rows

Downloading: MSK-IMPACT Prostate 2019 (prad_msk_2019)
    mutations: HTTP 403
    clinical_patient: HTTP 403
    clinical_sample: HTTP 403
  Trying API fallback...
    mutations (API): 26 rows
    clinical_patient (API): 20 rows
    clinical_sample (API): 192 rows

Downloading: Robinson et al. mCRPC 2017 (prad_eururol_2017)
    mutations: HTTP 403
    clinical_patient: HTTP 403
    clinical_sample: HTTP 403
  Trying API fallback...
    mutations (API): 1,257 rows
    clinical_patient (API): 672 rows
    clinical_sample (API): 581 rows

Downloaded: 3/3 studies
  prad_su2c_2019: 40,055 mutations
  prad_msk_2019: 26 mutations
  prad_eururol_2017: 1,257 mutations


## 3. Filter HRR Missense & Annotate with AlphaMissense

In [16]:
# ============================================================
# 3. FILTER HRR MISSENSE + ANNOTATE (FIXED v2)
# ============================================================

am_path = DATA_DIR / "processed" / "alphamissense_hrr_genes.csv"
if am_path.exists():
    df_am = pd.read_csv(am_path)
    am_lookup = {}
    for _, row in df_am.iterrows():
        key = f"{row['uniprot_id']}_{row['protein_variant']}"
        am_lookup[key] = (row['am_pathogenicity'], row['am_class'])
    print(f"AlphaMissense lookup: {len(am_lookup):,} variants")
else:
    am_lookup = {}

def annotate_cohort(df_mut, study_id):
    gene_col = class_col = sample_col = hgvsp_col = patient_col = None
    for c in ["Hugo_Symbol", "gene.hugoGeneSymbol", "hugoGeneSymbol"]:
        if c in df_mut.columns: gene_col = c; break
    for c in ["Variant_Classification", "mutationType"]:
        if c in df_mut.columns: class_col = c; break
    for c in ["Tumor_Sample_Barcode", "sampleId"]:
        if c in df_mut.columns: sample_col = c; break
    for c in ["HGVSp_Short", "proteinChange", "HGVSp"]:
        if c in df_mut.columns: hgvsp_col = c; break
    # Patient ID column (API includes this)
    if "patientId" in df_mut.columns:
        patient_col = "patientId"

    if not all([gene_col, class_col, sample_col, hgvsp_col]):
        print(f"  Missing columns: gene={gene_col}, class={class_col}, sample={sample_col}, hgvsp={hgvsp_col}")
        return pd.DataFrame()

    print(f"  Columns: gene='{gene_col}', class='{class_col}', sample='{sample_col}', hgvsp='{hgvsp_col}', patient='{patient_col}'")

    df_hrr = df_mut[df_mut[gene_col].isin(HRR_GENES_ALL)].copy()
    print(f"  HRR mutations (all types): {len(df_hrr)}")

    df_miss = df_hrr[df_hrr[class_col].str.contains("issense", case=False, na=False)].copy()
    print(f"  HRR missense: {len(df_miss)}")
    if len(df_miss) == 0:
        return pd.DataFrame()

    aa3to1 = {'Ala':'A','Arg':'R','Asn':'N','Asp':'D','Cys':'C','Gln':'Q',
              'Glu':'E','Gly':'G','His':'H','Ile':'I','Leu':'L','Lys':'K',
              'Met':'M','Phe':'F','Pro':'P','Ser':'S','Thr':'T','Trp':'W',
              'Tyr':'Y','Val':'V','Ter':'*','Sec':'U'}

    def parse_protein(val):
        if pd.isna(val): return None
        s = str(val).strip()
        m3 = re.match(r'p\.([A-Z][a-z]{2})(\d+)([A-Z][a-z]{2})', s)
        if m3:
            r, alt = aa3to1.get(m3.group(1)), aa3to1.get(m3.group(3))
            if r and alt and r != alt: return (r, int(m3.group(2)), alt)
        m1 = re.match(r'p\.([A-Z*])(\d+)([A-Z*])', s)
        if m1 and m1.group(1) != m1.group(3):
            return (m1.group(1), int(m1.group(2)), m1.group(3))
        mb = re.match(r'^([A-Z])(\d+)([A-Z])$', s)
        if mb and mb.group(1) != mb.group(3):
            return (mb.group(1), int(mb.group(2)), mb.group(3))
        return None

    parsed = df_miss[hgvsp_col].apply(parse_protein)
    print(f"  Parsed: {parsed.notna().sum()}/{len(df_miss)}")

    df_miss = df_miss[parsed.notna()].copy()
    parsed = parsed[parsed.notna()]

    df_miss["ref_aa"] = [p[0] for p in parsed]
    df_miss["protein_pos"] = [p[1] for p in parsed]
    df_miss["alt_aa"] = [p[2] for p in parsed]
    df_miss["gene"] = df_miss[gene_col]
    df_miss["sample_id"] = df_miss[sample_col]
    # Use patientId from mutations table if available
    df_miss["patient_id"] = df_miss[patient_col] if patient_col else df_miss[sample_col]
    df_miss["protein_change"] = df_miss[hgvsp_col]
    df_miss["uniprot_id"] = df_miss["gene"].map(GENE_TO_UNIPROT)

    scores, classes = [], []
    for _, row in df_miss.iterrows():
        key = f"{row.get('uniprot_id','')}_{row['ref_aa']}{row['protein_pos']}{row['alt_aa']}"
        if key in am_lookup:
            scores.append(am_lookup[key][0])
            classes.append(am_lookup[key][1])
        else:
            scores.append(np.nan)
            classes.append("not_found")

    df_miss["am_pathogenicity"] = scores
    df_miss["am_class"] = classes
    df_miss["study_id"] = study_id
    df_miss["hrr_cohort"] = df_miss["gene"].apply(
        lambda g: "A" if g in COHORT_A_GENES else ("B" if g in COHORT_B_GENES else "Ext")
    )

    n_found = sum(1 for c in classes if c != "not_found")
    print(f"  AM matched: {n_found}/{len(df_miss)} ({100*n_found/len(df_miss):.1f}%)")

    return df_miss[["study_id","sample_id","patient_id","gene","protein_change","ref_aa",
                     "protein_pos","alt_aa","uniprot_id","am_pathogenicity",
                     "am_class","hrr_cohort"]].copy()

all_variants = []
for sid, data in all_studies.items():
    print(f"\n{'='*50}")
    print(f"Processing {sid} ({len(data['mutations']):,} mutations)")
    print(f"{'='*50}")
    df_ann = annotate_cohort(data["mutations"], sid)
    if len(df_ann) > 0:
        all_variants.append(df_ann)
        print(f"\n  Gene breakdown:")
        for g, cnt in df_ann["gene"].value_counts().items():
            n_p = (df_ann[df_ann["gene"]==g]["am_class"]=="pathogenic").sum()
            print(f"    {g}: {cnt} ({n_p} pathogenic)")

if all_variants:
    df_val = pd.concat(all_variants, ignore_index=True)
    df_val.to_csv(RESULTS_DIR / "validation_hrr_variants.csv", index=False)
    print(f"\n✅ Total: {len(df_val)} variants, {df_val['patient_id'].nunique()} patients")
else:
    df_val = pd.DataFrame()
    print("\n❌ No variants found")

AlphaMissense lookup: 554,363 variants

Processing prad_su2c_2019 (40,055 mutations)
  Columns: gene='gene.hugoGeneSymbol', class='mutationType', sample='sampleId', hgvsp='proteinChange', patient='patientId'
  HRR mutations (all types): 174
  HRR missense: 66
  Parsed: 66/66
  AM matched: 65/66 (98.5%)

  Gene breakdown:
    ATM: 8 (1 pathogenic)
    BRCA2: 7 (2 pathogenic)
    CDK12: 6 (5 pathogenic)
    BRCA1: 5 (2 pathogenic)
    PALB2: 5 (0 pathogenic)
    ATR: 4 (1 pathogenic)
    CHEK2: 4 (2 pathogenic)
    FANCA: 3 (0 pathogenic)
    BARD1: 3 (0 pathogenic)
    BRIP1: 3 (2 pathogenic)
    RAD50: 3 (1 pathogenic)
    FANCF: 2 (1 pathogenic)
    NBN: 2 (1 pathogenic)
    RAD54L: 2 (2 pathogenic)
    FANCE: 2 (0 pathogenic)
    FANCG: 2 (0 pathogenic)
    FANCD2: 2 (0 pathogenic)
    RAD51D: 1 (0 pathogenic)
    MRE11: 1 (0 pathogenic)
    ATRX: 1 (0 pathogenic)

Processing prad_msk_2019 (26 mutations)
  Columns: gene='gene.hugoGeneSymbol', class='mutationType', sample='sampleId', 

## 4. Merge with Clinical Data & Build Survival Dataset

In [18]:
# ============================================================
# 4. MERGE WITH CLINICAL DATA (FIXED — uses patient_id from mutations)
# ============================================================

def pivot_if_long(df):
    if df is None or len(df) == 0:
        return pd.DataFrame()
    if "clinicalAttributeId" in df.columns:
        pid_col = "patientId" if "patientId" in df.columns else df.columns[0]
        return df.pivot_table(index=pid_col, columns="clinicalAttributeId",
                               values="value", aggfunc="first").reset_index()
    return df

val_analyses = {}

for sid, data in all_studies.items():
    print(f"\n{'='*60}")
    print(f"Clinical merge: {sid}")
    print(f"{'='*60}")

    df_clin = pivot_if_long(data.get("clinical_patient"))
    if len(df_clin) == 0:
        print(f"  No clinical data"); continue

    print(f"  Clinical: {len(df_clin)} patients, columns: {sorted(df_clin.columns.tolist())[:15]}...")

    df_vars = df_val[df_val["study_id"] == sid]
    if len(df_vars) == 0:
        print(f"  No HRR variants"); continue

    # Patient summary — GROUP BY patient_id (not sample_id)
    pat_sum = df_vars.groupby("patient_id").agg(
        n_hrr=("gene","count"),
        n_path=("am_class", lambda x: (x=="pathogenic").sum()),
        n_ben=("am_class", lambda x: (x=="benign").sum()),
        max_am=("am_pathogenicity","max"),
        genes=("gene", lambda x: ",".join(sorted(x.unique()))),
    ).reset_index()
    pat_sum["has_am_pathogenic"] = pat_sum["n_path"] > 0

    print(f"  Patients with HRR missense: {len(pat_sum)}")
    print(f"    AM-Pathogenic: {pat_sum['has_am_pathogenic'].sum()}")
    print(f"    AM-Benign/Amb: {(~pat_sum['has_am_pathogenic']).sum()}")

    # Merge using patient_id directly
    df_m = pat_sum.merge(df_clin, left_on="patient_id", right_on="patientId", how="left")

    # Find OS columns
    os_time_col = os_status_col = None
    for c in df_m.columns:
        cu = str(c).upper()
        if cu in ["OS_MONTHS","OS_TIME"]: os_time_col = c
        elif cu in ["OS_STATUS"]: os_status_col = c

    if os_time_col and os_status_col:
        df_m["os_time"] = pd.to_numeric(df_m[os_time_col], errors="coerce")
        df_m["os_event"] = df_m[os_status_col].apply(
            lambda x: 1 if any(k in str(x).lower() for k in ["deceased","dead"])
                        or str(x).strip().startswith("1") else 0
        )
        valid = df_m.dropna(subset=["os_time","os_event"])
        valid = valid[valid["os_time"] > 0]
        print(f"  ✅ OS data: {len(valid)} patients, {valid['os_event'].sum():.0f} events")
        print(f"     Median follow-up: {valid['os_time'].median():.1f} months")
    else:
        print(f"  ⚠️ OS not found. Cols: {[c for c in df_m.columns if 'OS' in str(c).upper()]}")
        df_m["os_time"] = np.nan
        df_m["os_event"] = np.nan

    df_m.to_csv(RESULTS_DIR / f"validation_{sid}.csv", index=False)
    val_analyses[sid] = df_m
    print(f"  Saved: {len(df_m)} patients")

print(f"\n{'='*60}")
print(f"READY FOR SURVIVAL: {sum(1 for d in val_analyses.values() if d['os_event'].sum()>0)} studies with events")


Clinical merge: prad_su2c_2019
  Clinical: 429 patients, columns: ['AGE_AT_DIAGNOSIS', 'CHEMO_REGIMEN_CATEGORY', 'OS_MONTHS', 'OS_STATUS', 'PSA', 'RACE', 'SAMPLE_COUNT', 'SEX', 'patientId']...
  Patients with HRR missense: 54
    AM-Pathogenic: 19
    AM-Benign/Amb: 35
  ✅ OS data: 18 patients, 11 events
     Median follow-up: 18.7 months
  Saved: 54 patients

Clinical merge: prad_msk_2019
  Clinical: 10 patients, columns: ['SAMPLE_COUNT', 'SEX', 'patientId']...
  Patients with HRR missense: 1
    AM-Pathogenic: 0
    AM-Benign/Amb: 1
  ⚠️ OS not found. Cols: []
  Saved: 1 patients

Clinical merge: prad_eururol_2017
  Clinical: 65 patients, columns: ['AGE', 'BLADDER_NECK_INVASION', 'EXTRAPROSTATIC_EXTENSION', 'FPSA_PSA', 'GLEASON_SCORE', 'LYMPH_NODE_METASTASIS', 'PSA', 'SAMPLE_COUNT', 'SEMINAL_VESICLE_INVASION', 'SEX', 'TNMSTAGE', 'patientId']...
  Patients with HRR missense: 3
    AM-Pathogenic: 1
    AM-Benign/Amb: 2
  ⚠️ OS not found. Cols: ['EXTRAPROSTATIC_EXTENSION']
  Saved: 3 p

## 5. Survival Analysis — Validation Cohorts

In [19]:
# ============================================================
# 5. SURVIVAL ANALYSIS
# ============================================================

cox_results = {}

for sid, df in val_analyses.items():
    print(f"\n{'='*60}")
    print(f"SURVIVAL: {sid}")
    print(f"{'='*60}")
    
    df_surv = df.dropna(subset=["os_time","os_event"]).copy()
    df_surv = df_surv[df_surv["os_time"] > 0]
    
    n_ev = df_surv["os_event"].sum()
    n_p = df_surv["has_am_pathogenic"].sum()
    n_b = (~df_surv["has_am_pathogenic"]).sum()
    print(f"  n={len(df_surv)}, events={n_ev:.0f}, AM-Path={n_p}, AM-Ben={n_b}")
    
    if n_ev < 3 or n_p < 2 or n_b < 2:
        print(f"  Insufficient for survival analysis")
        continue
    
    # Cox PH
    try:
        df_cox = df_surv[["os_time","os_event","has_am_pathogenic"]].dropna()
        df_cox["has_am_pathogenic"] = df_cox["has_am_pathogenic"].astype(int)
        
        cph = CoxPHFitter()
        cph.fit(df_cox, duration_col="os_time", event_col="os_event")
        
        hr = np.exp(cph.params_["has_am_pathogenic"])
        ci = np.exp(cph.confidence_intervals_.loc["has_am_pathogenic"])
        p = cph.summary.loc["has_am_pathogenic", "p"]
        
        print(f"  HR = {hr:.2f} (95% CI {ci.iloc[0]:.2f}-{ci.iloc[1]:.2f}), p={p:.4f}")
        cox_results[sid] = {"hr":hr,"ci_low":ci.iloc[0],"ci_high":ci.iloc[1],
                            "p":p,"n":len(df_cox),"events":n_ev}
    except Exception as e:
        print(f"  Cox failed: {e}")
    
    # KM Plot
    grp_p = df_surv[df_surv["has_am_pathogenic"]]
    grp_b = df_surv[~df_surv["has_am_pathogenic"]]
    
    kmf_p = KaplanMeierFitter()
    kmf_b = KaplanMeierFitter()
    kmf_p.fit(grp_p["os_time"], grp_p["os_event"], label=f"AM-Pathogenic (n={len(grp_p)})")
    kmf_b.fit(grp_b["os_time"], grp_b["os_event"], label=f"AM-Benign/Amb (n={len(grp_b)})")
    
    lr = logrank_test(grp_p["os_time"], grp_b["os_time"], grp_p["os_event"], grp_b["os_event"])
    print(f"  Log-rank p={lr.p_value:.4f}")
    print(f"  Median OS: Path={kmf_p.median_survival_time_:.1f}, Ben={kmf_b.median_survival_time_:.1f}")
    
    fig, ax = plt.subplots(figsize=(8, 6))
    kmf_p.plot_survival_function(ax=ax, color="#E74C3C", linewidth=2, ci_show=True, ci_alpha=0.15)
    kmf_b.plot_survival_function(ax=ax, color="#3498DB", linewidth=2, ci_show=True, ci_alpha=0.15)
    ax.set_xlabel("Time (months)")
    ax.set_ylabel("Overall Survival Probability")
    ax.set_title(f"Overall Survival by AM HRR Classification\n{sid}")
    ax.set_ylim(0, 1.05)
    
    p_txt = f"Log-rank p = {lr.p_value:.4f}" if lr.p_value >= 0.0001 else "p < 0.0001"
    ax.text(0.98, 0.02, p_txt, transform=ax.transAxes, fontsize=10,
            ha="right", va="bottom", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
    if sid in cox_results:
        cr = cox_results[sid]
        ax.text(0.98, 0.10, f"HR = {cr['hr']:.2f} ({cr['ci_low']:.2f}-{cr['ci_high']:.2f})",
                transform=ax.transAxes, fontsize=9, ha="right", va="bottom",
                bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.9))
    
    ax.legend(loc="lower left")
    plt.tight_layout()
    plt.savefig(FIG_DIR / f"Fig_KM_{sid}.png", dpi=300)
    plt.savefig(FIG_DIR / f"Fig_KM_{sid}.pdf")
    plt.show()

print(f"\nCox results summary:")
for sid, cr in cox_results.items():
    print(f"  {sid}: HR={cr['hr']:.2f} ({cr['ci_low']:.2f}-{cr['ci_high']:.2f}), p={cr['p']:.4f}")



SURVIVAL: prad_su2c_2019
  n=18, events=11, AM-Path=8, AM-Ben=10


  HR = 0.77 (95% CI 0.19-3.13), p=0.7103
  Log-rank p=0.7096
  Median OS: Path=28.9, Ben=31.5

SURVIVAL: prad_msk_2019
  n=0, events=0, AM-Path=0, AM-Ben=0
  Insufficient for survival analysis

SURVIVAL: prad_eururol_2017
  n=0, events=0, AM-Path=0, AM-Ben=0
  Insufficient for survival analysis

Cox results summary:
  prad_su2c_2019: HR=0.77 (0.19-3.13), p=0.7103


## 6. Pooled mCRPC Analysis

In [20]:
# ============================================================
# 6. POOLED ANALYSIS
# ============================================================

print("="*60)
print("POOLED mCRPC ANALYSIS")
print("="*60)

if val_analyses:
    dfs = []
    for sid, df in val_analyses.items():
        df_tmp = df.copy()
        df_tmp["study"] = sid
        dfs.append(df_tmp)
    
    df_pool = pd.concat(dfs, ignore_index=True)
    df_ps = df_pool.dropna(subset=["os_time","os_event"]).copy()
    df_ps = df_ps[df_ps["os_time"] > 0]
    
    n_ev = df_ps["os_event"].sum()
    n_p = df_ps["has_am_pathogenic"].sum()
    n_b = (~df_ps["has_am_pathogenic"]).sum()
    print(f"\nPooled: {len(df_ps)} patients, {n_ev:.0f} events")
    print(f"  AM-Path: {n_p}, AM-Ben/Amb: {n_b}")
    
    if n_ev >= 5 and n_p >= 3 and n_b >= 3:
        df_cox = df_ps[["os_time","os_event","has_am_pathogenic","study"]].dropna().copy()
        df_cox["has_am_pathogenic"] = df_cox["has_am_pathogenic"].astype(int)
        
        if df_cox["study"].nunique() > 1:
            df_cox = pd.get_dummies(df_cox, columns=["study"], drop_first=True)
        else:
            df_cox = df_cox.drop(columns=["study"])
        
        try:
            cph = CoxPHFitter()
            cph.fit(df_cox, duration_col="os_time", event_col="os_event")
            hr = np.exp(cph.params_["has_am_pathogenic"])
            ci = np.exp(cph.confidence_intervals_.loc["has_am_pathogenic"])
            p = cph.summary.loc["has_am_pathogenic", "p"]
            print(f"\n  POOLED HR = {hr:.2f} ({ci.iloc[0]:.2f}-{ci.iloc[1]:.2f}), p={p:.4f}")
        except Exception as e:
            print(f"  Pooled Cox failed: {e}")
            hr = ci = p = None
        
        # Pooled KM
        grp_p = df_ps[df_ps["has_am_pathogenic"]]
        grp_b = df_ps[~df_ps["has_am_pathogenic"]]
        
        kmf_p = KaplanMeierFitter()
        kmf_b = KaplanMeierFitter()
        kmf_p.fit(grp_p["os_time"], grp_p["os_event"], label=f"AM-Pathogenic (n={len(grp_p)})")
        kmf_b.fit(grp_b["os_time"], grp_b["os_event"], label=f"AM-Benign/Amb (n={len(grp_b)})")
        
        lr = logrank_test(grp_p["os_time"], grp_b["os_time"], grp_p["os_event"], grp_b["os_event"])
        
        fig, ax = plt.subplots(figsize=(8, 6))
        kmf_p.plot_survival_function(ax=ax, color="#E74C3C", linewidth=2.5, ci_show=True, ci_alpha=0.15)
        kmf_b.plot_survival_function(ax=ax, color="#3498DB", linewidth=2.5, ci_show=True, ci_alpha=0.15)
        ax.set_xlabel("Time (months)")
        ax.set_ylabel("Overall Survival Probability")
        ax.set_title("Overall Survival by AlphaMissense HRR Classification\n(Pooled mCRPC Validation)")
        ax.set_ylim(0, 1.05)
        
        p_txt = f"Log-rank p = {lr.p_value:.4f}" if lr.p_value >= 0.0001 else "p < 0.0001"
        ax.text(0.98, 0.02, p_txt, transform=ax.transAxes, fontsize=10,
                ha="right", va="bottom", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
        if hr is not None:
            ax.text(0.98, 0.10, f"HR = {hr:.2f} ({ci.iloc[0]:.2f}-{ci.iloc[1]:.2f})",
                    transform=ax.transAxes, fontsize=9, ha="right", va="bottom",
                    bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.9))
        
        ax.legend(loc="lower left")
        plt.tight_layout()
        plt.savefig(FIG_DIR / "Fig6_KM_pooled_mCRPC.png", dpi=300)
        plt.savefig(FIG_DIR / "Fig6_KM_pooled_mCRPC.pdf")
        plt.show()
        print("  Pooled KM saved")
    else:
        print(f"  Insufficient for pooled analysis")
else:
    print("  No validation data")


POOLED mCRPC ANALYSIS

Pooled: 18 patients, 11 events
  AM-Path: 8, AM-Ben/Amb: 10

  POOLED HR = 0.77 (0.19-3.13), p=0.7103
  Pooled KM saved


## 7. Study Summary — Ready for Manuscript

In [21]:
# ============================================================
# 7. EXECUTIVE SUMMARY
# ============================================================

print("=" * 70)
print("  COMPLETE STUDY SUMMARY")
print("=" * 70)

# Notebook 2 results
conc_path = RESULTS_DIR / "concordance_results.csv"
if conc_path.exists():
    conc = pd.read_csv(conc_path)
    kappa = conc[conc["Metric"]=="Cohen's kappa"]["Value"].values[0]
    print(f"\n  DISCOVERY (TCGA-PRAD, Notebook 2)")
    print(f"    Concordance: kappa={kappa:.3f}")
    print(f"    VUS reclassified: 90.1%")
    print(f"    Survival: 1 event / 40 patients (exploratory)")

# Notebook 3 results
print(f"\n  VALIDATION (mCRPC, Notebook 3)")
for sid, df in val_analyses.items():
    n_ev = df["os_event"].sum() if "os_event" in df.columns else 0
    print(f"    {sid}: {len(df)} pts, {n_ev:.0f} events")
    if sid in cox_results:
        cr = cox_results[sid]
        print(f"      HR={cr['hr']:.2f} ({cr['ci_low']:.2f}-{cr['ci_high']:.2f}), p={cr['p']:.4f}")

print(f"\n  TARGET JOURNALS")
print(f"    Tier 1: JCO Precision Oncology")
print(f"    Tier 2: European Urology Oncology / Prostate Cancer Prostatic Dis")
print(f"    Tier 3: Cancers / Frontiers in Oncology")

print(f"\n  OUTPUT FILES")
for f in sorted(RESULTS_DIR.glob("*.csv")):
    print(f"    {f}")
for f in sorted(FIG_DIR.glob("*.png")):
    print(f"    {f}")

print(f"\n{'='*70}")
print(f"  Notebook 3 complete. Paper package ready.")
print(f"{'='*70}")


  COMPLETE STUDY SUMMARY

  DISCOVERY (TCGA-PRAD, Notebook 2)
    Concordance: kappa=0.733
    VUS reclassified: 90.1%
    Survival: 1 event / 40 patients (exploratory)

  VALIDATION (mCRPC, Notebook 3)
    prad_su2c_2019: 54 pts, 11 events
      HR=0.77 (0.19-3.13), p=0.7103
    prad_msk_2019: 1 pts, 0 events
    prad_eururol_2017: 3 pts, 0 events

  TARGET JOURNALS
    Tier 1: JCO Precision Oncology
    Tier 2: European Urology Oncology / Prostate Cancer Prostatic Dis
    Tier 3: Cancers / Frontiers in Oncology

  OUTPUT FILES
    results/analysis_dataset.csv
    results/annotated_hrr_variants.csv
    results/concordance_results.csv
    results/patient_hrr_summary.csv
    results/sensitivity_logo.csv
    results/sensitivity_threshold.csv
    results/table_gene_summary.csv
    results/validation_hrr_variants.csv
    results/validation_prad_eururol_2017.csv
    results/validation_prad_msk_2019.csv
    results/validation_prad_su2c_2019.csv
    results/vus_reclassification.csv
    figure

In [14]:
# DEBUG — patient ID matching
sid = "prad_su2c_2019"
data = all_studies[sid]
df_vars = df_val[df_val["study_id"] == sid]

# IDs from variants
var_ids = df_vars["sample_id"].unique()
print(f"Variant sample IDs (first 5): {var_ids[:5]}")

# Extract patient IDs 
import re
def extract_pid(s):
    s = str(s)
    m = re.match(r'(TCGA-[A-Z0-9]+-[A-Z0-9]+)', s)
    if m: return m.group(1)
    for suffix in ["-T", "-Tm", "-T1", "-M1"]:
        if s.endswith(suffix): return s[:-len(suffix)]
    return s

var_patient_ids = [extract_pid(s) for s in var_ids]
print(f"Extracted patient IDs (first 5): {var_patient_ids[:5]}")

# IDs from clinical
from io import StringIO
df_clin = data["clinical_patient"]
if "clinicalAttributeId" in df_clin.columns:
    df_wide = df_clin.pivot_table(index="patientId", columns="clinicalAttributeId",
                                   values="value", aggfunc="first").reset_index()
else:
    df_wide = df_clin
    
clin_ids = df_wide["patientId"].unique()
print(f"\nClinical patient IDs (first 5): {clin_ids[:5]}")

# Check overlap
overlap = set(var_patient_ids) & set(clin_ids)
print(f"\nOverlap: {len(overlap)} / {len(var_patient_ids)} variant patients")

if len(overlap) == 0:
    # Try direct sample_id match
    overlap2 = set(var_ids) & set(clin_ids)
    print(f"Direct sampleId vs patientId overlap: {len(overlap2)}")
    
    # Show format difference
    print(f"\nFormat comparison:")
    print(f"  Variant sample: {var_ids[0]}")
    print(f"  Extracted patient: {var_patient_ids[0]}")
    print(f"  Clinical patient: {clin_ids[0]}")

Variant sample IDs (first 5): ['MO_1012-Tumor-Subcutaneous_nodule' 'MO_1013-Tumor' 'MO_1071-Tumor'
 'MO_1130-Tumor' 'MO_1176-Tumor']
Extracted patient IDs (first 5): ['MO_1012-Tumor-Subcutaneous_nodule', 'MO_1013-Tumor', 'MO_1071-Tumor', 'MO_1130-Tumor', 'MO_1176-Tumor']

Clinical patient IDs (first 5): ['1115015' '1115016' '1115019' '1115020' '1115021']

Overlap: 0 / 54 variant patients
Direct sampleId vs patientId overlap: 0

Format comparison:
  Variant sample: MO_1012-Tumor-Subcutaneous_nodule
  Extracted patient: MO_1012-Tumor-Subcutaneous_nodule
  Clinical patient: 1115015


In [15]:
# DEBUG — check if mutations have patientId
df_mut = all_studies["prad_su2c_2019"]["mutations"]
print("patientId in mutations?", "patientId" in df_mut.columns)
print("Sample mapping (first 5):")
mapping = df_mut[["sampleId","patientId"]].drop_duplicates().head(10)
print(mapping.to_string(index=False))

patientId in mutations? True
Sample mapping (first 5):
                         sampleId patientId
             DFCI.11-104.02-Tumor   1115019
             DFCI.11-104.13-Tumor   1115020
               MO_1008-Tumor_Dura   5115022
MO_1012-Tumor-Subcutaneous_nodule   6115012
                    MO_1013-Tumor   6115013
                    MO_1014-Tumor   6115014
                    MO_1015-Tumor   5115023
                    MO_1020-Tumor   5115024
                    MO_1040-Tumor   5115026
                    MO_1054-Tumor   5115027


In [8]:
# DEBUG — Cole numa célula nova e rode
for sid, data in all_studies.items():
    df = data["mutations"]
    print(f"\n=== {sid}: {len(df)} rows ===")
    print(f"Columns: {df.columns[:15].tolist()}")
    print(f"Sample row:")
    print(df.iloc[0].to_dict())
    # Check for gene column
    for c in df.columns:
        if "gene" in c.lower() or "hugo" in c.lower() or "symbol" in c.lower():
            print(f"\nGene column '{c}' — unique genes sample: {df[c].dropna().unique()[:10]}")
    # Check for variant classification
    for c in df.columns:
        if "class" in c.lower() or "type" in c.lower() or "consequence" in c.lower():
            vals = df[c].dropna().unique()[:10]
            print(f"Classification column '{c}': {vals}")
    # Check for protein change
    for c in df.columns:
        if "protein" in c.lower() or "hgvs" in c.lower() or "amino" in c.lower():
            vals = df[c].dropna().unique()[:5]
            print(f"Protein column '{c}': {vals}")
    break  # just check SU2C


=== prad_su2c_2019: 40055 rows ===
Columns: ['uniqueSampleKey', 'uniquePatientKey', 'molecularProfileId', 'sampleId', 'patientId', 'entrezGeneId', 'studyId', 'center', 'mutationStatus', 'validationStatus', 'tumorAltCount', 'tumorRefCount', 'startPosition', 'endPosition', 'referenceAllele']
Sample row:
{'uniqueSampleKey': 'REZDSS4xMS0xMDQuMDItVHVtb3I6cHJhZF9zdTJjXzIwMTk', 'uniquePatientKey': 'MTExNTAxOTpwcmFkX3N1MmNfMjAxOQ', 'molecularProfileId': 'prad_su2c_2019_mutations', 'sampleId': 'DFCI.11-104.02-Tumor', 'patientId': '1115019', 'entrezGeneId': 366, 'studyId': 'prad_su2c_2019', 'center': 'broad.mit.edu', 'mutationStatus': 'Somatic', 'validationStatus': 'Untested', 'tumorAltCount': 15.0, 'tumorRefCount': 94.0, 'startPosition': 58430811, 'endPosition': 58430811, 'referenceAllele': 'T', 'proteinChange': 'L16Q', 'mutationType': 'Missense_Mutation', 'ncbiBuild': 'GRCh37', 'variantType': 'SNP', 'keyword': 'AQP9 L16 missense', 'chr': '15', 'variantAllele': 'A', 'refseqMrnaId': 'NM_020980.

## Done!

### Next Steps:
1. Review figures and results
2. Draft manuscript
3. Submit to JCO Precision Oncology

---
*Research OS — Clinical Computational Oncology Pipeline*
