# Multi-Disease Data Extraction Pipeline

Extract disease cases and matched controls from AGP dataset.

## Summary
- Extract cases and matched controls for a selected disease
- Build BIOM tables and phylogenetic distance matrices
- Save outputs under `{DISEASE}_analysis_output/`

## Supported Diseases
IBD, Diabetes, Cancer, Autoimmune, Depression, etc.


In [None]:
# Load DISEASE from experiment config or environment
import json
import yaml
import os
from pathlib import Path

EXPERIMENT_CONFIG = {}
if 'EXPERIMENT_CONFIG_PATH' in os.environ:
    config_path = os.environ['EXPERIMENT_CONFIG_PATH']
    if Path(config_path).exists():
        with open(config_path, 'r') as f:
            EXPERIMENT_CONFIG = yaml.safe_load(f)

# Extract disease from config (check multiple paths like other notebooks)
_disease = (
    EXPERIMENT_CONFIG.get('disease') or
    EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease') or
    EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease_criteria', {}).get('disease') or
    os.environ.get('DISEASE', 'IBD')
)
DISEASE = _disease.upper()

# Disease-specific configurations
DISEASE_CONFIGS = {
    'IBD': {
        'column': 'ibd',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': 'ibd_diagnosis_refined',
        'valid_types': ["Crohn's disease", "Ulcerative colitis"]
    },
    'DIABETES': {
        'column': 'diabetes',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': None,
        'valid_types': None
    },
    'CANCER': {
        'column': 'cancer',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': 'cancer_type',
        'valid_types': None
    },
    'AUTOIMMUNE': {
        'column': 'autoimmune',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': None,
        'valid_types': None
    },
    'DEPRESSION': {
        'column': 'depression_bipolar_schizophrenia',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': None,
        'valid_types': None
    },
    'MENTAL_ILLNESS': {
        'column': 'mental_illness',
        'positive_values': ['Yes'],
        'negative_values': ['No'],
        'type_column': 'mental_illness_type',
        'valid_types': None
    },
    'CARDIOVASCULAR': {
        'column': 'cardiovascular_disease',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': None,
        'valid_types': None
    },
    'KIDNEY': {
        'column': 'kidney_disease',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': None,
        'valid_types': None
    },
    'LIVER': {
        'column': 'liver_disease',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': None,
        'valid_types': None
    },
    'LUNG': {
        'column': 'lung_disease',
        'positive_values': ['Diagnosed by a medical professional (doctor, physician assistant)'],
        'negative_values': ['I do not have this condition'],
        'type_column': None,
        'valid_types': None
    }
}

if DISEASE not in DISEASE_CONFIGS:
    print(f"WARNING: Unknown disease '{DISEASE}', using IBD defaults")
    DISEASE = 'IBD'

DISEASE_CONFIG = DISEASE_CONFIGS[DISEASE]

print(f"{DISEASE} Data Extraction Pipeline")
print("=" * 40)
print(f"Disease column: {DISEASE_CONFIG['column']}")

In [None]:
import pandas as pd
import numpy as np
from biom import load_table
from biom.table import Table
import skbio
from skbio.tree import TreeNode
from pathlib import Path
import pickle
from collections import defaultdict


In [None]:
import pickle
from pathlib import Path

checkpoint_path = Path(f"{DISEASE}_analysis_output/checkpoint_data.pkl")

if checkpoint_path.exists():
    print("Loading checkpoint...")
    
    with open(checkpoint_path, 'rb') as f:
        checkpoint_data = pickle.load(f)
    
    cases_valid = checkpoint_data['cases_valid']
    matched_controls = checkpoint_data['matched_controls']
    case_biom = checkpoint_data['case_biom']
    control_biom = checkpoint_data['control_biom']
    config = checkpoint_data['config']
    biom_table = checkpoint_data['biom_table']
    case_sample_ids = checkpoint_data['case_sample_ids']
    control_sample_ids = checkpoint_data['control_sample_ids']

    phylogeny_tip_count = checkpoint_data.get('phylogeny_tip_count')
    distance_matrix_created = checkpoint_data.get('distance_matrix_created')
    phylogeny_copied = checkpoint_data.get('phylogeny_copied')
    sequences_extracted = checkpoint_data.get('sequences_extracted')
    all_tsv_files_created = checkpoint_data.get('all_tsv_files_created')
    
    print(f"Loaded: {len(cases_valid)} cases, {len(matched_controls)} controls")
    print(f"{DISEASE} BIOM: {case_biom.shape if case_biom is not None else 'None'}")
    print(f"Control BIOM: {control_biom.shape if control_biom is not None else 'None'}")
    if phylogeny_tip_count:
        print(f"Phylogeny tips: {phylogeny_tip_count}")
    print("Can skip to final output cells")
    
    CHECKPOINT_LOADED = True
    
