# Multi-Disease Phylogenetic Graph Construction Pipeline

## Summary
- Build per-sample graphs from distance matrices and abundance tables.
- Save graphs, labels, and metadata under `{DISEASE}_analysis_output/graphs`.
- Optional visualizations and checkpoints.


## Overview

This notebook constructs per-sample phylogenetic graphs from the abundance
tables and distance matrices produced by `01_data_extraction.ipynb`. The target
disease is loaded from `pipeline_config.json` (or `EXPERIMENT_CONFIG_PATH`) and
outputs are written to `{DISEASE}_analysis_output/graphs`.

## Pipeline Architecture

```
Input Data (from 01_data_extraction.ipynb)
├── Distance matrix (MATRICES_{DISEASE}.pickle)
├── Abundance tables (AGP_{DISEASE}_cases.tsv, AGP_{DISEASE}_controls.tsv)
├── Sample ID lists (samples_{disease}_cases.txt, samples_{disease}_controls.txt)
└── Optional alignment file ({DISEASE}_seqs_CMaligned_with_Rfam.sto)

Graph Construction Pipeline
1. Configuration and input validation
2. Load distance matrix and abundance tables
3. Build base phylogenetic graph
4. Generate sample-specific graphs
5. Validate graphs and save outputs

Output Files
├── graphs/
│   ├── nx_graphs_{DISEASE}.pkl
│   ├── labels_{DISEASE}.npy
│   ├── graph_metadata_{DISEASE}.csv
│   ├── graph_config_{DISEASE}.json
│   └── graph_examples/
└── visualizations/
    └── statistics/
```

## Key Features

- Multi-disease support driven by config
- Robust validation and checkpointing
- Graph quality control and optional visualizations
- Organized outputs for downstream modeling

## Prerequisites

- Completed execution of `01_data_extraction.ipynb` for the target disease
- Generated files in `{DISEASE}_analysis_output/` directory
- Required Python packages installed


In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import random
import pickle
import copy
from pathlib import Path
import json
import time
from datetime import datetime

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

from skbio import TreeNode
from Bio import SeqIO

import networkx as nx

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle

try:
    import torch
    TORCH_AVAILABLE = True
    print(f"PyTorch {torch.__version__}")
    print(f"CUDA: {torch.cuda.is_available()}")
except ImportError:
    TORCH_AVAILABLE = False
    print("PyTorch not available")

from tqdm import tqdm

print(f"\nGraph Construction - {datetime.now().strftime('%Y-%m-%d %H:%M')}")
print(f"Working dir: {os.getcwd()}")

## Config setup
### Load config and paths


In [None]:
                                                       
_default_disease = os.environ.get("DISEASE", "IBD").upper()
config_file = Path(f"{_default_disease}_analysis_output/config/pipeline_config.json")
if config_file.exists():
    with open(config_file, "r") as f:
        config = json.load(f)
    print("Loaded configuration from previous notebook")
else:
    print("No configuration file found - using default settings")
    config = {
        "disease": "IBD",
        "output_dir": f"{_default_disease}_analysis_output"
    }

DISEASE = config.get("disease", "IBD")

                                                      
if 'EXPERIMENT_CONFIG' in globals() and EXPERIMENT_CONFIG:
    exp_disease = EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease', None)
    if exp_disease and exp_disease != 'N/A':
        DISEASE = exp_disease
        print(f'Using disease from experiment config: {DISEASE}')
output_dir = Path(config.get("output_dir", f"{DISEASE}_analysis_output"))

PICKLE_FILE = output_dir / "phylogeny" / f"MATRICES_{DISEASE}.pickle"
CASE_BIOM_TSV = output_dir / "biom_tables" / f"AGP_{DISEASE}_cases.tsv"
CONTROL_BIOM_TSV = output_dir / "biom_tables" / f"AGP_{DISEASE}_controls.tsv"
ALN_STOCKHOLM_FILE = f"{DISEASE}_seqs_CMaligned_with_Rfam.sto"

CASE_IDS_FILE = output_dir / "metadata" / f"samples_{DISEASE.lower()}_cases.txt"
CONTROL_IDS_FILE = output_dir / "metadata" / f"samples_{DISEASE.lower()}_controls.txt"

graphs_output_dir = output_dir / "graphs"
visualizations_dir = output_dir / "visualizations"
checkpoints_dir = output_dir / "checkpoints"

for dir_path in [graphs_output_dir, visualizations_dir, checkpoints_dir]:
    dir_path.mkdir(exist_ok=True)
    
(graphs_output_dir / "graph_examples").mkdir(exist_ok=True)
(visualizations_dir / "statistics").mkdir(exist_ok=True)

OUT_GRAPHS_PKL = graphs_output_dir / f"nx_graphs_{DISEASE}.pkl"
OUT_LABELS_NPY = graphs_output_dir / f"labels_{DISEASE}.npy"
OUT_METADATA_CSV = graphs_output_dir / f"graph_metadata_{DISEASE}.csv"
OUT_CONFIG_JSON = graphs_output_dir / f"graph_config_{DISEASE}.json"
CHECKPOINT_FILE = checkpoints_dir / f"graph_construction_checkpoint_{DISEASE}.pkl"

print(f" Disease focus: {DISEASE}")
print(f" Output directory: {output_dir}")
print(f" Graph output directory: {graphs_output_dir}")
print(f" Visualizations directory: {visualizations_dir}")
print(f" Checkpoints directory: {checkpoints_dir}")


In [None]:
import yaml

config_path = None
if 'EXPERIMENT_CONFIG_PATH' in os.environ:
    config_path = Path(os.environ['EXPERIMENT_CONFIG_PATH'])
    print(f" Using config from environment: {config_path}")
elif Path("config.yaml").exists():
    config_path = Path("config.yaml")
    print(f" Found config in current directory")

if config_path and config_path.exists():
    print(f" Loading experiment configuration from: {config_path}")
    with open(config_path, 'r') as f:
        EXPERIMENT_CONFIG = yaml.safe_load(f)
    print(f" Loaded configuration for experiment")
    
    # FIX: Look for disease in multiple locations
    exp_disease = (
        EXPERIMENT_CONFIG.get('disease') or  # Root level
        EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease') or  # data_extraction.disease
        EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease_criteria', {}).get('disease')  # disease_criteria
    )
    
    gc_config = EXPERIMENT_CONFIG.get('graph_construction', {})
    knn_k = gc_config.get('knn', {}).get('k', 'default')
    
    print(f"Disease from config: {exp_disease}")
    print(f"Graph k-NN: k={knn_k}")
else:
    print(" No experiment config found - using default parameters from notebook")
    EXPERIMENT_CONFIG = {}
    exp_disease = None

# FIX: Update DISEASE immediately after loading config
if EXPERIMENT_CONFIG and exp_disease:
    DISEASE = exp_disease.upper()  # Ensure uppercase
    print(f"DISEASE updated to: {DISEASE}")
    
    # Recalculate ALL paths with new DISEASE (NO notebooks/ prefix - we're already in notebooks/)
    output_dir = Path(f"{DISEASE}_analysis_output")
    graphs_output_dir = output_dir / "graphs"
    visualizations_dir = output_dir / "visualizations"
    checkpoints_dir = output_dir / "checkpoints"
    
    # Create directories
    for dir_path in [graphs_output_dir, visualizations_dir, checkpoints_dir]:
        dir_path.mkdir(parents=True, exist_ok=True)
    
    PICKLE_FILE = output_dir / "phylogeny" / f"MATRICES_{DISEASE}.pickle"
    CASE_BIOM_TSV = output_dir / "biom_tables" / f"AGP_{DISEASE}_cases.tsv"
    CONTROL_BIOM_TSV = output_dir / "biom_tables" / f"AGP_{DISEASE}_controls.tsv"
    CASE_IDS_FILE = output_dir / "metadata" / f"samples_{DISEASE.lower()}_cases.txt"
    CONTROL_IDS_FILE = output_dir / "metadata" / f"samples_{DISEASE.lower()}_controls.txt"
    
    OUT_GRAPHS_PKL = graphs_output_dir / f"nx_graphs_{DISEASE}.pkl"
    OUT_LABELS_NPY = graphs_output_dir / f"labels_{DISEASE}.npy"
    OUT_METADATA_CSV = graphs_output_dir / f"graph_metadata_{DISEASE}.csv"
    OUT_CONFIG_JSON = graphs_output_dir / f"graph_config_{DISEASE}.json"
    CHECKPOINT_FILE = checkpoints_dir / f"graph_construction_checkpoint_{DISEASE}.pkl"
    
    print(f"Output directory: {output_dir}")
    print(f"Graphs output: {OUT_GRAPHS_PKL}")
    print(f"Data files: {PICKLE_FILE}")
    
    # Read FORCE_REBUILD from config
    force_rebuild_config = EXPERIMENT_CONFIG.get('graph_construction', {}).get('force_rebuild', False)
    if force_rebuild_config:
        FORCE_REBUILD = True
        print("FORCE_REBUILD set to True from config")


In [None]:
SKIP_GRAPH_CONSTRUCTION = False
# Read FORCE_REBUILD from environment or keep default
FORCE_REBUILD = os.environ.get("FORCE_REBUILD", "False").lower() in ["true", "1", "yes"]

