# Comprehensive Treatment Pattern Analysis

This notebook demonstrates the complete pipeline combining:
1. Age-matched feature building
2. Observational pattern learning with matched controls
3. Bayesian propensity-response modeling

Author: Sarah Urbut  
Date: 2025-07-15

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

print("Libraries imported successfully!")

## 1. Load Data

Load your signature data, patient IDs, prescription data, and covariates.

In [None]:
# Load your data here - adjust paths as needed
# Example:
# thetas = np.load('path/to/your/signature_loadings.npy')  # N x K x T
# processed_ids = np.load('path/to/your/processed_ids.npy')
# statins = pd.read_csv('path/to/your/statin_prescriptions.csv')
# cov = pd.read_csv('path/to/your/covariates.csv')

# For this example, assuming you already have these loaded:
# thetas, processed_ids, statins, cov

print(f"Signature data shape: {thetas.shape}")
print(f"Number of processed patients: {len(processed_ids)}")
print(f"Number of prescription records: {len(statins)}")
print(f"Number of patients with covariates: {len(cov)}")

## 2. Clean and Prepare Statin Data

In [None]:
# Clean up the statin data - remove non-statins
true_statins = statins[statins['drug_name'].str.contains(
    'simvastatin|atorvastatin|rosuvastatin|pravastatin|fluvastatin|lovastatin', 
    case=False, na=False
)].copy()

print(f"Total prescription records: {len(statins)}")
print(f"True statins after filtering: {len(true_statins)}")
print(f"Unique patients with true statins: {true_statins['eid'].nunique()}")

# Show distribution of statin types
if len(true_statins) > 0:
    statin_counts = true_statins['drug_name'].value_counts().head(10)
    print("\nTop 10 statin types:")
    print(statin_counts)

## 3. Prepare Covariate Dictionary for Matching

In [None]:
# Create covariate dictionary for matching
covariate_dicts = {}

# Add age information
if 'birth_year' in cov.columns:
    current_year = 2025  # Adjust as needed
    covariate_dicts['age'] = dict(zip(cov['eid'], current_year - cov['birth_year']))
    print(f"Added age information for {len(covariate_dicts['age'])} patients")

# Add sex information if available
if 'sex' in cov.columns:
    covariate_dicts['sex'] = dict(zip(cov['eid'], cov['sex']))
    print(f"Added sex information for {len(covariate_dicts['sex'])} patients")

# Add other covariates as needed
# Example: BMI, smoking status, etc.
# if 'bmi' in cov.columns:
#     covariate_dicts['bmi'] = dict(zip(cov['eid'], cov['bmi']))

print(f"\nCovariate dictionary keys: {list(covariate_dicts.keys())}")

## 4. Run Comprehensive Analysis

In [None]:
# Import the comprehensive analysis module
from scripts.comprehensive_treatment_analysis import run_comprehensive_analysis

# Optional: specify which signatures to use (None = use all)
sig_indices = None  # or list(range(10)) for first 10 signatures

# Optional: provide outcome data for Bayesian analysis
# outcomes = your_outcome_data  # Binary outcomes for treated patients
outcomes = None  # Set to None if no outcome data available

print("Starting comprehensive analysis...")
print("This may take several minutes depending on data size.\n")

In [None]:
# Run the complete analysis
results = run_comprehensive_analysis(
    signature_loadings=thetas,
    processed_ids=processed_ids,
    statin_prescriptions=true_statins,
    covariates=cov,
    covariate_dicts=covariate_dicts,
    sig_indices=sig_indices,
    outcomes=outcomes
)

## 5. Examine Results

### 5.1 Matching Quality

In [None]:
# Check matching results
n_matched = len(results['matched_treated_indices'])
print(f"Successfully matched {n_matched} treated-control pairs")

# If we have age data, compare age distributions
if 'age' in covariate_dicts:
    treated_ages = [covariate_dicts['age'].get(eid, np.nan) 
                   for eid in results['enhanced_learner'].treatment_patterns['treated_patients']]
    treated_ages = [age for age in treated_ages if not np.isnan(age)]
    
    if len(treated_ages) > 0:
        plt.figure(figsize=(10, 6))
        plt.hist(treated_ages, bins=20, alpha=0.7, label='Treated patients', density=True)
        plt.xlabel('Age')
        plt.ylabel('Density')
        plt.title('Age Distribution of Treated Patients')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
        
        print(f"Mean age of treated patients: {np.mean(treated_ages):.1f} Â± {np.std(treated_ages):.1f}")