else:
    print("No checkpoint - run full pipeline first")
    CHECKPOINT_LOADED = False


# Multi-Disease Data Extraction
### Steps and outputs


In [None]:
if 'CHECKPOINT_LOADED' in globals() and CHECKPOINT_LOADED:
    print("Checkpoint loaded - skip to final cells")
    
    output_dir = Path(config["output_dir"]) if 'config' in globals() else Path(f"{DISEASE}_analysis_output")
    output_dir.mkdir(exist_ok=True)
    
    for dir_name in ["metadata", "biom_tables", "sequences", "phylogeny", "config"]:
        (output_dir / dir_name).mkdir(exist_ok=True)
    
    print("Output dirs ready")
    
else:
    print("No checkpoint - run processing cells first")


In [None]:
output_dir = Path(f"{DISEASE}_analysis_output")
output_dir.mkdir(exist_ok=True)

dirs_to_create = ["metadata", "biom_tables", "sequences", "phylogeny", "config"]

for dir_name in dirs_to_create:
    (output_dir / dir_name).mkdir(exist_ok=True)

print("Created dirs:")
for d in dirs_to_create:
    print(f"  {output_dir / d}")

config = {
    "biom_path": "../data/AGP.data.biom",
    "metadata_path": "../data/AGP-metadata.tsv", 
    "phylogeny_path": "../data/2024.09.phylogeny.asv.nwk",
    "sequences_path": "ASV_sequences.fasta",
    
    "disease": DISEASE,
    "disease_column": DISEASE_CONFIG['column'],
    "positive_values": DISEASE_CONFIG['positive_values'],
    "negative_values": DISEASE_CONFIG['negative_values'],
    "type_column": DISEASE_CONFIG['type_column'],
    "valid_types": DISEASE_CONFIG['valid_types'],
    
    "age_categories": ["20s", "30s", "40s", "50s", "60s", "70+"],
    "bmi_categories": ["Underweight", "Normal", "Overweight", "Obese"],
    "random_seed": 42,

    "enable_feature_filtering": False,
    
    "output_dir": str(output_dir),
    "metadata_output": str(output_dir / "metadata"),
    "biom_output": str(output_dir / "biom_tables"),
    "sequences_output": str(output_dir / "sequences"),
    "phylogeny_output": str(output_dir / "phylogeny")
}

with open(output_dir / "config" / "pipeline_config.json", "w") as f:
    json.dump(config, f, indent=2)

print(f"\nConfiguration saved to: {output_dir / 'config' / 'pipeline_config.json'}")
print(f"Disease: {config['disease']}")
print(f"Column: {config['disease_column']}")
print(f"Positive values: {config['positive_values']}")
print(f"Negative values: {config['negative_values']}")

In [None]:
required_files = [
    config["metadata_path"],
    config["biom_path"], 
    config["phylogeny_path"]
]

print("Checking required input files:")
missing_files = []
for file_path in required_files:
    if Path(file_path).exists():
        size_mb = Path(file_path).stat().st_size / (1024*1024)
        print(f" {file_path} ({size_mb:.1f} MB)")
    else:
        print(f" {file_path} - MISSING")
        missing_files.append(file_path)

optional_files = [config["sequences_path"]]
print("\n Checking optional input files:")
for file_path in optional_files:
    if Path(file_path).exists():
        size_mb = Path(file_path).stat().st_size / (1024*1024)
        print(f" {file_path} ({size_mb:.1f} MB)")
    else:
        print(f" {file_path} - Not found (optional)")

if missing_files:
    print(f"\n Missing required files: {missing_files}")
    print("Please ensure all required files are in the current directory.")
else:
    print(f"\n All required files found!")

phylogeny_path = config["phylogeny_path"]
phylogeny_file_exists = Path(phylogeny_path).exists()
if phylogeny_file_exists:
    if 'phylogeny_tip_count' in globals() and phylogeny_tip_count is not None:
        print(f" Phylogeny file present; tips from checkpoint: {phylogeny_tip_count:,}")
        gg_phylogeny = None
    else:
        try:
            gg_phylogeny = TreeNode.read(phylogeny_path)
            n_tips = len(list(gg_phylogeny.tips()))
            print(f" Phylogeny loaded: {n_tips:,} tips")
            phylogeny_tip_count = n_tips
        except Exception as e:
            print(f" Error loading phylogeny: {e}")
            gg_phylogeny = None
            phylogeny_tip_count = None
else:
    print(f" Phylogeny file missing: {phylogeny_path}")
    gg_phylogeny = None
    phylogeny_tip_count = None