if OUT_GRAPHS_PKL.exists() and OUT_LABELS_NPY.exists() and not FORCE_REBUILD:
    import pickle
    import numpy as np
    
    try:
        with open(OUT_GRAPHS_PKL, 'rb') as f:
            existing_graphs = pickle.load(f)
        existing_labels = np.load(OUT_LABELS_NPY)
        
        if len(existing_graphs) > 0 and len(existing_labels) > 0:
            print("="*60)
            print("EXISTING GRAPHS FOUND - SKIPPING REBUILD")
            print("="*60)
            print(f"Graphs file: {OUT_GRAPHS_PKL}")
            print(f"Number of graphs: {len(existing_graphs)}")
            print(f"Number of labels: {len(existing_labels)}")
            print(f"File modified: {OUT_GRAPHS_PKL.stat().st_mtime}")
            print()
            print(f"To force rebuild, set FORCE_REBUILD = True")
            print("="*60)
            SKIP_GRAPH_CONSTRUCTION = True
            
            GRAPHS_PREPARED = existing_graphs
            LABELS = existing_labels
    except Exception as e:
        print(f" Existing graphs file is corrupted or invalid: {e}")
        print(f"Will rebuild graphs from scratch...")
        SKIP_GRAPH_CONSTRUCTION = False

if not SKIP_GRAPH_CONSTRUCTION:
    print("Will build graphs from scratch...")


In [None]:
GRAPH_PARAMS = {
    "min_nodes": 8,
    "min_edges": 5,
    "normalize_weights": True,
    "max_nodes_per_graph": 3000,
    "connectivity_check": False,
    "remove_isolates": True,
    "parallel_processing": False,
    "batch_size": 50,
    "save_intermediate": True,
    "graph_type": "knn",    # supported: knn, tree, mst, threshold, hierarchical
    "knn_k": 10,
    "knn_symmetric": True,
    "knn_max_distance_factor": 3.5,
    "threshold_percentile": 25,
    "tree_ancestor_levels": 3,
    "randomize_edges": False,
    "preserve_degree": True,
    "weight_transform": "identity"
}

gc_config = EXPERIMENT_CONFIG.get('graph_construction', {})

if 'graph_type' in gc_config:
    GRAPH_PARAMS['graph_type'] = gc_config['graph_type']
    print(f"Using graph_type from config: {gc_config['graph_type']}")

knn_config = gc_config.get('knn', {})
if 'k' in knn_config:
    GRAPH_PARAMS['knn_k'] = knn_config['k']
if 'symmetric' in knn_config:
    GRAPH_PARAMS['knn_symmetric'] = knn_config['symmetric']
if 'max_distance_factor' in knn_config:
    GRAPH_PARAMS['knn_max_distance_factor'] = knn_config['max_distance_factor']

threshold_config = gc_config.get('threshold', {})
if 'percentile' in threshold_config:
    GRAPH_PARAMS['threshold_percentile'] = threshold_config['percentile']

tree_config = gc_config.get('tree', {})
if 'ancestor_levels' in tree_config:
    GRAPH_PARAMS['tree_ancestor_levels'] = tree_config['ancestor_levels']

quality_config = gc_config.get('quality', {})
if 'min_nodes' in quality_config:
    GRAPH_PARAMS['min_nodes'] = quality_config['min_nodes']
if 'min_edges' in quality_config:
    GRAPH_PARAMS['min_edges'] = quality_config['min_edges']

edge_config = gc_config.get('edge_construction', {})
if 'randomize_edges' in edge_config:
    GRAPH_PARAMS['randomize_edges'] = edge_config['randomize_edges']
if 'preserve_degree' in edge_config:
    GRAPH_PARAMS['preserve_degree'] = edge_config['preserve_degree']

weight_config = gc_config.get('weights', {})
if 'weight_transform' in weight_config:
    GRAPH_PARAMS['weight_transform'] = weight_config['weight_transform']

VIZ_PARAMS = {
    "max_nodes_to_plot": 100,
    "node_size_scale": 100,
    "edge_width_scale": 1,
    "figsize": (12, 8),
    "dpi": 150,
    "save_plots": False,
    "show_plots": False
}

print(f" Graph construction parameters:")
for key, value in GRAPH_PARAMS.items():
    print(f"- {key}: {value}")

print(f"\n Visualization parameters:")
for key, value in VIZ_PARAMS.items():
    print(f"- {key}: {value}")

print(f"\n Expected outputs:")
print(f"- Graphs: {OUT_GRAPHS_PKL}")
print(f"- Labels: {OUT_LABELS_NPY}")
print(f"- Metadata: {OUT_METADATA_CSV}")
print(f"- Config: {OUT_CONFIG_JSON}")
print(f"- Checkpoint: {CHECKPOINT_FILE}")


## Checkpoint System and Input Validation

### Checkpoint Loading


In [None]:
def load_checkpoint(checkpoint_file):
    """Load checkpoint data if available."""
    if checkpoint_file.exists():
        try:
            with open(checkpoint_file, 'rb') as f:
                checkpoint_data = pickle.load(f)
            print(f" Loaded checkpoint from: {checkpoint_file}")
            return checkpoint_data
        except Exception as e:
            print(f" Error loading checkpoint: {e}")
            return None
    return None

def save_checkpoint(checkpoint_data, checkpoint_file):
    """Save checkpoint data."""
    try:
        with open(checkpoint_file, 'wb') as f:
            pickle.dump(checkpoint_data, f, protocol=pickle.HIGHEST_PROTOCOL)
        print(f" Checkpoint saved to: {checkpoint_file}")
    except Exception as e:
        print(f" Error saving checkpoint: {e}")

if SKIP_GRAPH_CONSTRUCTION:
    checkpoint_data = {'step': 5, 'skipped': True}
    print("Skipping checkpoint loading - using existing graphs")
else:
    checkpoint_data = load_checkpoint(CHECKPOINT_FILE)

if checkpoint_data is None:
    checkpoint_data = {
        'step': 0,
        'base_graph': None,
        'seq_to_ids': None,
        'graphs': [],
        'labels': [],
        'metadata': [],
        'config': config,
        'graph_params': GRAPH_PARAMS,
        'timestamp': datetime.now().isoformat()
    }
    print("Starting fresh - no checkpoint found")
else:
    step_names = {
        0: "Initialization",
        1: "Step 4: Data Loading and Distance Matrix Processing",
        2: "Step 4: Sequence Mappings (completed)",
        3: "Step 5: Base Phylogenetic Graph Construction",
        4: "Step 7: Sample-Specific Graph Construction",
        5: "Step 9: Output Generation and Pipeline Summary"
    }
    
    current_step = checkpoint_data.get('step', 0)
    current_step_name = step_names.get(current_step, f"Unknown step {current_step}")
    
    print(f" Checkpoint contains:")
    print(f"- Checkpoint Step: {current_step}")
    print(f"- Pipeline Step: {current_step_name}")
    print(f"- Base graph: {'' if checkpoint_data.get('base_graph') is not None else ''}")
    print(f"- Sequence mappings: {'' if checkpoint_data.get('seq_to_ids') is not None else ''}")
    print(f"- Sample graphs: {len(checkpoint_data.get('graphs', []))}")
    print(f"- Labels: {len(checkpoint_data.get('labels', []))}")
    print(f"- Timestamp: {checkpoint_data.get('timestamp', 'Unknown')}")
    
    next_steps = {
        0: "Step 4: Data Loading and Distance Matrix Processing",
        1: "Step 4: Sequence Mappings",
        2: "Step 5: Base Phylogenetic Graph Construction", 
        3: "Step 7: Sample-Specific Graph Construction",
        4: "Step 9: Output Generation and Pipeline Summary",
        5: "Pipeline completed! "
    }
    
    next_step = next_steps.get(current_step, "Unknown next step")
    print(f"\n Next step: {next_step}")


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping - using existing graphs")
else:
    if SKIP_GRAPH_CONSTRUCTION:
        print("Skipping input validation - using existing graphs")
    else:
        print("Validating input files...")

    required_files = [
        (PICKLE_FILE, "Distance matrix pickle"),
        (CASE_BIOM_TSV, f"{DISEASE} cases abundance table"),
        (CONTROL_BIOM_TSV, "Control abundance table"),
        (CASE_IDS_FILE, f"{DISEASE} sample IDs"),
        (CONTROL_IDS_FILE, "Control sample IDs")
    ]

    optional_files = [
        (ALN_STOCKHOLM_FILE, "Stockholm alignment file")
    ]

    missing_files = []
    for file_path, description in required_files:
        if file_path.exists():
            size_mb = file_path.stat().st_size / (1024*1024)
            print(f" {description}: {file_path} ({size_mb:.1f} MB)")
        else:
            print(f" {description}: {file_path} - MISSING")
            missing_files.append(file_path)

    print("\n Optional files:")
    for file_path, description in optional_files:
        path_obj = Path(file_path)
        if path_obj.exists():
            size_mb = path_obj.stat().st_size / (1024*1024)
            print(f" {description}: {path_obj} ({size_mb:.1f} MB)")
        else:
            print(f" {description}: {path_obj} - Not found (optional)")

    if missing_files:
        print(f"\n Missing required files: {len(missing_files)}")
        print("Please run the data extraction notebook first to generate these files.")
        raise FileNotFoundError("Required input files are missing")
    else:
        print(f"\n All required files found!")



## Data Loading and Distance Matrix Processing