### 5.2 Treatment Pattern Clusters

In [None]:
# Examine clustering results
cluster_analysis = results['cluster_analysis']

if cluster_analysis is not None:
    print("Pre-treatment signature clusters:")
    print("=" * 50)
    
    for cluster_id, pattern in cluster_analysis['cluster_patterns'].items():
        print(f"Cluster {cluster_id}:")
        print(f"  - {pattern['n_patients']} patients")
        print(f"  - Mean treatment age: {pattern['mean_treatment_age']:.1f} years")
        print()
else:
    print("No clustering results available")

### 5.3 Treatment Readiness Signatures

In [None]:
# Examine treatment readiness patterns
responsive_patterns = results['responsive_patterns']

if responsive_patterns is not None:
    readiness_sigs = responsive_patterns['treatment_readiness_signatures']
    
    print("Top 10 treatment readiness signatures:")
    print("=" * 60)
    print(f"{'Signature':<10} {'Readiness Score':<15} {'Trend Fraction':<15} {'Accel Fraction'}")
    print("=" * 60)
    
    for i, (sig_idx, score) in enumerate(readiness_sigs[:10]):
        concerning = responsive_patterns['concerning_patterns'][sig_idx]
        trend_frac = concerning['concerning_trend_fraction']
        accel_frac = concerning['accelerating_fraction']
        
        print(f"{sig_idx:<10} {score:<15.3f} {trend_frac:<15.3f} {accel_frac:<15.3f}")
else:
    print("No responsive patterns available")

### 5.4 Matched Control Comparison

In [None]:
# Examine treated vs matched control differences
matched_comparison = results['matched_comparison']

if matched_comparison is not None:
    print("Treated vs Matched Controls (Top 10 signatures by effect size):")
    print("=" * 80)
    print(f"{'Sig':<4} {'Treated':<10} {'Control':<10} {'Difference':<12} {'Effect Size':<12} {'P-value':<10} {'Sig?'}")
    print("=" * 80)
    
    # Sort by absolute effect size
    sorted_sigs = sorted(matched_comparison.items(), 
                        key=lambda x: abs(x[1]['effect_size']), reverse=True)
    
    for sig_idx, stats in sorted_sigs[:10]:
        sig_marker = "***" if stats['p_value'] < 0.05 else ""
        print(f"{sig_idx:<4} {stats['treated_mean']:<10.4f} {stats['control_mean']:<10.4f} "
              f"{stats['difference']:<12.4f} {stats['effect_size']:<12.4f} "
              f"{stats['p_value']:<10.4f} {sig_marker}")
else:
    print("No matched comparison results available")

### 5.5 Prediction Model Performance

In [None]:
# Examine prediction model results
predictor = results['predictor']

if predictor is not None:
    print(f"Treatment Readiness Prediction Model:")
    print(f"Cross-validation AUC: {predictor['cv_auc']:.3f}")
    
    # Feature importance from Random Forest
    if hasattr(predictor['model'], 'feature_importances_'):
        importances = predictor['model'].feature_importances_
        n_features_per_sig = 4  # trend, level, level_vs_typical, accel
        n_signatures = len(importances) // n_features_per_sig
        
        # Aggregate importance by signature
        sig_importances = []
        for s in range(n_signatures):
            start_idx = s * n_features_per_sig
            end_idx = start_idx + n_features_per_sig
            sig_importance = np.sum(importances[start_idx:end_idx])
            sig_importances.append((s, sig_importance))
        
        # Sort by importance
        sig_importances.sort(key=lambda x: x[1], reverse=True)
        
        print("\nTop 10 most important signatures for prediction:")
        print("=" * 40)
        print(f"{'Signature':<10} {'Importance':<12}")
        print("=" * 40)
        
        for sig_idx, importance in sig_importances[:10]:
            print(f"{sig_idx:<10} {importance:<12.4f}")
else:
    print("No prediction model results available")

### 5.6 Bayesian Analysis Results

In [None]:
# Examine Bayesian analysis results
trace = results['bayesian_trace']
bayesian_model = results['bayesian_model']