## Load and Process Metadata
### Read and filter metadata


In [None]:
print("Loading AGP metadata...")
df = pd.read_csv(config["metadata_path"], sep="\t", low_memory=False, dtype=str)
print(f"Loaded {df.shape[0]:,} samples with {df.shape[1]} variables")

df.columns = [col.strip() for col in df.columns]

sample_col = "#SampleID"
disease_col = config['disease_column']
age_col = "age_cat" 
bmi_col = "bmi_cat"
sex_col = "sex"

type_col = config.get('type_column')
if type_col and type_col not in df.columns:
    print(f"Warning: Type column '{type_col}' not found in metadata")
    type_col = None

print(f"\nKey columns identified:")
print(f"  Sample ID: {sample_col}")
print(f"  Disease status: {disease_col}")
print(f"  Age category: {age_col}")
print(f"  BMI category: {bmi_col}")
print(f"  Sex: {sex_col}")
print(f"  Disease type: {type_col}")

print(f"\nSample of metadata:")
display(df[[sample_col, disease_col, age_col, bmi_col, sex_col]].head())


In [None]:
print(f"Identifying {DISEASE} cases and controls...")

cases = df[df[disease_col].isin(config["positive_values"])].copy()
print(f"Total {DISEASE} cases: {len(cases):,}")

if type_col and config.get("valid_types"):
    cases = cases[cases[type_col].isin(config["valid_types"])]
    print(f"{DISEASE} cases (valid types): {len(cases):,}")
    
    if len(cases) > 0:
        type_counts = cases[type_col].value_counts()
        print(f"{DISEASE} type breakdown:")
        for disease_type, count in type_counts.items():
            print(f"  - {disease_type}: {count:,}")

controls = df[df[disease_col].isin(config["negative_values"])].copy()
print(f"Total potential controls: {len(controls):,}")

cases_valid = cases[
    cases[age_col].isin(config["age_categories"]) & 
    cases[bmi_col].isin(config["bmi_categories"])
].copy()

controls_valid = controls[
    controls[age_col].isin(config["age_categories"]) & 
    controls[bmi_col].isin(config["bmi_categories"])
].copy()

print(f"\nAfter filtering to valid age/BMI categories:")
print(f"  {DISEASE} cases: {len(cases_valid):,}")
print(f"  Controls: {len(controls_valid):,}")

if len(cases_valid) > 0:
    print(f"\n{DISEASE} cases age/BMI distribution:")
    age_bmi_dist = pd.crosstab(cases_valid[age_col], cases_valid[bmi_col])
    display(age_bmi_dist)


## Create Matched Case-Control Pairs
### Match by age and BMI


In [None]:
print("Creating matched case-control pairs...")

target_counts = cases_valid.groupby([age_col, bmi_col]).size()
print(f"Target distribution has {len(target_counts)} age/BMI combinations")

np.random.seed(config["random_seed"])
matched_samples = []

for (age_val, bmi_val), n_needed in target_counts.items():
    matching_controls = controls_valid[
        (controls_valid[age_col] == age_val) & 
        (controls_valid[bmi_col] == bmi_val)
    ]
    
    if len(matching_controls) == 0:
        print(f" No controls found for age={age_val}, BMI={bmi_val}")
        continue
        
    if len(matching_controls) >= n_needed:
        sampled = matching_controls.sample(n=n_needed, replace=False, random_state=config["random_seed"])
    else:
        sampled = matching_controls.sample(n=n_needed, replace=True, random_state=config["random_seed"])
        print(f" Sampling with replacement for age={age_val}, BMI={bmi_val} (needed {n_needed}, had {len(matching_controls)})")
    
    matched_samples.append(sampled)

if matched_samples:
    matched_controls = pd.concat(matched_samples, ignore_index=True)
    matched_controls = matched_controls.drop_duplicates(subset=[sample_col])
else:
    matched_controls = pd.DataFrame(columns=df.columns)

print(f"\n Matching results:")
print(f"{DISEASE} cases: {len(cases_valid):,}")
print(f"Matched controls: {len(matched_controls):,}")

if len(matched_controls) > 0:
    print(f"\n Control age/BMI distribution:")
    control_dist = pd.crosstab(matched_controls[age_col], matched_controls[bmi_col])
    display(control_dist)
    
    print(f"\n {DISEASE} case age/BMI distribution:")
    case_dist = pd.crosstab(cases_valid[age_col], cases_valid[bmi_col])
    display(case_dist)


In [None]:
print("Saving intermediate checkpoint (metadata processing)...")

import pickle

intermediate_checkpoint = {
    'cases_valid': cases_valid,
    'matched_controls': matched_controls,
    'config': config
}