### Load Distance Matrix


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping distance matrix loading - using existing graphs")
elif checkpoint_data['step'] < 1 or checkpoint_data.get('base_graph') is None:
    print("Loading phylogenetic distance matrix...")
    try:
        with open(PICKLE_FILE, "rb") as f:
            try:
                mats = pickle.load(f)
            except ModuleNotFoundError as e:
                import sys as _sys
                import numpy as _np
                if "numpy._core" not in _sys.modules:
                    _sys.modules["numpy._core"] = _np.core
                try:
                    import numpy.core.numeric as _numeric
                    _sys.modules["numpy._core.numeric"] = _numeric
                except Exception:
                    pass
                try:
                    import numpy.core.multiarray as _multiarray
                    _sys.modules["numpy._core.multiarray"] = _multiarray
                except Exception:
                    pass
                f.seek(0)
                mats = pickle.load(f)
        
        print(f"Loaded matrices: {list(mats.keys())}")
        dist_mat = mats.get("TreeDistMatrix")
        
        if dist_mat is None:
            raise ValueError("TreeDistMatrix not found in pickle file")
        
        print(f"Distance matrix type: {type(dist_mat)}")
        
        if hasattr(dist_mat, 'shape'):
            print(f"Distance matrix shape: {dist_mat.shape}")
        elif isinstance(dist_mat, dict):
            print(f"Distance matrix entries: {len(dist_mat)}")
            if dist_mat:
                key0 = next(iter(dist_mat))
                print(f"Example key: {key0} → type: {type(dist_mat[key0])}")
        else:
            print(f"Distance matrix: {dist_mat}")
            
        print(f" Distance matrix loaded successfully")
        
        checkpoint_data['dist_mat'] = dist_mat
        checkpoint_data['step'] = 1
        save_checkpoint(checkpoint_data, CHECKPOINT_FILE)
        
    except Exception as e:
        print(f" Error loading distance matrix: {e}")
        raise
else:
    print("Distance matrix already loaded from checkpoint")
    dist_mat = checkpoint_data['dist_mat']


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping abundance table loading - using existing graphs")
else:
    print("Loading abundance tables...")

    try:
        if CASE_IDS_FILE.exists():
            case_ids = pd.read_csv(CASE_IDS_FILE, header=None)[0].astype(str).tolist()
            case_ids = [x for x in case_ids if x.strip()]
            print(f"{DISEASE} sample IDs loaded: {len(case_ids):,}")
        else:
            raise FileNotFoundError(f"{DISEASE} sample IDs file not found: {CASE_IDS_FILE}")
        
        if CONTROL_IDS_FILE.exists():
            control_ids = pd.read_csv(CONTROL_IDS_FILE, header=None)[0].astype(str).tolist()
            control_ids = [x for x in control_ids if x.strip()]
            print(f"Control sample IDs loaded: {len(control_ids):,}")
        else:
            raise FileNotFoundError(f"Control sample IDs file not found: {CONTROL_IDS_FILE}")
        
        print(f"Loading {DISEASE} cases table...")
        df_CASE_raw = pd.read_csv(CASE_BIOM_TSV, sep='\t', index_col=0)
        print(f"Raw {DISEASE} cases table: {df_CASE_raw.shape[0]:,} rows × {df_CASE_raw.shape[1]:,} columns")
        
        print(f"Loading controls table...")
        df_CONTROL_raw = pd.read_csv(CONTROL_BIOM_TSV, sep='\t', index_col=0)
        print(f"Raw controls table: {df_CONTROL_raw.shape[0]:,} rows × {df_CONTROL_raw.shape[1]:,} columns")
        
        print("\n    Checking table orientation...")
        
        case_rows_are_samples = any(str(idx) in case_ids for idx in list(df_CASE_raw.index)[:10])
        control_rows_are_samples = any(str(idx) in control_ids for idx in list(df_CONTROL_raw.index)[:10])
        
        if case_rows_are_samples or control_rows_are_samples:
            print(f"  DETECTED: Tables are TRANSPOSED (rows=samples, columns=features)")
            print(f" Transposing tables to correct format (rows=features, columns=samples)...")
            df_CASE_raw = df_CASE_raw.T
            df_CONTROL_raw = df_CONTROL_raw.T
            print(f"After transpose - {DISEASE} cases: {df_CASE_raw.shape[0]:,} features × {df_CASE_raw.shape[1]:,} samples")
            print(f"After transpose - Controls: {df_CONTROL_raw.shape[0]:,} features × {df_CONTROL_raw.shape[1]:,} samples")
        else:
            print(f" Tables are in correct orientation (rows=features, columns=samples)")
        
        print(f"\n    Filtering to {DISEASE}/control samples...")
        
        available_case_samples = [col for col in df_CASE_raw.columns if str(col) in case_ids]
        df_CASE = df_CASE_raw[available_case_samples]
        print(f"{DISEASE} cases filtered: {df_CASE.shape[0]:,} features × {df_CASE.shape[1]:,} samples")
        print(f"Found {len(available_case_samples)}/{len(case_ids)} {DISEASE} samples in table")
        
        available_control_samples = [col for col in df_CONTROL_raw.columns if str(col) in control_ids]
        df_CONTROL = df_CONTROL_raw[available_control_samples]
        print(f"Controls filtered: {df_CONTROL.shape[0]:,} features × {df_CONTROL.shape[1]:,} samples")
        print(f"Found {len(available_control_samples)}/{len(control_ids)} control samples in table")
        
        common_features = set(df_CASE.index) & set(df_CONTROL.index)
        print(f"\n   Common features between cases and controls: {len(common_features):,}")
        
        print("\n    Cleaning feature IDs (removing trailing whitespace)...")
        df_CASE.index = df_CASE.index.str.strip()
        df_CONTROL.index = df_CONTROL.index.str.strip()
        print(f" Feature IDs cleaned")
        
        print("\n    Final abundance table dimensions:")
        print(f"   {DISEASE} cases: {df_CASE.shape[0]:,} features × {df_CASE.shape[1]:,} samples")
        print(f"   Controls: {df_CONTROL.shape[0]:,} features × {df_CONTROL.shape[1]:,} samples")
        print(f"   Total samples to process: {df_CASE.shape[1] + df_CONTROL.shape[1]:,}")
        
        if df_CASE.shape[1] == 0 or df_CONTROL.shape[1] == 0:
            raise ValueError("No samples found after filtering! Check sample ID matching.")
        
        print("\n    Abundance tables loaded and filtered successfully")
        
    except Exception as e:
        print(f" Error loading abundance tables: {e}")
        raise


### Load Sequence Alignments (Optional)


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping - using existing graphs")
else:
    if checkpoint_data['step'] < 2 or checkpoint_data.get('seq_to_ids') is None:
        seq_to_ids = {}
        stockholm_available = False

        if Path(ALN_STOCKHOLM_FILE).exists():
            print("Loading Stockholm alignment file...")
            try:
                aln_dict = SeqIO.to_dict(SeqIO.parse(ALN_STOCKHOLM_FILE, "stockholm"))
                seq_to_ids = {str(v.seq).replace("-", "").replace("U", "T"): k for k, v in aln_dict.items()}
                print(f"Aligned sequences: {len(seq_to_ids):,}")
                stockholm_available = True
                print(f" Stockholm alignment loaded successfully")
            except Exception as e:
                print(f" Error loading Stockholm alignment: {e}")
                print(f" Continuing without sequence alignment mapping")
        else:
            print("Stockholm alignment file not found - using direct feature mapping")
            print(f"This may reduce graph quality if sequences don't match phylogenetic IDs")

        if not stockholm_available:
            print("Creating fallback sequence mapping...")
            all_features = set(df_CASE.index) | set(df_CONTROL.index)
            seq_to_ids = {feature_id: feature_id for feature_id in all_features}
            print(f"Fallback mappings: {len(seq_to_ids):,}")

        print(f" Total sequence mappings: {len(seq_to_ids):,}")

        print("\n Feature ID diagnostic check:")
        sample_features = list(seq_to_ids.keys())[:5]
        for i, feat_id in enumerate(sample_features):
            print(f"[{i}] Length={len(feat_id):3d}: {feat_id[:80]}..." if len(feat_id) > 80 else f"   [{i}] Length={len(feat_id):3d}: {feat_id}")

        checkpoint_data['seq_to_ids'] = seq_to_ids
        checkpoint_data['stockholm_available'] = stockholm_available
        checkpoint_data['step'] = 2
        save_checkpoint(checkpoint_data, CHECKPOINT_FILE)
    else:
        print("Sequence mappings already loaded from checkpoint")
        seq_to_ids = checkpoint_data['seq_to_ids']
        stockholm_available = checkpoint_data.get('stockholm_available', False)
        print(f"Total sequence mappings: {len(seq_to_ids):,}")
        print(f"Stockholm alignment available: {stockholm_available}")



In [None]:
def efficient_graph_pruning(G, max_edges=None, distance_threshold=None):
    """
    Efficiently prune a graph using multiple strategies:
    1. Distance threshold filtering (if provided)
    2. Minimum spanning tree construction
    3. Optional edge limit
    
    Parameters:
    - G: NetworkX graph
    - max_edges: Maximum number of edges to keep (optional)
    - distance_threshold: Remove edges with weight > threshold (optional)
    
    Returns:
    - Pruned NetworkX graph
    """
    print(f"Starting with {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges")
    
    if distance_threshold is not None:
        print(f"Applying distance threshold: {distance_threshold:.4f}")
        edges_to_remove = [(u, v) for u, v, data in G.edges(data=True) 
                          if data.get('weight', 0) > distance_threshold]
        
        G.remove_edges_from(edges_to_remove)
        print(f"After threshold filtering: {G.number_of_edges():,} edges (removed {len(edges_to_remove):,})")
        
        isolated = list(nx.isolates(G))
        if isolated:
            G.remove_nodes_from(isolated)
            print(f"Removed {len(isolated):,} isolated nodes")
    
    if max_edges is None or G.number_of_edges() <= max_edges:
        print(f"Graph already within target size: {G.number_of_edges():,} edges")
        return G
    
    print(f"Computing minimum spanning tree (target: {max_edges:,} edges)...")
    try:
        mst = nx.minimum_spanning_tree(G, weight='weight', algorithm='kruskal')
        print(f"MST has {mst.number_of_edges():,} edges")
        
        if mst.number_of_edges() < max_edges:
            print(f"Adding back {max_edges - mst.number_of_edges():,} shortest edges...")
            
            mst_edges = set(mst.edges())
            remaining_edges = [(u, v, data['weight']) for u, v, data in G.edges(data=True) 
                             if (u, v) not in mst_edges and (v, u) not in mst_edges]
            remaining_edges.sort(key=lambda x: x[2])
            
            edges_to_add = min(len(remaining_edges), max_edges - mst.number_of_edges())
            for u, v, weight in remaining_edges[:edges_to_add]:
                mst.add_edge(u, v, weight=weight)
            
            print(f"Final graph: {mst.number_of_edges():,} edges")
        
        return mst
        
    except Exception as e:
        print(f" Error in MST computation: {e}")
        if max_edges is not None and G.number_of_edges() > max_edges:
            print(f"Fallback: limiting to {max_edges:,} shortest edges")
            edges_with_weights = [(u, v, data['weight']) for u, v, data in G.edges(data=True)]
            edges_with_weights.sort(key=lambda x: x[2])
            
            G_limited = nx.Graph()
            G_limited.add_nodes_from(G.nodes())
            
            for u, v, weight in edges_with_weights[:max_edges]:
                G_limited.add_edge(u, v, weight=weight)
            
            print(f"Fallback graph: {G_limited.number_of_edges():,} edges")
            return G_limited
        
        return G