if trace is not None:
    print("Bayesian Propensity-Response Analysis:")
    print(f"Number of chains: {len(trace.posterior.chain)}")
    print(f"Number of draws per chain: {len(trace.posterior.draw)}")
    
    # Plot trace for key parameters
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    # Propensity baseline
    az.plot_trace(trace, var_names=['propensity_baseline'], ax=axes[0, :])
    axes[0, 0].set_title('Propensity Baseline (Trace)')
    axes[0, 1].set_title('Propensity Baseline (Posterior)')
    
    # Treatment effect (if available)
    if 'treatment_effect' in trace.posterior:
        az.plot_trace(trace, var_names=['treatment_effect'], ax=axes[1, :])
        axes[1, 0].set_title('Treatment Effect (Trace)')
        axes[1, 1].set_title('Treatment Effect (Posterior)')
    
    plt.tight_layout()
    plt.show()
    
    # Summary statistics
    print("\nPosterior Summary:")
    print(az.summary(trace, var_names=['propensity_baseline']))
    
    if 'treatment_effect' in trace.posterior:
        print("\nTreatment Effect Summary:")
        print(az.summary(trace, var_names=['treatment_effect']))
        
else:
    print("No Bayesian analysis results available")
    print("This may be due to insufficient data or computational issues")

## 6. Advanced Visualizations

### 6.1 Signature Comparison Heatmap