checkpoint_path = Path(config["output_dir"]) / "intermediate_checkpoint.pkl"
with open(checkpoint_path, 'wb') as f:
    pickle.dump(intermediate_checkpoint, f)

print(f" Intermediate checkpoint saved to: {checkpoint_path}")
print(f" Saved variables:")
print(f"- cases_valid (DataFrame)")
print(f"- matched_controls (DataFrame)")  
print(f"- config (configuration)")
print("\n Continue to BIOM processing...")


## Filter BIOM Tables
### Subset samples and filter features


In [None]:
print("Loading and filtering BIOM table...")

biom_table = load_table(config["biom_path"])
print(f"Complete BIOM table: {biom_table.shape[0]:,} features × {biom_table.shape[1]:,} samples")

case_sample_ids = set(cases_valid[sample_col].astype(str))
control_sample_ids = set(matched_controls[sample_col].astype(str))

print(f"{DISEASE} samples to extract: {len(case_sample_ids):,}")
print(f"Control samples to extract: {len(control_sample_ids):,}")

biom_sample_ids = set(biom_table.ids(axis='sample'))
cases_in_biom = case_sample_ids.intersection(biom_sample_ids)
control_in_biom = control_sample_ids.intersection(biom_sample_ids)

print(f"{DISEASE} samples found in BIOM: {len(cases_in_biom):,}")
print(f"Control samples found in BIOM: {len(control_in_biom):,}")

if len(cases_in_biom) > 0:
    case_biom = biom_table.filter(cases_in_biom, axis='sample', inplace=False)
    print(f"{DISEASE} BIOM table: {case_biom.shape[0]:,} features × {case_biom.shape[1]:,} samples")
else:
    print(f" No {DISEASE} samples found in BIOM table!")
    case_biom = None

if len(control_in_biom) > 0:
    control_biom = biom_table.filter(control_in_biom, axis='sample', inplace=False)
    print(f"Control BIOM table: {control_biom.shape[0]:,} features × {control_biom.shape[1]:,} samples")
else:
    print(f" No control samples found in BIOM table!")
    control_biom = None

if config.get("enable_feature_filtering", False) and (case_biom is not None or control_biom is not None):
    print("\n Applying feature filtering (T2D-style)...")
    min_prev = float(config.get("feature_min_prevalence", 0.01))
    min_total = float(config.get("feature_min_total_abundance", 0.001))

    try:
        tables = []
        sample_counts = 0
        if case_biom is not None:
            tables.append(case_biom)
            sample_counts += case_biom.shape[1]
        if control_biom is not None:
            tables.append(control_biom)
            sample_counts += control_biom.shape[1]
        union_table = tables[0] if len(tables) == 1 else tables[0].concat(tables[1:])

        nonzero_counts = union_table.matrix_data.getnnz(axis=1)
        prevalence = nonzero_counts / sample_counts

        totals = np.asarray(union_table.matrix_data.sum(axis=1)).ravel()

        feature_ids = np.array(union_table.ids(axis='observation'))
        keep_mask = (prevalence >= min_prev) & (totals >= min_total)
        kept_features = set(feature_ids[keep_mask])

        # Apply max_features safety cap
        max_feats = int(config.get("max_features", 0))
        if max_feats > 0 and len(kept_features) > max_feats:
            # Keep only the most prevalent features
            kept_indices = np.where(keep_mask)[0]
            kept_prev = prevalence[kept_indices]
            top_k = np.argsort(-kept_prev)[:max_feats]
            final_indices = kept_indices[top_k]
            kept_features = set(feature_ids[final_indices])
            print(f"  ⚠ Capped to {max_feats:,} most prevalent features")

        print(f" Kept {len(kept_features):,} / {len(feature_ids):,} features after filtering")

        if case_biom is not None:
            case_biom = case_biom.filter(kept_features, axis='observation', inplace=False)
            print(f" {DISEASE} table now: {case_biom.shape[0]:,} features × {case_biom.shape[1]:,} samples")
        if control_biom is not None:
            control_biom = control_biom.filter(kept_features, axis='observation', inplace=False)
            print(f" Control table now: {control_biom.shape[0]:,} features × {control_biom.shape[1]:,} samples")
    except Exception as e:
        print(f" Feature filtering failed: {e}")

print("\n Saving final checkpoint with all variables...")

import pickle

checkpoint_data = {
    'cases_valid': cases_valid,
    'matched_controls': matched_controls,
    'case_biom': case_biom,
    'control_biom': control_biom,
    'config': config,
    'biom_table': biom_table,
    'case_sample_ids': case_sample_ids,
    'control_sample_ids': control_sample_ids
}