print("Optimized graph pruning function defined")


## Base Phylogenetic Graph Construction

### Build Base Graph from Distance Matrix


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping - using existing graphs")
else:
    if checkpoint_data['step'] < 3 or checkpoint_data.get('base_graph') is None:
        graph_type = GRAPH_PARAMS.get('graph_type', 'knn')
        print(f"Building base phylogenetic graph (type: {graph_type})...")

        try:
            # For non-kNN graph types, use src/graph_utils functions
            if graph_type in ['threshold', 'mst', 'tree']:
                print(f"Using {graph_type} graph builder from src/graph_utils")
                
                # Convert distance matrix to numpy format if needed
                if hasattr(dist_mat, 'data') and hasattr(dist_mat, 'ids'):
                    dm_numpy = dist_mat.data
                    node_ids = list(dist_mat.ids)
                elif isinstance(dist_mat, dict):
                    # Handle nested dict format - extract TreeDistMatrix
                    first_key = next(iter(dist_mat.keys()))
                    if 'DistMatrix' in str(first_key):
                        nested_dist = dist_mat.get('TreeDistMatrix', list(dist_mat.values())[0])
                    else:
                        nested_dist = dist_mat
                    
                    node_ids = list(nested_dist.keys())
                    n = len(node_ids)
                    dm_numpy = np.zeros((n, n))
                    node_to_idx = {nid: i for i, nid in enumerate(node_ids)}
                    
                    for node1, neighbors in tqdm(nested_dist.items(), desc="Converting to numpy"):
                        i = node_to_idx[node1]
                        for node2, dist in neighbors.items():
                            if node2 in node_to_idx:
                                j = node_to_idx[node2]
                                dm_numpy[i, j] = dist
                                dm_numpy[j, i] = dist
                    print(f"Converted to numpy: {dm_numpy.shape}")
                else:
                    dm_numpy = dist_mat
                    node_ids = [f"node_{i}" for i in range(dm_numpy.shape[0])]
                
                # Build graph based on type
                if graph_type == 'threshold':
                    from src.graph_utils import build_threshold_graph
                    percentile = GRAPH_PARAMS.get('threshold_percentile', 25)
                    G = build_threshold_graph(dm_numpy, node_ids, threshold_percentile=percentile)
                    print(f"Threshold graph: percentile={percentile}, {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
                
                elif graph_type == 'mst':
                    from src.graph_utils import build_mst_graph
                    G = build_mst_graph(dm_numpy, node_ids)
                    print(f"MST graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
                
                elif graph_type == 'tree':
                    # Tree requires phylogenetic tree file
                    tree_path = output_dir / "phylogeny" / f"phylogeny_{DISEASE}.nwk"
                    if tree_path.exists():
                        from src.graph_utils import build_tree_graph
                        ancestor_levels = GRAPH_PARAMS.get('tree_ancestor_levels', 3)
                        G = build_tree_graph(tree_path, node_ids, ancestor_levels=ancestor_levels)
                        print(f"Tree graph: ancestor_levels={ancestor_levels}, {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
                    else:
                        print(f"Tree file not found: {tree_path}, falling back to k-NN")
                        graph_type = 'knn'  # Fall back to kNN
                
                # Save checkpoint if successful
                if graph_type != 'knn':
                    checkpoint_data['base_graph'] = G
                    checkpoint_data['step'] = 3
                    save_checkpoint(checkpoint_data, CHECKPOINT_FILE)
                    print(f"Base {graph_type} graph constructed successfully")
            
            # Default k-NN implementation (original code)
            if graph_type == 'knn':
                G = nx.Graph()
            
                nested_dist = None

                if isinstance(dist_mat, dict):
                    print(f"Processing dictionary-based distance matrix...")

                    first_key = next(iter(dist_mat.keys()))
                if isinstance(first_key, str) and 'DistMatrix' in first_key:
                    primary_matrix = EXPERIMENT_CONFIG.get('data_extraction', {}).get('distance_matrices', {}).get('primary_matrix', 'TreeDistMatrix')
                    print(f"Selecting primary matrix: {primary_matrix}")
                    if primary_matrix not in dist_mat:
                        available = list(dist_mat.keys())
                        print(f"WARNING: {primary_matrix} not found, using {available[0]}")
                        primary_matrix = available[0]
                    nested_dist = dist_mat[primary_matrix]
                else:
                    nested_dist = dist_mat

                edge_count = 0
                first_val = next(iter(nested_dist.values()))
                if isinstance(first_val, dict):
                    print(f"Structure: nested dict (node -> {{neighbor -> distance}})")
                    for node1, neighbors in tqdm(nested_dist.items(), desc="Adding edges"):
                        for node2, distance in neighbors.items():
                            if node1 != node2:
                                G.add_edge(node1, node2, weight=distance)
                                edge_count += 1
                else:
                    print(f"Structure: flat dict ((node1, node2) -> distance)")
                    for (node1, node2), distance in tqdm(nested_dist.items(), desc="Adding edges"):
                        if node1 != node2:
                            G.add_edge(node1, node2, weight=distance)
                            edge_count += 1
                print(f"Added {edge_count:,} edges from distance matrix")
            else:
                # Handle both matrix and dict formats for distance matrix
                G = nx.Graph()  # Create graph for phylogenetic type
                
                if isinstance(dist_mat, dict):
                    print(f"Processing dictionary-based distance matrix (else branch)...")
                    # dist_mat is a dict, use kNN logic instead of matrix logic
                    nested_dist = dist_mat
                    
                    # Get node IDs from dict keys
                    node_ids = list(nested_dist.keys())
                    print(f"  Found {len(node_ids):,} nodes in distance dict")
                    
                    # Build graph from dict (same as kNN branch)
                    edge_count = 0
                    first_val = next(iter(nested_dist.values()))
                    if isinstance(first_val, dict):
                        for node1, neighbors in tqdm(nested_dist.items(), desc="Adding edges"):
                            for node2, distance in neighbors.items():
                                if node1 != node2:
                                    G.add_edge(node1, node2, weight=distance)
                                    edge_count += 1
                    else:
                        for (node1, node2), distance in tqdm(nested_dist.items(), desc="Adding edges"):
                            if node1 != node2:
                                G.add_edge(node1, node2, weight=distance)
                                edge_count += 1
                    print(f"  Added {edge_count:,} edges from distance dict")
                else:
                    print(f"Processing matrix-based distance matrix...")

                    if hasattr(dist_mat, 'ids'):
                        node_ids = list(dist_mat.ids)
                        print(f" Extracted {len(node_ids):,} feature IDs from distance matrix")
                        print(f"Sample IDs: {node_ids[0][:80]}..." if len(node_ids[0]) > 80 else f"   Sample IDs: {node_ids[0]}")
                    else:
                        print(f"  WARNING: Distance matrix has no .ids attribute - using generic node names")
                        node_ids = [f"node_{i}" for i in range(dist_mat.shape[0])]

                    if hasattr(dist_mat, 'todense'):
                        dist_dense = dist_mat.todense()
                    elif hasattr(dist_mat, 'data'):
                        dist_dense = dist_mat.data
                    else:
                        dist_dense = dist_mat

                    n_nodes = dist_dense.shape[0]
                    for i in tqdm(range(n_nodes), desc="Processing matrix rows"):
                        for j in range(i+1, n_nodes):
                            if dist_dense[i, j] > 0:
                                G.add_edge(node_ids[i], node_ids[j], weight=float(dist_dense[i, j]))

            print(f"Initial graph: {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges")
            print(f"Connected: {nx.is_connected(G)}")

            print(f"Building sparse k-NN graph (t2d-style)...")
            k = GRAPH_PARAMS.get("knn_k", 8)
            symmetric = GRAPH_PARAMS.get("knn_symmetric", True)
            max_factor = GRAPH_PARAMS.get("knn_max_distance_factor", None)

            if hasattr(dist_mat, 'ids') and hasattr(dist_mat, 'data'):
                node_ids = list(dist_mat.ids)
                dm = dist_mat.data
                use_numpy_dm = True
            elif nested_dist is not None and isinstance(nested_dist, dict):
                node_ids = list(nested_dist.keys())
                print(f"Using nested dict format with {len(node_ids):,} nodes")
                use_numpy_dm = False
            else:
                raise ValueError("Unsupported distance matrix type for k-NN build")

            import numpy as _np
            n = len(node_ids)
            knn_edges = []

            if use_numpy_dm:
                global_median = _np.median(dm[_np.triu_indices(n, k=1)])
                max_allowed = None if max_factor is None else global_median * float(max_factor)

                for i in range(n):
                    row = dm[i].astype(float)
                    row[i] = _np.inf
                    if k < n - 1:
                        idx = _np.argpartition(row, k)[:k]
                    else:
                        idx = _np.argsort(row)[:k]
                    for j in idx:
                        w = float(dm[i, j])
                        if max_allowed is not None and w > max_allowed:
                            continue
                        knn_edges.append((node_ids[i], node_ids[j], w))
            else:
                print(f"Building k-NN from nested dict (k={k})...")
                all_distances = []
                for node1, neighbors in nested_dist.items():
                    for node2, dist in neighbors.items():
                        if node1 < node2:
                            all_distances.append(dist)
                if all_distances:
                    global_median = _np.median(all_distances)
                    max_allowed = None if max_factor is None else global_median * float(max_factor)
                    print(f"Median distance: {global_median:.4f}, max_allowed: {max_allowed}")
                else:
                    max_allowed = None

                for node1 in tqdm(node_ids, desc="Building k-NN"):
                    if node1 not in nested_dist:
                        continue
                    neighbors = nested_dist[node1]
                    sorted_neighbors = sorted(neighbors.items(), key=lambda x: x[1])[:k]
                    for node2, w in sorted_neighbors:
                        if max_allowed is not None and w > max_allowed:
                            continue
                        knn_edges.append((node1, node2, w))

            if symmetric:
                edge_set = {}
                for u, v, w in knn_edges:
                    a, b = (u, v) if u < v else (v, u)
                    if (a, b) not in edge_set or w < edge_set[(a, b)]:
                        edge_set[(a, b)] = w
                G = nx.Graph()
                G.add_nodes_from(node_ids)
                for (u, v), w in edge_set.items():
                    G.add_edge(u, v, weight=w)
            else:
                G = nx.Graph()
                G.add_nodes_from(node_ids)
                for u, v, w in knn_edges:
                    G.add_edge(u, v, weight=w)

            print(f"k-NN graph: k={k}, symmetric={symmetric}, nodes={G.number_of_nodes():,}, edges={G.number_of_edges():,}")
            print(f"Connected: {nx.is_connected(G)}")

            print(f" Base phylogenetic graph constructed successfully")

            if G.number_of_nodes() > 0:
                avg_degree = sum(dict(G.degree()).values()) / G.number_of_nodes()
                print(f"Average degree: {avg_degree:.2f}")

                isolated = list(nx.isolates(G))
                if isolated:
                    print(f" Isolated nodes: {len(isolated)}")
                    G.remove_nodes_from(isolated)
                    print(f"After removing isolated nodes: {G.number_of_nodes():,} nodes")

            print(f" Base phylogenetic graph constructed successfully")

            print(f"\n    Performance Estimates:")
            total_samples = len(df_CASE.columns) + len(df_CONTROL.columns)
            est_time_per_sample = G.number_of_edges() * 0.00001
            est_total_minutes = (est_time_per_sample * total_samples) / 60

            print(f"   - Total samples to process: {total_samples:,}")
            print(f"   - Est. time per sample: ~{est_time_per_sample:.2f} seconds")
            print(f"   - Est. total processing time: ~{est_total_minutes:.1f} minutes")

            if G.number_of_edges() > 200_000:
                print(f"\n        WARNING: Base graph may be large for per-sample processing!")
                print(f"   Edges: {G.number_of_edges():,}")
                print(f"   Consider reducing GRAPH_PARAMS['knn_k'] or lowering 'knn_max_distance_factor'")

            checkpoint_data['base_graph'] = G
            checkpoint_data['step'] = 3
            save_checkpoint(checkpoint_data, CHECKPOINT_FILE)

        except Exception as e:
            print(f" Error building base graph: {e}")
            raise
    else:
        print("Base phylogenetic graph already built from checkpoint")
        G = checkpoint_data['base_graph']
        print(f"Graph: {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges")
        print(f"Connected: {nx.is_connected(G)}")



## Graph Adjustment Function

### Define Graph Adjustment Function


In [None]:
_EDGE_WEIGHT_WARNING_SHOWN = False

def adjust_graph_to_abundance(G_in, abundance_dict, params=GRAPH_PARAMS):
    """
    Adjust the base phylogenetic graph based on sample-specific abundance data.
    OPTIMIZED: Removed expensive edge sorting for significant speedup.
    
    Now supports edge weight strategies including abundance-based ones:
    - identity, inverse, exponential, binary (simple transforms)
    - abundance_product, abundance_geometric, abundance_log, etc. (abundance-aware)
    
    Parameters:
    -----------
    G_in : networkx.Graph
        Base phylogenetic graph (edge weights = phylogenetic distances)
    abundance_dict : dict
        Dictionary mapping phylogenetic IDs to abundance values
    params : dict
        Graph construction parameters including 'weight_transform' strategy
        
    Returns:
    --------
    networkx.Graph
        Adjusted graph for the specific sample with appropriate edge weights
    """
    global _EDGE_WEIGHT_WARNING_SHOWN
    
    G = G_in.copy()
    
    nodes_to_remove = [node for node in G.nodes() if abundance_dict.get(node, 0) == 0]
    
    G.remove_nodes_from(nodes_to_remove)
    
    if G.number_of_nodes() == 0:
        return G
    
                                  
    for node in G.nodes():
        G.nodes[node]["weight"] = [abundance_dict.get(node, 0)]
    
    if params.get("remove_isolates", True):
        isolated = list(nx.isolates(G))
        if isolated:
            G.remove_nodes_from(isolated)
    
                                                                               
                                      
                                                    
                                                                               
    weight_strategy = params.get("weight_transform", "identity")
    
    if G.number_of_edges() > 0:
                                                   
        if weight_strategy.startswith("abundance_"):
                                          
            edge_weights_applied = False
            try:
                from src.edge_weights import get_edge_weight_function
                weight_fn = get_edge_weight_function(weight_strategy)
                
                                                                     
                for u, v, data in G.edges(data=True):
                    dist = data.get("weight", 1.0)                              
                    a1 = abundance_dict.get(u, 0)
                    a2 = abundance_dict.get(v, 0)
                    data["weight"] = weight_fn(dist, a1, a2)
                edge_weights_applied = True
            except ImportError:
                if not _EDGE_WEIGHT_WARNING_SHOWN:
                    print(f"⚠ Could not import edge_weights module, falling back to identity")
                    _EDGE_WEIGHT_WARNING_SHOWN = True
            except Exception as e:
                if not _EDGE_WEIGHT_WARNING_SHOWN:
                    print(f"⚠ Error applying {weight_strategy}: {e}, falling back to identity")
                    _EDGE_WEIGHT_WARNING_SHOWN = True
            
                                                                                      
            if not edge_weights_applied and not _EDGE_WEIGHT_WARNING_SHOWN:
                print(f"  Using identity weighting (original distances as weights)")
                _EDGE_WEIGHT_WARNING_SHOWN = True
        
        elif weight_strategy == "inverse":
            for u, v, data in G.edges(data=True):
                w = data.get("weight", 1.0)
                data["weight"] = 1.0 / (w + 1e-8)
        
        elif weight_strategy == "exponential":
            import math
            for u, v, data in G.edges(data=True):
                w = data.get("weight", 1.0)
                data["weight"] = math.exp(-w)
        
        elif weight_strategy == "binary":
            for u, v, data in G.edges(data=True):
                data["weight"] = 1.0
        
                                                                  
    
                            
    if params.get("normalize_weights", True) and G.number_of_edges() > 0:
        edge_weights = [data["weight"] for _, _, data in G.edges(data=True)]
        max_weight = max(edge_weights) if edge_weights else 1.0
        if max_weight > 0:
            for u, v, data in G.edges(data=True):
                data["weight"] = data["weight"] / max_weight
    
                            
    node_weights = [data["weight"][0] for _, data in G.nodes(data=True) if "weight" in data]
    if node_weights:
        max_node_weight = max(node_weights)
        if max_node_weight > 0:
            for node, data in G.nodes(data=True):
                if "weight" in data:
                    data["weight"][0] = data["weight"][0] / max_node_weight
    
    return G

print("Optimized graph adjustment function defined")
print(f"Key optimizations:")
print(f"- Removed expensive edge sorting (5-10x faster)")
print(f"- Removed redundant per-sample edge pruning")
print(f"- Optimized node removal with list comprehensions")
print(f"Parameters: {GRAPH_PARAMS}")


### Pre-flight Diagnostics


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping - using existing graphs")
else:
    import gc

    print("Pre-flight Diagnostics for Step 7")
    print("=" * 60)

    print(f"\n Base Graph (k-NN):")
    print(f"Nodes: {G.number_of_nodes():,}")
    print(f"Edges: {G.number_of_edges():,}")
    print(f"Avg degree: {2 * G.number_of_edges() / max(1, G.number_of_nodes()):.2f}")
    print(f"Graph density: {nx.density(G):.6f}")

    total_samples = len(df_CASE.columns) + len(df_CONTROL.columns)
    print(f"\n Samples to Process:")
    print(f"{DISEASE} cases: {len(df_CASE.columns):,}")
    print(f"Controls: {len(df_CONTROL.columns):,}")
    print(f"Total: {total_samples:,}")

    if total_samples > 10000:
        print(f"\n    CRITICAL WARNING: {total_samples:,} samples detected!")
        print(f"This is unusually high. Expected ~400-1000 samples for disease study.")
        print(f"  Did the transpose/filter fix in Cell 10 work correctly?")
        print(f"Please check Cell 10 output for 'DETECTED: Tables are TRANSPOSED' message")
        print(f"If you see this warning, the BIOM tables may still be in wrong orientation!")
        raise ValueError(f"Unreasonable sample count: {total_samples:,}. Check Cell 10 output.")
    elif total_samples < 50:
        print(f"\n     WARNING: Only {total_samples:,} samples detected.")
        print(f"This seems low. Expected ~400 samples for disease study.")
        print(f"Check that sample filtering in Cell 10 found matching samples.")
    else:
        print(f" Sample count looks reasonable for {DISEASE} analysis")

    edge_count = G.number_of_edges()
    est_time_per_sample = 0.02 + (edge_count * 0.000005)
    est_total_time_minutes = (est_time_per_sample * total_samples) / 60

    print(f"\n  Performance Estimates:")
    print(f"Est. time per sample: ~{est_time_per_sample:.2f} seconds")
    print(f"Est. total time: ~{est_total_time_minutes:.1f} minutes ({est_total_time_minutes/60:.1f} hours)")

    if total_samples > 1000000:
        print(f" MASSIVE DATASET ALERT: {total_samples:,} samples detected!")
        print(f"This will take approximately {est_total_time_minutes/60:.1f} hours")
        print(f"Consider processing in batches or subsampling")
    elif est_total_time_minutes > 1440:
        print(f" ALERT: Processing will take > 24 hours!")
        print(f"STRONGLY recommend batch processing or subsampling")
    elif est_total_time_minutes > 60:
        print(f"  WARNING: Processing may take > 1 hour")
        print(f"Consider reducing base graph size further")
    else:
        print(f" Processing time looks reasonable")

    est_memory_per_graph_mb = (G.number_of_nodes() * 100 + G.number_of_edges() * 50) / 1024 / 1024
    est_total_memory_mb = est_memory_per_graph_mb * total_samples

    print(f"\n Memory Estimates:")
    print(f"Per graph: ~{est_memory_per_graph_mb:.1f} MB")
    print(f"Total (all graphs): ~{est_total_memory_mb:.1f} MB ({est_total_memory_mb/1024:.2f} GB)")

    if est_total_memory_mb > 8000:
        print(f"  WARNING: High memory usage expected (>{est_total_memory_mb/1024:.1f} GB)")

    print(f"\n Feature Mapping:")
    print(f"Sequence mappings: {len(seq_to_ids):,}")
    print(f"{DISEASE} features: {len(df_CASE.index):,}")
    print(f"Control features: {len(df_CONTROL.index):,}")
    print(f"Mapping coverage: {len(seq_to_ids) / len(set(df_CASE.index) | set(df_CONTROL.index)) * 100:.1f}%")

    print(f"\n Running garbage collection...")
    gc.collect()
    print(f" Memory cleaned")

    print("\n" + "=" * 60)

    if total_samples > 1000000:
        print("MASSIVE DATASET DETECTED - RECOMMENDATIONS:")
        print("=" * 60)
        print(f"\n  Processing {total_samples:,} samples will take ~{est_total_time_minutes/60:.1f} hours")
        print(f"Memory requirement: ~{est_total_memory_mb/1024:.1f} GB")

        print(f"\n RECOMMENDED SOLUTIONS:")
        print(f"1. BATCH PROCESSING:")
        print(f"   - Process in chunks of 10,000-50,000 samples")
        print(f"   - Save results after each batch")
        print(f"   - Use multiple workers if available")

        print(f"\n   2. SUBSAMPLING:")
        print(f"   - Use stratified sampling (keep class balance)")
        print(f"   - Start with 100,000 samples for testing")
        print(f"   - Scale up if results look good")

        print(f"\n   3. GRAPH SIZE REDUCTION:")
        print(f"   - Further reduce max_edges to 500 or less")
        print(f"   - Increase distance_threshold to 1.02")
        print(f"   - Consider feature selection first")

        print(f"\n   4. INFRASTRUCTURE:")
        print(f"   - Use high-memory server (32GB+ RAM)")
        print(f"   - Consider distributed processing")
        print(f"   - Enable parallel processing if stable")

        print(f"\n QUICK FIX - Run this to reduce graph size:")
        print(f"In Cell 15, change:")
        print(f"max_edges = min(500, G.number_of_nodes())  # Even smaller")
        print(f"distance_threshold = median_weight * 1.02  # Very tight")

        print(f"\n Do you want to continue with current settings?")
        print(f"This will process ALL {total_samples:,} samples!")
        print("=" * 60)

    else:
        print("Pre-flight check complete. Ready to process samples.")



## Sample-Specific Graph Construction

### Build Individual Graphs for Each Sample


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping - using existing graphs")
else:
    if checkpoint_data['step'] < 4 or len(checkpoint_data.get('graphs', [])) == 0:
        print("Constructing sample-specific graphs...")
        print(f" Optimizations: Batch checkpointing + Better progress tracking")

        GRAPHS = []
        LABELS = []
        GRAPH_METADATA = []

        CHECKPOINT_INTERVAL = 50

        import time
        start_time = time.time()
        samples_processed = 0

                                                                                   
                                                                                   
        target_config = {}
        if 'EXPERIMENT_CONFIG' in globals() and EXPERIMENT_CONFIG:
            target_config = EXPERIMENT_CONFIG.get('data_extraction', {}).get('target', {})
        
        target_type = target_config.get('type', 'binary')
        num_classes = target_config.get('num_classes', 2)
        
                                                             
        if target_type == 'ibd_subtype' and DISEASE != 'IBD':
            print(f"  ⚠️ Warning: ibd_subtype target is only valid for IBD disease, falling back to binary")
            target_type = 'binary'
            num_classes = 2
        
        print(f"\n Label configuration:")
        print(f"  Target type: {target_type}")
        print(f"  Num classes: {num_classes}")
        
                                                                              
        sample_to_subtype = {}
        sample_to_age = {}
        
        if target_type in ['ibd_subtype', 'age_category']:
                                
            case_metadata_path = output_dir / "metadata" / f"AGP_{DISEASE}_cases_metadata.txt"
            if case_metadata_path.exists():
                case_meta_df = pd.read_csv(case_metadata_path, sep='\t')
                sample_col = '#SampleID' if '#SampleID' in case_meta_df.columns else case_meta_df.columns[0]
                
                                     
                if 'ibd_diagnosis_refined' in case_meta_df.columns:
                    for _, row in case_meta_df.iterrows():
                        subtype = row.get('ibd_diagnosis_refined', '')
                        if "Crohn" in str(subtype):
                            sample_to_subtype[row[sample_col]] = 1               
                        elif "Ulcerative" in str(subtype):
                            sample_to_subtype[row[sample_col]] = 2          
                        else:
                            sample_to_subtype[row[sample_col]] = 1                                    
                
                                      
                if 'age_cat' in case_meta_df.columns:
                    age_mapping = {'20s': 0, '30s': 1, '40s': 2, '50s': 3, '60s': 4, '70+': 4}
                    for _, row in case_meta_df.iterrows():
                        age = str(row.get('age_cat', ''))
                        sample_to_age[row[sample_col]] = age_mapping.get(age, 2)                  
                
                print(f"  Loaded {len(case_meta_df)} case metadata records")
            
                                                             
            ctrl_metadata_path = output_dir / "metadata" / f"AGP_{DISEASE}_controls_metadata.txt"
            if ctrl_metadata_path.exists():
                ctrl_meta_df = pd.read_csv(ctrl_metadata_path, sep='\t')
                sample_col = '#SampleID' if '#SampleID' in ctrl_meta_df.columns else ctrl_meta_df.columns[0]
                
                if 'age_cat' in ctrl_meta_df.columns:
                    age_mapping = {'20s': 0, '30s': 1, '40s': 2, '50s': 3, '60s': 4, '70+': 4}
                    for _, row in ctrl_meta_df.iterrows():
                        age = str(row.get('age_cat', ''))
                        sample_to_age[row[sample_col]] = age_mapping.get(age, 2)
                
                print(f"  Loaded {len(ctrl_meta_df)} control metadata records")

        def get_sample_label(sample_id, is_case):
            """Get label based on target type and sample info."""
            if target_type == 'binary':
                return 1 if is_case else 0
            elif target_type == 'ibd_subtype':
                if is_case:
                    return sample_to_subtype.get(sample_id, 1)                   
                else:
                    return 0               
            elif target_type == 'age_category':
                return sample_to_age.get(sample_id, 2)                         
            else:
                return 1 if is_case else 0                      

        print(f"\n Processing {DISEASE} cases...")
        case_samples = list(df_CASE.columns)
        case_features = list(df_CASE.index)

        print(f"Total {DISEASE} samples: {len(case_samples)}")

        for i, sample in enumerate(tqdm(case_samples, desc=f"{DISEASE} cases")):
            abundance_dict = {}
            for j, feature in enumerate(case_features):
                if feature in seq_to_ids:
                    phylogenetic_id = seq_to_ids[feature]
                    abundance_dict[phylogenetic_id] = df_CASE.loc[feature, sample]

            try:
                g_sample = adjust_graph_to_abundance(G, abundance_dict, GRAPH_PARAMS)

                if GRAPH_PARAMS.get("randomize_edges", False):
                    from src.graph_utils import randomize_edges
                    g_sample = randomize_edges(
                        g_sample, 
                        preserve_degree=GRAPH_PARAMS.get("preserve_degree", True),
                        seed=RANDOM_SEED + i
                    )

                                                                                         
                                                                       

                if (g_sample.number_of_nodes() >= GRAPH_PARAMS["min_nodes"] and 
                    g_sample.number_of_edges() >= GRAPH_PARAMS["min_edges"]):

                                                                           
                    sample_label = get_sample_label(sample, is_case=True)
                    g_sample.graph['label'] = sample_label
                    g_sample.graph['sample_id'] = sample
                    
                    GRAPHS.append(g_sample)
                    LABELS.append(sample_label)
                    GRAPH_METADATA.append({
                        "sample_id": sample,
                        "group": DISEASE.upper(),
                        "label": sample_label,
                        "n_nodes": g_sample.number_of_nodes(),
                        "n_edges": g_sample.number_of_edges(),
                        "avg_degree": sum(dict(g_sample.degree()).values()) / g_sample.number_of_nodes() if g_sample.number_of_nodes() > 0 else 0,
                        "is_connected": nx.is_connected(g_sample)
                    })
                else:
                    print(f" Skipping {sample}: insufficient size ({g_sample.number_of_nodes()} nodes, {g_sample.number_of_edges()} edges)")

            except Exception as e:
                print(f" Error processing {sample}: {e}")

            samples_processed += 1

            if (i + 1) % CHECKPOINT_INTERVAL == 0:
                elapsed = time.time() - start_time
                rate = samples_processed / elapsed
                remaining = len(case_samples) - (i + 1)
                eta_minutes = (remaining / rate) / 60 if rate > 0 else 0

                checkpoint_data['graphs'] = GRAPHS
                checkpoint_data['labels'] = LABELS
                checkpoint_data['metadata'] = GRAPH_METADATA
                save_checkpoint(checkpoint_data, CHECKPOINT_FILE)
                print(f" Checkpoint: {i+1}/{len(case_samples)} samples | Rate: {rate:.1f} samples/sec | ETA: {eta_minutes:.1f} min")

        print(f" {DISEASE} cases processed: {len([l for l in LABELS if l == 1])} valid graphs")

        print("\n Processing controls...")
        control_samples = list(df_CONTROL.columns)
        control_features = list(df_CONTROL.index)

        print(f"Total control samples: {len(control_samples)}")

        for i, sample in enumerate(tqdm(control_samples, desc="Controls")):
            abundance_dict = {}
            for j, feature in enumerate(control_features):
                if feature in seq_to_ids:
                    phylogenetic_id = seq_to_ids[feature]
                    abundance_dict[phylogenetic_id] = df_CONTROL.loc[feature, sample]

            try:
                g_sample = adjust_graph_to_abundance(G, abundance_dict, GRAPH_PARAMS)

                if GRAPH_PARAMS.get("randomize_edges", False):
                    from src.graph_utils import randomize_edges
                    g_sample = randomize_edges(
                        g_sample, 
                        preserve_degree=GRAPH_PARAMS.get("preserve_degree", True),
                        seed=RANDOM_SEED + len(case_samples) + i
                    )

                                                                                         
                                                                       

                if (g_sample.number_of_nodes() >= GRAPH_PARAMS["min_nodes"] and 
                    g_sample.number_of_edges() >= GRAPH_PARAMS["min_edges"]):

                                                                           
                    sample_label = get_sample_label(sample, is_case=False)
                    g_sample.graph['label'] = sample_label
                    g_sample.graph['sample_id'] = sample
                    
                    GRAPHS.append(g_sample)
                    LABELS.append(sample_label)
                    GRAPH_METADATA.append({
                        "sample_id": sample,
                        "group": "Control",
                        "label": sample_label,
                        "n_nodes": g_sample.number_of_nodes(),
                        "n_edges": g_sample.number_of_edges(),
                        "avg_degree": sum(dict(g_sample.degree()).values()) / g_sample.number_of_nodes() if g_sample.number_of_nodes() > 0 else 0,
                        "is_connected": nx.is_connected(g_sample)
                    })
                else:
                    print(f" Skipping {sample}: insufficient size ({g_sample.number_of_nodes()} nodes, {g_sample.number_of_edges()} edges)")

            except Exception as e:
                print(f" Error processing {sample}: {e}")

            samples_processed += 1

            if (i + 1) % CHECKPOINT_INTERVAL == 0:
                elapsed = time.time() - start_time
                rate = samples_processed / elapsed
                remaining = len(control_samples) - (i + 1)
                eta_minutes = (remaining / rate) / 60 if rate > 0 else 0

                checkpoint_data['graphs'] = GRAPHS
                checkpoint_data['labels'] = LABELS
                checkpoint_data['metadata'] = GRAPH_METADATA
                save_checkpoint(checkpoint_data, CHECKPOINT_FILE)
                print(f" Checkpoint: {i+1}/{len(control_samples)} samples | Rate: {rate:.1f} samples/sec | ETA: {eta_minutes:.1f} min")

        print(f" Controls processed: {len([l for l in LABELS if l == 0])} valid graphs")

        total_elapsed = time.time() - start_time
        print(f"\n Graph construction summary:")
        print(f"Total graphs: {len(GRAPHS)}")
        print(f"{DISEASE} cases: {len([l for l in LABELS if l == 1])}")
        print(f"Control graphs: {len([l for l in LABELS if l == 0])}")
        print(f"Success rate: {len(GRAPHS)/(len(case_samples) + len(control_samples))*100:.1f}%")
        print(f"Total time: {total_elapsed/60:.1f} minutes")
        print(f"Average rate: {samples_processed/total_elapsed:.2f} samples/second")

        checkpoint_data['graphs'] = GRAPHS
        checkpoint_data['labels'] = LABELS
        checkpoint_data['metadata'] = GRAPH_METADATA
        checkpoint_data['step'] = 4
        save_checkpoint(checkpoint_data, CHECKPOINT_FILE)
        print(f" Final checkpoint saved")

    else:
        print("Sample-specific graphs already constructed from checkpoint")
        GRAPHS = checkpoint_data['graphs']
        LABELS = checkpoint_data['labels']
        GRAPH_METADATA = checkpoint_data['metadata']
        print(f"Total graphs: {len(GRAPHS)}")
        print(f"{DISEASE} cases: {len([l for l in LABELS if l == 1])}")
        print(f"Control graphs: {len([l for l in LABELS if l == 0])}")



## Post-Processing and Memory Cleanup

### Memory Management


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping - using existing graphs")
else:
    import gc
    import sys

    print("Cleaning up memory...")

    try:
        import psutil
        import os
        process = psutil.Process(os.getpid())
        mem_before_mb = process.memory_info().rss / 1024 / 1024
        print(f"Memory before cleanup: {mem_before_mb:.1f} MB")
        has_psutil = True
    except ImportError:
        has_psutil = False
        print(f"(Install psutil for memory monitoring)")

    if 'dist_mat' in locals():
        print(f"Clearing distance matrix...")
        del dist_mat

    if 'df_CASE' in locals() and 'df_CONTROL' in locals():
        print(f"Abundance tables will be kept for reference")

    for var_name in ['abundance_dict', 'g_sample', 'g_copy']:
        if var_name in locals():
            exec(f"del {var_name}")

    gc.collect()

    if has_psutil:
        mem_after_mb = process.memory_info().rss / 1024 / 1024
        mem_freed_mb = mem_before_mb - mem_after_mb
        print(f"Memory after cleanup: {mem_after_mb:.1f} MB")
        if mem_freed_mb > 0:
            print(f"Freed: {mem_freed_mb:.1f} MB")

    print(f" Memory cleanup complete\n")



### Graph Validation and Quality Control


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping graph validation - using existing graphs")
else:
    print("Validating and cleaning graphs...")

    def ensure_node_weight_scalar(G_in):
        """Convert node weights from lists to scalars for compatibility."""
        for node, data in G_in.nodes(data=True):
            weight = data.get("weight", 0.0)
            while isinstance(weight, list):
                if len(weight) == 0:
                    weight = 0.0
                    break
                elif len(weight) == 1:
                    weight = weight[0]
                else:
                    weight = weight[0]
                    break
            data["weight"] = float(weight)
        return G_in

    if len(GRAPHS) > 0:
        print(f"Cleaning {len(GRAPHS):,} graphs...")
        GRAPHS = [ensure_node_weight_scalar(g) for g in GRAPHS]

        print(f"Validating graph properties...")
        valid_graphs = []
        valid_labels = []
        valid_metadata = []
        invalid_count = 0

        for i, (graph, label, metadata) in enumerate(zip(GRAPHS, LABELS, GRAPH_METADATA)):
            is_valid = (
                graph.number_of_nodes() >= GRAPH_PARAMS["min_nodes"] and 
                graph.number_of_edges() >= GRAPH_PARAMS["min_edges"]
            )

            if GRAPH_PARAMS.get("connectivity_check", False):
                is_valid = is_valid and nx.is_connected(graph)

            if is_valid:
                valid_graphs.append(graph)
                valid_labels.append(label)
                valid_metadata.append(metadata)
            else:
                invalid_count += 1
                if invalid_count <= 5:
                    print(f" Removing invalid graph {i}: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")

        if invalid_count > 5:
            print(f" ... and {invalid_count - 5} more invalid graphs")

        GRAPHS = valid_graphs
        LABELS = valid_labels
        GRAPH_METADATA = valid_metadata

        print(f" Valid graphs: {len(GRAPHS):,} (removed {invalid_count} invalid)")

        if len(GRAPHS) > 0:
            node_counts = [g.number_of_nodes() for g in GRAPHS]
            edge_counts = [g.number_of_edges() for g in GRAPHS]

            print(f"\n    Graph statistics:")
            print(f"   Nodes - Mean: {np.mean(node_counts):.1f}, Median: {np.median(node_counts):.1f}, Range: {min(node_counts)}-{max(node_counts)}")
            print(f"   Edges - Mean: {np.mean(edge_counts):.1f}, Median: {np.median(edge_counts):.1f}, Range: {min(edge_counts)}-{max(edge_counts)}")

            case_count = sum(LABELS)
            control_count = len(LABELS) - case_count
            balance_ratio = min(case_count, control_count) / max(case_count, control_count) if max(case_count, control_count) > 0 else 0

            print(f"\n     Class balance:")
            print(f"   {DISEASE} cases: {case_count:,} ({case_count/len(LABELS)*100:.1f}%)")
            print(f"   Controls: {control_count:,} ({control_count/len(LABELS)*100:.1f}%)")
            print(f"   Balance ratio: {balance_ratio:.2f} {'' if balance_ratio > 0.7 else ''}")

            if len(GRAPHS) > 0:
                example_graph = GRAPHS[0]
                print(f"\n    Example graph (first sample):")
                print(f"   Nodes: {example_graph.number_of_nodes()}")
                print(f"   Edges: {example_graph.number_of_edges()}")
                print(f"   Density: {nx.density(example_graph):.4f}")
                print(f"   Sample node weights: {[round(w, 4) for n, w in list(example_graph.nodes(data='weight'))[:3]]}...")

                if example_graph.number_of_edges() == 0:
                    print(f"WARNING: Example graph has no edges!")
                if example_graph.number_of_nodes() < 10:
                    print(f"WARNING: Example graph has very few nodes!")
    else:
        print(f" No graphs to validate!")

    print("\n    Graph validation complete")



## Graph Visualization and Analysis

### Visualize Example Graphs


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("Skipping graph visualization - using existing graphs")
else:
    def plot_graph_example(graph, title="Graph Example", max_nodes=VIZ_PARAMS["max_nodes_to_plot"]):
        """Plot an example graph with node sizes based on weights."""
        if graph.number_of_nodes() == 0:
            print(f" Cannot plot empty graph: {title}")
            return

        if graph.number_of_nodes() > max_nodes:
            nodes_to_plot = list(graph.nodes())[:max_nodes]
            subgraph = graph.subgraph(nodes_to_plot)
            title += f" (showing {max_nodes}/{graph.number_of_nodes()} nodes)"
        else:
            subgraph = graph

        plt.figure(figsize=VIZ_PARAMS["figsize"], dpi=VIZ_PARAMS["dpi"])

        pos = nx.spring_layout(subgraph, k=1, iterations=50)

        node_weights = [subgraph.nodes[node].get('weight', 0.1) for node in subgraph.nodes()]
        node_sizes = [max(10, w * VIZ_PARAMS["node_size_scale"]) for w in node_weights]

        edge_weights = [subgraph.edges[edge].get('weight', 0.1) for edge in subgraph.edges()]

        nx.draw(subgraph, pos, 
                node_size=node_sizes,
                node_color='lightblue',
                edge_color='gray',
                width=VIZ_PARAMS["edge_width_scale"],
                with_labels=False,
                alpha=0.7)

        plt.title(title)
        plt.axis('off')

        if VIZ_PARAMS["save_plots"]:
            plot_path = graphs_output_dir / "graph_examples" / f"{title.replace(' ', '_').replace('(', '').replace(')', '')}.png"
            plt.savefig(plot_path, dpi=VIZ_PARAMS["dpi"], bbox_inches='tight')
            print(f" Saved plot: {plot_path}")

        if VIZ_PARAMS["show_plots"]:
            plt.show()
        else:
            plt.close()

    def plot_graph_statistics():
        """Plot comprehensive graph statistics."""
        if len(GRAPHS) == 0:
            print(f" No graphs to analyze")
            return

        metadata_df = pd.DataFrame(GRAPH_METADATA)

        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('Graph Construction Statistics', fontsize=16)

        axes[0, 0].hist(metadata_df['n_nodes'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        axes[0, 0].set_xlabel('Number of Nodes')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].set_title('Distribution of Node Counts')
        axes[0, 0].grid(True, alpha=0.3)

        axes[0, 1].hist(metadata_df['n_edges'], bins=20, alpha=0.7, color='lightgreen', edgecolor='black')
        axes[0, 1].set_xlabel('Number of Edges')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title('Distribution of Edge Counts')
        axes[0, 1].grid(True, alpha=0.3)

        group_stats = metadata_df.groupby('group')['avg_degree'].agg(['mean', 'std']).reset_index()
        axes[1, 0].bar(group_stats['group'], group_stats['mean'], 
                       yerr=group_stats['std'], capsize=5, alpha=0.7, color=['red', 'blue'])
        axes[1, 0].set_ylabel('Average Degree')
        axes[1, 0].set_title('Average Degree by Group')
        axes[1, 0].grid(True, alpha=0.3)

        class_counts = metadata_df['group'].value_counts()
        axes[1, 1].pie(class_counts.values, labels=class_counts.index, autopct='%1.1f%%', 
                       colors=['red', 'blue'], wedgeprops={'alpha': 0.7})
        axes[1, 1].set_title('Class Balance')

        plt.tight_layout()

        if VIZ_PARAMS["save_plots"]:
            plot_path = visualizations_dir / "statistics" / "graph_statistics.png"
            plt.savefig(plot_path, dpi=VIZ_PARAMS["dpi"], bbox_inches='tight')
            print(f" Saved statistics plot: {plot_path}")

        if VIZ_PARAMS["show_plots"]:
            plt.show()
        else:
            plt.close()

    if len(GRAPHS) > 0:
        print("Generating graph visualizations...")

        case_graphs = [g for g, l in zip(GRAPHS, LABELS) if l == 1]
        control_graphs = [g for g, l in zip(GRAPHS, LABELS) if l == 0]

        if len(case_graphs) > 0:
            print(f" Plotting {DISEASE} case example graph...")
            plot_graph_example(case_graphs[0], f"{DISEASE} Case Example Graph")

        if len(control_graphs) > 0:
            print(f" Plotting Control example graph...")
            plot_graph_example(control_graphs[0], "Control Example Graph")

        print(f" Generating statistics plots...")
        plot_graph_statistics()

        print(f" Visualizations complete")
    else:
        print(f" No graphs available for visualization")



## Output Generation and Pipeline Summary

### Save Output Files


In [None]:
if SKIP_GRAPH_CONSTRUCTION:
    print("="*60)
    print("USING EXISTING GRAPHS - NO SAVE NEEDED")
    print("="*60)
    print(f"Graphs already exist at: {OUT_GRAPHS_PKL}")
    print(f"Labels already exist at: {OUT_LABELS_NPY}")
    print(f"Number of graphs: {len(GRAPHS_PREPARED)}")
    print(f"Number of labels: {len(LABELS)}")
    GRAPHS = GRAPHS_PREPARED
else:
    print("Saving output files...")

    if len(GRAPHS) > 0:
        print(f"Saving {len(GRAPHS)} graphs to {OUT_GRAPHS_PKL}")
        with open(OUT_GRAPHS_PKL, "wb") as f:
            pickle.dump(GRAPHS, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        print(f"Saving {len(LABELS)} labels to {OUT_LABELS_NPY}")
        np.save(OUT_LABELS_NPY, np.array(LABELS, dtype=np.int64))
        
        print(f"Saving metadata to {OUT_METADATA_CSV}")
        metadata_df = pd.DataFrame(GRAPH_METADATA)
        metadata_df.to_csv(OUT_METADATA_CSV, index=False)
        
        config_output = {
            "disease": DISEASE,
            "graph_params": GRAPH_PARAMS,
            "viz_params": VIZ_PARAMS,
            "n_graphs": len(GRAPHS),
            "n_cases": sum(LABELS),
            "n_control": len(LABELS) - sum(LABELS),
            "base_graph_nodes": G.number_of_nodes(),
            "base_graph_edges": G.number_of_edges(),
            "stockholm_available": stockholm_available,
            "sequence_mappings": len(seq_to_ids),
            "pipeline_version": "1.0",
            "timestamp": datetime.now().isoformat(),
            "input_files": {
                "distance_matrix": str(PICKLE_FILE),
                "cases": str(CASE_BIOM_TSV),
                "controls": str(CONTROL_BIOM_TSV),
                "case_ids": str(CASE_IDS_FILE),
                "control_ids": str(CONTROL_IDS_FILE)
            }
        }
        
        with open(OUT_CONFIG_JSON, "w") as f:
            json.dump(config_output, f, indent=2)
        
        print(f" All files saved successfully")
        
        print(f"\n Final Summary:")
        print(f"Disease: {DISEASE}")
        print(f"Total graphs: {len(GRAPHS)}")
        print(f"{DISEASE} graphs: {sum(LABELS)}")
        print(f"Control graphs: {len(LABELS) - sum(LABELS)}")
        print(f"Base graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
        print(f"Sequence mappings: {len(seq_to_ids)}")
        
        print(f"\n Output files:")
        print(f"- Graphs: {OUT_GRAPHS_PKL}")
        print(f"- Labels: {OUT_LABELS_NPY}")
        print(f"- Metadata: {OUT_METADATA_CSV}")
        print(f"- Config: {OUT_CONFIG_JSON}")
        print(f"- Checkpoint: {CHECKPOINT_FILE}")
        
        print(f"\n Visualization files:")
        print(f"- Graph examples: {graphs_output_dir / 'graph_examples'}")
        print(f"- Statistics: {visualizations_dir / 'statistics'}")
        
        print(f"\n Next Steps:")
        print(f"1. Use the graphs for machine learning analysis")
        print(f"2. The metadata CSV contains graph statistics for each sample")
        print(f"3. The configuration file contains all parameters for reproducibility")
        print(f"4. Check the visualization files to understand graph structure")
        
        checkpoint_data['step'] = 5
        checkpoint_data['final_outputs'] = {
            'graphs_file': str(OUT_GRAPHS_PKL),
            'labels_file': str(OUT_LABELS_NPY),
            'metadata_file': str(OUT_METADATA_CSV),
            'config_file': str(OUT_CONFIG_JSON)
        }
        save_checkpoint(checkpoint_data, CHECKPOINT_FILE)
        
    else:
        print(f" No graphs to save!")
        print(f"Check the input data and parameters")

print(f"\n Graph construction pipeline complete!")
print(f" Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