In [None]:
# Create heatmap of signature differences
if matched_comparison is not None:
    # Extract data for heatmap
    signatures = list(matched_comparison.keys())
    differences = [matched_comparison[s]['difference'] for s in signatures]
    p_values = [matched_comparison[s]['p_value'] for s in signatures]
    
    # Create significance mask
    significant = np.array(p_values) < 0.05
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Differences
    diff_matrix = np.array(differences).reshape(-1, 1)
    im1 = ax1.imshow(diff_matrix.T, cmap='RdBu_r', aspect='auto')
    ax1.set_title('Signature Differences (Treated - Control)')
    ax1.set_xlabel('Signature Index')
    ax1.set_xticks(range(0, len(signatures), max(1, len(signatures)//10)))
    ax1.set_xticklabels([str(signatures[i]) for i in range(0, len(signatures), max(1, len(signatures)//10))])
    ax1.set_yticks([])
    plt.colorbar(im1, ax=ax1)
    
    # Significance
    sig_matrix = significant.astype(int).reshape(-1, 1)
    im2 = ax2.imshow(sig_matrix.T, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)
    ax2.set_title('Statistical Significance (p < 0.05)')
    ax2.set_xlabel('Signature Index')
    ax2.set_xticks(range(0, len(signatures), max(1, len(signatures)//10)))
    ax2.set_xticklabels([str(signatures[i]) for i in range(0, len(signatures), max(1, len(signatures)//10))])
    ax2.set_yticks([])
    plt.colorbar(im2, ax=ax2)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Number of significantly different signatures: {np.sum(significant)}")

### 6.2 Treatment Timing Analysis

In [None]:
# Analyze treatment timing patterns
learner = results['enhanced_learner']
treatment_times = np.array(learner.treatment_patterns['treatment_times'])
treatment_ages = treatment_times + learner.time_start_age

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Treatment age distribution
axes[0].hist(treatment_ages, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
axes[0].set_xlabel('Age at Treatment Initiation')
axes[0].set_ylabel('Number of Patients')
axes[0].set_title('Treatment Initiation Age Distribution')
axes[0].grid(True, alpha=0.3)
axes[0].axvline(np.median(treatment_ages), color='red', linestyle='--', 
                label=f'Median: {np.median(treatment_ages):.1f}')
axes[0].legend()

# Treatment by year (if prescription dates available)
if 'issue_date' in true_statins.columns:
    true_statins['year'] = pd.to_datetime(true_statins['issue_date'], 
                                         format='%d/%m/%Y', errors='coerce').dt.year
    year_counts = true_statins.groupby('year').size()
    
    axes[1].plot(year_counts.index, year_counts.values, marker='o')
    axes[1].set_xlabel('Year')
    axes[1].set_ylabel('Number of Prescriptions')
    axes[1].set_title('Statin Prescriptions Over Time')
    axes[1].grid(True, alpha=0.3)
    axes[1].tick_params(axis='x', rotation=45)

# Age vs cluster assignment (if clustering available)
if cluster_analysis is not None:
    clusters = cluster_analysis['clusters']
    unique_clusters = np.unique(clusters)
    
    for cluster_id in unique_clusters:
        cluster_mask = clusters == cluster_id
        cluster_ages = treatment_ages[cluster_mask]
        axes[2].hist(cluster_ages, alpha=0.6, label=f'Cluster {cluster_id}', bins=15)
    
    axes[2].set_xlabel('Age at Treatment')
    axes[2].set_ylabel('Number of Patients')
    axes[2].set_title('Treatment Age by Cluster')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Summary and Key Insights

In [None]:
print("=" * 60)
print("COMPREHENSIVE ANALYSIS SUMMARY")
print("=" * 60)

# Matching summary
n_matched = len(results['matched_treated_indices'])
print(f"\n1. MATCHING RESULTS:")
print(f"   - Successfully matched {n_matched} treated-control pairs")
print(f"   - Matching controlled for age and other available covariates")

# Pattern discovery summary
if cluster_analysis is not None:
    n_clusters = cluster_analysis['n_clusters']
    print(f"\n2. PATTERN DISCOVERY:")
    print(f"   - Identified {n_clusters} distinct pre-treatment signature patterns")
    
    for cluster_id, pattern in cluster_analysis['cluster_patterns'].items():
        print(f"   - Cluster {cluster_id}: {pattern['n_patients']} patients, "
              f"avg age {pattern['mean_treatment_age']:.1f}")

# Signature differences summary
if matched_comparison is not None:
    significant_sigs = [s for s, stats in matched_comparison.items() 
                       if stats['p_value'] < 0.05]
    print(f"\n3. SIGNATURE DIFFERENCES (Treated vs Matched Controls):")
    print(f"   - {len(significant_sigs)} signatures show significant differences (p < 0.05)")
    
    if len(significant_sigs) > 0:
        top_sig = max(matched_comparison.items(), 
                     key=lambda x: abs(x[1]['effect_size']))
        print(f"   - Strongest effect: Signature {top_sig[0]} "
              f"(effect size = {top_sig[1]['effect_size']:.3f})")

# Prediction performance
if predictor is not None:
    print(f"\n4. PREDICTION MODEL:")
    print(f"   - Treatment readiness model AUC: {predictor['cv_auc']:.3f}")
    
    if predictor['cv_auc'] > 0.7:
        print(f"   - Model shows good predictive performance")
    elif predictor['cv_auc'] > 0.6:
        print(f"   - Model shows moderate predictive performance")
    else:
        print(f"   - Model shows limited predictive performance")

# Bayesian analysis summary
if trace is not None:
    print(f"\n5. BAYESIAN CAUSAL INFERENCE:")
    print(f"   - Successfully completed propensity-response modeling")
    print(f"   - {len(trace.posterior.chain)} chains, "
          f"{len(trace.posterior.draw)} draws per chain")
    
    if 'treatment_effect' in trace.posterior:
        treatment_effect_mean = float(trace.posterior['treatment_effect'].mean())
        print(f"   - Estimated treatment effect: {treatment_effect_mean:.3f}")
else:
    print(f"\n5. BAYESIAN CAUSAL INFERENCE:")
    print(f"   - Analysis not completed (may need outcome data)")

print(f"\n" + "=" * 60)
print("ANALYSIS COMPLETE")
print("=" * 60)

## 8. Save Results (Optional)

In [None]:
# Save key results to files
import pickle
from datetime import datetime

# Create timestamp for filenames
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save results dictionary
output_file = f"comprehensive_analysis_results_{timestamp}.pkl"
with open(output_file, 'wb') as f:
    pickle.dump(results, f)
print(f"Results saved to: {output_file}")

# Save summary statistics to CSV
if matched_comparison is not None:
    comparison_df = pd.DataFrame(matched_comparison).T
    comparison_df.to_csv(f"signature_comparison_{timestamp}.csv")
    print(f"Signature comparison saved to: signature_comparison_{timestamp}.csv")

# Save Bayesian results if available
if trace is not None:
    az.to_netcdf(trace, f"bayesian_trace_{timestamp}.nc")
    print(f"Bayesian trace saved to: bayesian_trace_{timestamp}.nc")

print("\nAll results saved successfully!")