checkpoint_path = Path(config["output_dir"]) / "checkpoint_data.pkl"
with open(checkpoint_path, 'wb') as f:
    pickle.dump(checkpoint_data, f)

print(f" Final checkpoint saved to: {checkpoint_path}")
print(f" Saved variables:")
print(f"- cases_valid (DataFrame)")
print(f"- matched_controls (DataFrame)")  
print(f"- case_biom (BIOM table)")
print(f"- control_biom (BIOM table)")
print(f"- config (configuration)")
print(f"- biom_table (original large BIOM table)")
print(f"- Sample ID sets")
print("\n If kernel crashes, restart and run the checkpoint load cell to continue from here!")


## Sequence Extraction and Phylogeny Processing
### Copy phylogeny and extract sequences


In [None]:
print("Processing sequences and phylogeny...")

import shutil
from pathlib import Path

phylogeny_source = Path(config["phylogeny_path"])
phylogeny_dest = Path(config["phylogeny_output"]) / f"phylogeny_{DISEASE}.nwk"

if phylogeny_source.exists():
    print(f" Copying phylogeny file...")
    shutil.copy2(phylogeny_source, phylogeny_dest)
    print(f" Phylogeny copied to: {phylogeny_dest}")
    
    size_mb = phylogeny_dest.stat().st_size / (1024*1024)
    print(f" File size: {size_mb:.1f} MB")
else:
    print(f" Source phylogeny file not found: {phylogeny_source}")

sequences_source = Path(config["sequences_path"])
if sequences_source.exists():
    print(f" Processing ASV sequences...")

    obs_ids = set()
    try:
        if case_biom is not None:
            obs_ids |= set(case_biom.ids(axis='observation'))
        if control_biom is not None:
            obs_ids |= set(control_biom.ids(axis='observation'))
    except Exception as e:
        print(f" Could not collect observation IDs: {e}")
        obs_ids = set()

    print(f" Extracting sequences for {len(obs_ids):,} observation IDs")

    sequences_output = Path(config["sequences_output"]) / f"ASV_sequences_{DISEASE}.fasta"
    extracted_count = 0

    with open(sequences_output, 'w') as outfile:
        for record in SeqIO.parse(sequences_source, "fasta"):
            if record.id in obs_ids:
                SeqIO.write(record, outfile, "fasta")
                extracted_count += 1

    print(f" Extracted {extracted_count:,} sequences to: {sequences_output}")

    if sequences_output.exists():
        size_mb = sequences_output.stat().st_size / (1024*1024)
        print(f" Sequences file size: {size_mb:.1f} MB")
else:
    print(f" ASV sequences file not found: {sequences_source}")
    print(f" Skipping sequence extraction (optional step)")

print(f" Sequence and phylogeny processing complete!")


## Distance Matrix Generation
### Compute distance matrices


In [None]:
print("Generating THREE phylogenetic distance matrices...")
print(f"1. TreeDistMatrix - Phylogenetic tree distances")
print(f"2. SeqDistMatrix - Jukes-Cantor sequence distances")  
print(f"3. GraphDistMatrix - Phylogenetic network distances")

from skbio import TreeNode
from Bio import SeqIO
import networkx as nx
import pickle
import numpy as np
from math import log

phylogeny_path = Path(config["phylogeny_output"]) / f"phylogeny_{DISEASE}.nwk"

if phylogeny_path.exists():
    try:
        print(f"\n    Loading phylogeny from: {phylogeny_path}")
        tree = TreeNode.read(str(phylogeny_path))
        total_tips = len(list(tree.tips()))
        print(f" Phylogeny loaded: {total_tips:,} tips")

        observed_feature_ids = set()
        if case_biom is not None:
            observed_feature_ids.update(case_biom.ids(axis='observation'))
        if control_biom is not None:
            observed_feature_ids.update(control_biom.ids(axis='observation'))
        
        print(f" Using filtered BIOM feature set with {len(observed_feature_ids):,} total features")

        observed_in_tree = [tip.name for tip in tree.tips() if tip.name in observed_feature_ids]
        print(f" Filtering tree to {len(observed_in_tree):,} observed features...")

        MAX_TIPS_FOR_DENSE_MATRIX = 25000
        
        if len(observed_in_tree) == 0:
            print(f" No observed features found in tree - skipping distance matrix")
            TreeDistMatrix = None
            SeqDistMatrix = None
            GraphDistMatrix = None
        else:
            if len(observed_in_tree) > MAX_TIPS_FOR_DENSE_MATRIX:
                print(f" {len(observed_in_tree):,} observed tips exceed cap ({MAX_TIPS_FOR_DENSE_MATRIX}); reducing set...")
                observed_in_tree = observed_in_tree[:MAX_TIPS_FOR_DENSE_MATRIX]
            
            tree_pruned = tree.shear(observed_in_tree)
            print(f" Pruned tree to {len(list(tree_pruned.tips())):,} tips")

            print(f"\n    Computing TreeDistMatrix...")
            try:
                tree_dist_mat = tree_pruned.tip_tip_distances()
                print(f"TreeDistMatrix: {tree_dist_mat.shape}")
                
                TreeDistMatrix = {}
                tip_names = [tip.name for tip in tree_pruned.tips()]
                for i, tip1 in enumerate(tip_names):
                    TreeDistMatrix[tip1] = {}
                    for j, tip2 in enumerate(tip_names):
                        TreeDistMatrix[tip1][tip2] = tree_dist_mat[i, j]
                
            except Exception as e:
                print(f"Error computing TreeDistMatrix: {e}")
                TreeDistMatrix = None

            print(f"\n    Computing SeqDistMatrix (Jukes-Cantor)...")
            try:
                aln_path = Path("../data/T2D_seqs_CMaligned_with_Rfam.sto")
                if not aln_path.exists():
                    print(f"Alignment file not found: {aln_path}")
                    SeqDistMatrix = None
                else:
                    aln_dict = SeqIO.to_dict(SeqIO.parse(str(aln_path), "stockholm"))
                    print(f"Loaded alignment with {len(aln_dict)} sequences")
                    
                    def JC_distance(seq1, seq2):
                        """Jukes Cantor distance: (-3/4)ln[1-p*(4/3)]"""
                        seq1_str = str(seq1.seq).upper()
                        seq2_str = str(seq2.seq).upper()
                        
                        valid_positions = 0
                        differences = 0
                        
                        for i in range(min(len(seq1_str), len(seq2_str))):
                            if seq1_str[i] not in ['-', '.', 'N'] and seq2_str[i] not in ['-', '.', 'N']:
                                valid_positions += 1
                                if seq1_str[i] != seq2_str[i]:
                                    differences += 1
                        
                        if valid_positions == 0:
                            return 1.0
                        
                        p = differences / valid_positions
                        
                        if p >= 0.75:
                            return 3.0
                        
                        return (-3/4) * log(1 - p * (4/3))
                    
                    SeqDistMatrix = {}
                    aligned_features = [f for f in observed_in_tree if f in aln_dict]
                    print(f"Computing distances for {len(aligned_features)} aligned features...")
                    
                    for i, feat1 in enumerate(aligned_features):
                        SeqDistMatrix[feat1] = {}
                        for feat2 in aligned_features:
                            SeqDistMatrix[feat1][feat2] = JC_distance(aln_dict[feat1], aln_dict[feat2])
                        
                        if (i + 1) % 1000 == 0:
                            print(f"  Progress: {i+1}/{len(aligned_features)}")
                    
                    print(f"SeqDistMatrix: {len(SeqDistMatrix)} x {len(SeqDistMatrix)} features")
                    
            except Exception as e:
                print(f"Error computing SeqDistMatrix: {e}")
                import traceback
                traceback.print_exc()
                SeqDistMatrix = None

            print(f"\n    Computing GraphDistMatrix (Phylogenetic Network)...")
            try:
                network_path = Path("../data/neighbornet_txt.gml")
                if not network_path.exists():
                    print(f"Network file not found: {network_path}")
                    GraphDistMatrix = None
                else:
                    phylo_graph = nx.read_gml(str(network_path), label='id')
                    print(f"Loaded network with {len(phylo_graph.nodes())} nodes, {len(phylo_graph.edges())} edges")
                    
                    def find_node_matching_seqID(seqID, graph):
                        for node in graph.nodes(data=True):
                            if "label" in node[1]:
                                if node[1]["label"] == seqID:
                                    return node[0]
                        return None
                    
                    GraphDistMatrix = {}
                    network_features = [f for f in observed_in_tree if find_node_matching_seqID(f, phylo_graph) is not None]
                    print(f"Computing distances for {len(network_features)} features in network...")
                    
                    for i, feat1 in enumerate(network_features):
                        GraphDistMatrix[feat1] = {}
                        node1 = find_node_matching_seqID(feat1, phylo_graph)
                        
                        for feat2 in network_features:
                            node2 = find_node_matching_seqID(feat2, phylo_graph)
                            
                            try:
                                path = nx.shortest_path(phylo_graph, node1, node2)
                                dist = 0
                                for n in range(len(path)-1):
                                    edge_data = phylo_graph.get_edge_data(path[n], path[n+1])
                                    dist += float(edge_data.get("weight", 1.0))
                                GraphDistMatrix[feat1][feat2] = dist
                            except nx.NetworkXNoPath:
                                GraphDistMatrix[feat1][feat2] = float('inf')
                        
                        if (i + 1) % 1000 == 0:
                            print(f"  Progress: {i+1}/{len(network_features)}")
                    
                    print(f"GraphDistMatrix: {len(GraphDistMatrix)} x {len(GraphDistMatrix)} features")
                    
            except Exception as e:
                print(f"Error computing GraphDistMatrix: {e}")
                import traceback
                traceback.print_exc()
                GraphDistMatrix = None

            print(f"\n    Saving distance matrices...")
            matrices_output = Path(config["phylogeny_output"]) / f"MATRICES_{DISEASE}.pickle"
            
            matrices_to_save = {}
            if TreeDistMatrix is not None:
                matrices_to_save["TreeDistMatrix"] = TreeDistMatrix
                print(f"TreeDistMatrix included")
            if SeqDistMatrix is not None:
                matrices_to_save["SeqDistMatrix"] = SeqDistMatrix
                print(f"SeqDistMatrix included")
            if GraphDistMatrix is not None:
                matrices_to_save["GraphDistMatrix"] = GraphDistMatrix
                print(f"GraphDistMatrix included")
            
            if matrices_to_save:
                with open(matrices_output, "wb") as f:
                    pickle.dump(matrices_to_save, f, protocol=pickle.HIGHEST_PROTOCOL)
                print(f"\n    Distance matrices saved to: {matrices_output}")
                print(f" Total matrices: {len(matrices_to_save)}")
            else:
                print(f" No distance matrices to save")

        try:
            checkpoint_path = Path(config["output_dir"]) / "checkpoint_data.pkl"
            if checkpoint_path.exists():
                with open(checkpoint_path, 'rb') as f:
                    checkpoint_data = pickle.load(f)
                
                checkpoint_data.update({
                    'phylogeny_tip_count': total_tips,
                    'distance_matrix_created': True,
                    'num_distance_matrices': len(matrices_to_save) if 'matrices_to_save' in locals() else 0
                })
                with open(checkpoint_path, 'wb') as f:
                    pickle.dump(checkpoint_data, f, protocol=pickle.HIGHEST_PROTOCOL)
                print(f" Checkpoint updated")
        except Exception as e:
            print(f" Could not update checkpoint: {e}")

        print("\n    Distance matrix generation complete!")
        
    except Exception as e:
        print(f" Error in distance matrix generation: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f" Phylogeny file not found: {phylogeny_path}")


## Generate Output Files
### Save metadata and BIOM tables


In [None]:
                                                                               
                     
                                                                               
print(f"Saving metadata files for {DISEASE}...")

case_metadata = cases_valid.copy()
case_metadata['case_control'] = DISEASE.upper()
case_metadata['group'] = 'case'

control_metadata = matched_controls.copy()
control_metadata['case_control'] = 'Control'
control_metadata['group'] = 'control'

combined_metadata = pd.concat([case_metadata, control_metadata], ignore_index=True)

                       
output_cols = [sample_col, age_col, bmi_col, sex_col, disease_col, 'case_control', 'group']
if type_col:
    output_cols.append(type_col)

available_cols = [col for col in output_cols if col in combined_metadata.columns]
final_metadata = combined_metadata[available_cols].copy()

metadata_path = Path(config["metadata_output"])
final_metadata.to_csv(metadata_path / f"AGP_{DISEASE}_metadata.txt", sep="\t", index=False)
case_metadata[available_cols].to_csv(metadata_path / f"AGP_{DISEASE}_cases_metadata.txt", sep="\t", index=False)
control_metadata[available_cols].to_csv(metadata_path / f"AGP_{DISEASE}_controls_metadata.txt", sep="\t", index=False)

                                                       
with open(metadata_path / f"samples_{DISEASE.lower()}_cases.txt", "w") as f:
    f.write("\n".join(sorted(cases_in_biom)) + "\n")

with open(metadata_path / f"samples_{DISEASE.lower()}_controls.txt", "w") as f:
    f.write("\n".join(sorted(control_in_biom)) + "\n")

print(f"✓ Metadata saved to: {metadata_path}")
print(f"  Combined dataset: {len(final_metadata):,} samples")
print(f"  {DISEASE} cases: {len(case_metadata):,}")
print(f"  Controls: {len(control_metadata):,}")


In [None]:
                                                                               
                         
                                                                               
print(f"Saving BIOM tables for {DISEASE}...")

biom_path = Path(config["biom_output"])

if case_biom is not None:
    print(f"  Processing {DISEASE} BIOM table...")
    
    case_tsv_path = biom_path / f"AGP_{DISEASE}_cases.tsv"
    print(f"  Converting to TSV...")
    
    case_df = case_biom.to_dataframe()
    case_df = case_df.T
    case_df.index.name = "#SampleID"
    case_df.to_csv(case_tsv_path, sep="\t")
    
    print(f"✓ {DISEASE} TSV saved: {case_tsv_path}")

if control_biom is not None:
    print(f"  Processing Control BIOM table...")
    
    control_tsv_path = biom_path / f"AGP_{DISEASE}_controls.tsv"
    print(f"  Converting to TSV...")
    
    control_df = control_biom.to_dataframe()
    control_df = control_df.T
    control_df.index.name = "#SampleID"
    control_df.to_csv(control_tsv_path, sep="\t")
    
    print(f"✓ Control TSV saved: {control_tsv_path}")

print(f"\n✓ BIOM table processing complete!")
print(f"→ TSV files are ready for graph construction pipeline")


In [None]:
print("Updating checkpoint with new variables...")

checkpoint_path = Path(config["output_dir"]) / "checkpoint_data.pkl"
if checkpoint_path.exists():
    with open(checkpoint_path, 'rb') as f:
        checkpoint_data = pickle.load(f)

    biom_path = Path(config["biom_output"]) if "biom_output" in config else Path(config["output_dir"]) / "biom_tables"
    disease_tsv_path = biom_path / f"AGP_{DISEASE}_cases.tsv"
    control_tsv_path = biom_path / f"AGP_{DISEASE}_controls.tsv"
    all_tsv_files_created = disease_tsv_path.exists() and control_tsv_path.exists()

    phylogeny_output_dir = Path(config["phylogeny_output"]) if "phylogeny_output" in config else Path(config["output_dir"]) / "phylogeny"
    phylogeny_copied_path = phylogeny_output_dir / f"phylogeny_{DISEASE}.nwk"
    matrices_output_path = phylogeny_output_dir / f"MATRICES_{DISEASE}.pickle"

    sequences_output_dir = Path(config["sequences_output"]) if "sequences_output" in config else Path(config["output_dir"]) / "sequences"
    sequences_output_path = sequences_output_dir / f"ASV_sequences_{DISEASE}.fasta"

    checkpoint_data.update({
        'phylogeny_copied': phylogeny_copied_path.exists(),
        'distance_matrix_created': matrices_output_path.exists(),
        'sequences_extracted': sequences_output_path.exists(),
        'all_tsv_files_created': all_tsv_files_created
    })
    
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(checkpoint_data, f)
    
    print(f" Checkpoint updated with new variables")
    print(f" Added variables:")
    print(f"- phylogeny_copied")
    print(f"- distance_matrix_created") 
    print(f"- sequences_extracted")
    print(f"- all_tsv_files_created")
else:
    print(f" No existing checkpoint found to update")


## Pipeline Summary
### Outputs overview


In [None]:
print(f"{DISEASE} Data Extraction Pipeline Complete!")
print("=" * 60)

print(f"\n Dataset Summary:")
print(f"Total {DISEASE} cases: {len(cases_valid):,}")
print(f"Total matched controls: {len(matched_controls):,}")
print(f"Total samples in analysis: {len(cases_valid) + len(matched_controls):,}")

if case_biom is not None and control_biom is not None:
    print(f"\n Microbiome Data:")
    print(f"{DISEASE} features: {case_biom.shape[0]:,}")
    print(f"Control features: {control_biom.shape[0]:,}")
    print(f"{DISEASE} samples with data: {case_biom.shape[1]:,}")
    print(f"Control samples with data: {control_biom.shape[1]:,}")

print(f"\n Output Directory Structure:")
print(f"{config['output_dir']}/")
print(f" metadata/")
print(f"AGP_{DISEASE}_metadata.txt")
print(f"AGP_{DISEASE}_cases_metadata.txt")
print(f"AGP_{DISEASE}_controls_metadata.txt")
print(f"samples_{DISEASE.lower()}_cases.txt")
print(f"samples_{DISEASE.lower()}_controls.txt")
print(f" biom_tables/")
print(f"AGP_{DISEASE}_cases.tsv")
print(f"AGP_{DISEASE}_controls.tsv")
print(f" phylogeny/")
print(f"phylogeny_{DISEASE}.nwk")
print(f"MATRICES_{DISEASE}.pickle")
print(f" config/")
print(f" pipeline_config.json")

print(f"\n Next Steps:")
print(f"1. Use the generated files in your graph construction notebook")
print(f"2. The TSV files can be used for traditional microbiome analysis")
print(f"3. The distance matrix pickle is ready for graph-based ML")

print(f"\n Pipeline completed successfully!")
