# Comprehensive Treatment Pattern Analysis

This notebook demonstrates the complete pipeline combining:
1. Age-matched feature building
2. Observational pattern learning with matched controls
3. Bayesian propensity-response modeling

Author: Sarah Urbut  
Date: 2025-07-15

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")
import sys
sys.path.append('scripts')
%load_ext autoreload

%autoreload 2

print("Libraries imported successfully!")


Libraries imported successfully!


## 1. Load Data

Load your signature data, patient IDs, prescription data, and covariates.

In [25]:
import numpy as np
import pandas as pd
import sys
sys.path.append('scripts')
thetas = np.load("/Users/sarahurbut/aladynoulli2/pyScripts/thetas.npy")
processed_ids = np.load("/Users/sarahurbut/aladynoulli2/pyScripts/processed_patient_ids.npy").astype(int)
cov = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/matched_pce_df_400k.csv')


cov.columns = cov.columns.str.strip()
cov = cov.rename(columns={cov.columns[0]: 'eid'})
cov['eid'] = cov['eid'].astype(int)



  # Create covariate dict
covariate_dicts = {'age': dict(zip(cov['eid'], 2025 - cov['birth_year']))}



In [None]:
from scripts.statin_utils import *
# thetas: shape (N, n_signatures, n_timepoints)
mean_thetas = thetas.mean(axis=0)  # shape: (n_signatures, n_timepoints)
std_thetas = thetas.std(axis=0)    # shape: (n_signatures, n_timepoints)

import pandas as pd
import torch
Y=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt')
Y.shape
# Broadcasting: (N, n_signatures, n_timepoints) - (n_signatures, n_timepoints)
z_thetas = (thetas - mean_thetas[None, :, :]) / std_thetas[None, :, :]

# statins = pd.read_csv('path/to/your/statin_prescriptions.csv')
# Load covariate data (demographics, labs, etc.)
cov = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/matched_pce_df_400k.csv')
cov.columns = cov.columns.str.strip()
cov = cov.rename(columns={cov.columns[0]: 'eid'})
cov['eid'] = cov['eid'].astype(int)


# Usage:
prescription_path = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/gp_scripts.txt'
gp_scripts = pd.read_csv(prescription_path, sep='\t')
gp_scripts = simple_gp_check(gp_scripts)
df, statins = basic_analysis(gp_scripts)

# For this example, assuming you already have these loaded:
# thetas, processed_ids, statins, cov

print(f"Signature data shape: {thetas.shape}")
print(f"Number of processed patients: {len(processed_ids)}")
print(f"Number of prescription records: {len(statins)}")
print(f"Number of patients with covariates: {len(cov)}")

=== Simple GP Scripts Check ===

Data shape: (56212343, 8)
Columns: ['eid', 'data_provider', 'issue_date', 'read_2', 'bnf_code', 'dmd_code', 'drug_name', 'quantity']

Data types:
eid                int64
data_provider      int64
issue_date        object
read_2            object
bnf_code          object
dmd_code         float64
drug_name         object
quantity          object
dtype: object

Missing values:
  issue_date: 6091
  read_2: 42056239
  bnf_code: 13189321
  dmd_code: 50300279
  drug_name: 7228777
  quantity: 7275923

First few rows:
       eid  data_provider  issue_date read_2        bnf_code  dmd_code  \
0  1000015              3  14/06/2005    NaN  06.03.02.00.00       NaN   
1  1000015              3  28/07/2014    NaN  05.01.01.03.00       NaN   
2  1000015              3  10/09/2009    NaN  05.01.01.02.00       NaN   
3  1000015              3  15/12/2004    NaN  03.01.01.03.00       NaN   
4  1000015              3  15/12/2004    NaN  03.02.00.00.00       NaN   

       

## 2. Clean and Prepare Statin Data

In [18]:
# Clean up the statin data - remove non-statins
true_statins = statins[statins['drug_name'].str.contains(
    'simvastatin|atorvastatin|rosuvastatin|pravastatin|fluvastatin|lovastatin', 
    case=False, na=False
)].copy()

print(f"Total prescription records: {len(statins)}")
print(f"True statins after filtering: {len(true_statins)}")
print(f"Unique patients with true statins: {true_statins['eid'].nunique()}")

# Show distribution of statin types
if len(true_statins) > 0:
    statin_counts = true_statins['drug_name'].value_counts().head(10)
    print("\nTop 10 statin types:")
    print(statin_counts)


true_statins.to_csv('true_statins.csv', index=False)
cov.to_csv('cov.csv', index=False)

# Then load with:
true_statins = pd.read_csv('true_statins.csv')
cov = pd.read_csv('cov.csv')


prescription_patient_ids = set(gp_scripts['eid'].unique())

  # Save this set for fast loading (optional)
np.save('prescription_patient_ids.npy', np.array(list(prescription_patient_ids)))

Total prescription records: 3891973
True statins after filtering: 3584756
Unique patients with true statins: 70329

Top 10 statin types:
drug_name
Simvastatin 40mg tablets     1203845
Simvastatin 20mg tablets      535753
Atorvastatin 20mg tablets     331045
Atorvastatin 10mg tablets     283735
Atorvastatin 40mg tablets     272954
Simvastatin 10mg tablets      135675
Atorvastatin 80mg tablets      92158
SIMVASTATIN tabs 40mg          66184
Pravastatin 40mg tablets       65840
Rosuvastatin 10mg tablets      57566
Name: count, dtype: int64


## 3. Prepare Covariate Dictionary for Matching

In [38]:
import torch as torch
from dt import *
cov = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/matched_pce_df_400k.csv')
cov.columns = cov.columns.str.strip()
cov = cov.rename(columns={cov.columns[0]: 'eid'})
cov['eid'] = cov['eid'].astype(int)
# Parse enrollment date and calculate age at enrollment
cov['enrollment'] = pd.to_datetime(cov['Enrollment_Date'], errors='coerce')
cov['age_at_enroll'] = cov['enrollment'].dt.year - cov['birth_year']
age_at_enroll = dict(zip(cov['eid'], cov['age_at_enroll']))
eid_to_yob = dict(zip(cov['eid'], cov['birth_year']))
# Create covariate dictionary for matching
covariate_dicts = {}

# Load PRS names and labels for plotting/interpretation
prs_names = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/prs_names_with_head.csv')
prs_labels = prs_names['Names'].tolist()

# Load disease names for reference
# (Assumes second column contains names)
disease_names_df = pd.read_csv("/Users/sarahurbut/aladynoulli2/pyScripts/disease_names.csv")
disease_names = disease_names_df.iloc[:, 1].tolist()

# Load G-matrix (genotype/PRS matrix)
G = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/G_matrix.pt")
G = G.detach().cpu().numpy()

# Load covariate data (demographics, labs, etc.)
cov = pd.read_csv('/Users/sarahurbut/aladynoulli2/pyScripts/matched_pce_df_400k.csv')
cov.columns = cov.columns.str.strip()
cov = cov.rename(columns={cov.columns[0]: 'eid'})
cov['eid'] = cov['eid'].astype(int)
# Parse enrollment date and calculate age at enrollment
cov['enrollment'] = pd.to_datetime(cov['Enrollment_Date'], errors='coerce')
cov['age_at_enroll'] = cov['enrollment'].dt.year - cov['birth_year']
age_at_enroll = dict(zip(cov['eid'], cov['age_at_enroll']))
eid_to_yob = dict(zip(cov['eid'], cov['birth_year']))


# 7. Add prior disease/condition flags to covariate table
# ---------------------------------------------
# These functions flag prior disease status at enrollment for each subject
prev_condition(cov, 'Dm_Any', 'Dm_censor_age', 'age_enrolled', 'prev_dm')
prev_condition(cov, 'DmT1_Any', 'DmT1_censor_age', 'age_enrolled', 'prev_dm1')
prev_condition(cov, 'Ht_Any', 'Ht_censor_age', 'age_enrolled', 'prev_ht')
prev_condition(cov, 'HyperLip_Any', 'HyperLip_censor_age', 'age_enrolled', 'prev_hl')
prev_condition(cov, 'Cad_Any', 'Cad_censor_age', 'age_enrolled', 'prev_cad')
# ---------------------------------------------
# 8. Build mapping dictionaries for covariates and PRS
# ---------------------------------------------
ldl_idx = prs_labels.index('LDL_SF')   # LDL PRS index
cad_idx = prs_labels.index('CAD')      # CAD PRS index
eid_to_dm2_prev = dict(zip(cov['eid'], cov['prev_dm']))
eid_to_antihtnbase = dict(zip(cov['eid'], cov['prev_ht']))
eid_to_htn = dict(zip(cov['eid'], cov['prev_ht']))
eid_to_smoke = dict(zip(cov['eid'], cov['SmokingStatusv2']))
eid_to_dm1_prev = dict(zip(cov['eid'], cov['prev_dm1']))
eid_to_hl_prev = dict(zip(cov['eid'], cov['prev_hl']))
eid_to_sex = dict(zip(cov['eid'],cov['Sex']))
eid_to_age = dict(zip(cov['eid'],cov['age_enrolled']))
eid_to_ldl_prs = {eid: G[i, ldl_idx] for i, eid in enumerate(processed_ids)}
eid_to_cad_prs = {eid: G[i, cad_idx] for i, eid in enumerate(processed_ids)}
eid_to_race = dict(zip(cov['eid'],cov['race']))
eid_to_pce_goff = dict(zip(cov['eid'],cov['pce_goff']))
eid_to_tchol = dict(zip(cov['eid'],cov['tchol']))
eid_to_hdl = dict(zip(cov['eid'],cov['hdl']))
eid_to_sbp = dict(zip(cov['eid'],cov['SBP']))
# Add these disease mapping dictionaries
eid_to_cad_any = dict(zip(cov['eid'], cov['Cad_Any']))
eid_to_cad_censor_age = dict(zip(cov['eid'], cov['Cad_censor_age']))
eid_to_dm_any = dict(zip(cov['eid'], cov['Dm_Any']))
eid_to_dm_censor_age = dict(zip(cov['eid'], cov['Dm_censor_age']))
eid_to_ht_any = dict(zip(cov['eid'], cov['Ht_Any']))
eid_to_ht_censor_age = dict(zip(cov['eid'], cov['Ht_censor_age']))
eid_to_hyperlip_any = dict(zip(cov['eid'], cov['HyperLip_Any']))
eid_to_hyperlip_censor_age = dict(zip(cov['eid'], cov['HyperLip_censor_age']))



# Fix your covariate_dicts first:
covariate_dicts = {
    'age_at_enroll': eid_to_age,
    'sex': eid_to_sex,
    'dm2_prev': eid_to_dm2_prev,
    'antihtnbase': eid_to_antihtnbase,
    'dm1_prev': eid_to_dm1_prev,  # Fix: was eid_to_hl_prev
    'smoke': eid_to_smoke,
    'ldl_prs': eid_to_ldl_prs,
    'cad_prs': eid_to_cad_prs,
    'tchol': eid_to_tchol,
    'hdl': eid_to_hdl,
    'SBP': eid_to_sbp,  # Fix: was 'sbp' (lowercase)
    'pce_goff': eid_to_pce_goff,
    'Cad_Any': eid_to_cad_any,
    'Cad_censor_age': eid_to_cad_censor_age,
    'Dm_Any': eid_to_dm_any,
    'Dm_censor_age': eid_to_dm_censor_age,
    'Ht_Any': eid_to_ht_any,
    'Ht_censor_age': eid_to_ht_censor_age,
    'HyperLip_Any': eid_to_hyperlip_any,
    'HyperLip_censor_age': eid_to_hyperlip_censor_age
}


# ---------------------------------------------


In [None]:
# Import and run the simple treatment analysis with your variables
from scripts.simple_treatment_analysis import simple_treatment_analysis



torch.Size([407878, 348, 52])

In [116]:

print(f"Y shape: {Y.shape}")
print(f"processed_ids length: {len(processed_ids)}")
print(f"First few processed_ids: {processed_ids[:5]}")

#from comprehensive_treatment_analysis import *
sig_indices=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
# Now run the analysis with outcomes
# Run the simplified treatment analysis
statin_results = simple_treatment_analysis(
    gp_scripts=gp_scripts,
    true_statins=true_statins, 
    processed_ids=processed_ids,
    thetas=thetas,
    sig_indices=sig_indices,
    covariate_dicts=covariate_dicts,
    Y=Y,
    event_indices=[112, 113, 114, 115, 116],  # ASCVD composite events
    cov=cov
)

Y shape: torch.Size([407878, 348, 52])
processed_ids length: 400000
First few processed_ids: [1000015 1000023 1000037 1000042 1000059]
=== SIMPLIFIED TREATMENT ANALYSIS WITH SELF-CHECKING ===
Every step is verified and transparent

1. Verifying patient cohort definitions:
=== PATIENT COHORT VERIFICATION ===
Patients with complete data: 178,317
All statin patients: 70,329
Treated cohort (complete data + statins): 62,186
Control cohort (complete data - statins): 116,131
✅ No overlap between treated and control cohorts

2. Extracting treated patients using ObservationalTreatmentPatternLearner:
Found 53514 treated patients
Found 116131 never-treated patients
   Treated patients from learner: 53,514
   Never-treated patients from learner: 116,131

=== TREATED PATIENT VERIFICATION ===
Claimed treated patients: 53,514
  - With statins: 53,514
  - Without statins: 0
✅ All treated patients have statins

3. Defining clean controls:
   Found 116131 never-treated patients with signature data

=== 

## now we will look at Reduction CRC among patients treated with Aspirin

In [124]:
del aspirin_results
# Run analysis
from scripts.aspirin_colorectal_analysis import *
aspirin_results = aspirin_colorectal_analysis(
    gp_scripts=gp_scripts,
    true_aspirins=true_aspirins,
    processed_ids=processed_ids,
    thetas=thetas,
    covariate_dicts=covariate_dicts,
    Y=Y,  # Your outcome tensor
    colorectal_cancer_indices=[10,11],  # Define cancer events
    cov=cov
)

=== ASPIRIN-COLORECTAL CANCER PREVENTION ANALYSIS ===
Expected effect: Aspirin should REDUCE colorectal cancer risk (HR < 1.0)
Expected HR from trials: ~0.7-0.8 (20-30% risk reduction)

1. Verifying patient cohort definitions:
=== PATIENT COHORT VERIFICATION (ASPIRIN) ===
Patients with complete data: 178,317
All aspirin patients: 41,110
Treated cohort (complete data + aspirin): 37,674
Control cohort (complete data - aspirin): 140,643
✅ No overlap between treated and control cohorts

2. Extracting treated patients using ObservationalTreatmentPatternLearner:
Found 33089 treated patients
Found 140643 never-treated patients
   Treated patients from learner: 33,089
   Never-treated patients from learner: 140,643

=== TREATED PATIENT VERIFICATION (ASPIRIN) ===
Claimed treated patients: 33,089
  - With aspirin: 33,089
  - Without aspirin: 0
✅ All treated patients have aspirin

3. Defining clean controls:
   Found 140643 never-treated patients with signature data

=== CONTROL PATIENT VERIFICAT

### take out clopidogre, filter out folks with less than 1 aspiri

# MI reduction among diabetics treated with Metforming

In [13]:
# Import the metformin analysis functions
from scripts.metformin_analysis import *

# Find metformin prescriptions (you already have this data)
metformins = find_metformin_basic(gp_scripts)


=== Basic Metformin Search ===

Found 700404 prescriptions containing 'metformin'
Found 14128 prescriptions containing 'glucophage'
Found 255561 prescriptions with BNF pattern '06.01.02.01'
Found 15735 prescriptions with BNF pattern '06010202'

DEBUG: Checking what drugs are being included...
Metformin mask matches: 714532
BNF mask matches: 271296
Combined matches: 970113

Total potential metformin prescriptions: 970113
Unique patients with metformin: 13010

Sample metformin prescriptions:
          eid  issue_date                drug_name  bnf_code
1252  1000198  07/06/2016  Metformin 500mg tablets  06010202
1256  1000198  06/12/2013  Metformin 500mg tablets  06010202
1257  1000198  09/05/2016  Metformin 500mg tablets  06010202
1258  1000198  29/10/2013  Metformin 500mg tablets  06010202
1260  1000198  16/12/2014  Metformin 500mg tablets  06010202
1263  1000198  22/10/2014  Metformin 500mg tablets  06010202
1265  1000198  18/10/2016  Metformin 500mg tablets  06010202
1269  1000198  2

In [14]:
# Let's check what BNF codes we're actually getting
print("BNF codes in the results:")
bnf_counts = metformins['bnf_code'].value_counts().head(10)
for code, count in bnf_counts.items():
    print(f"  {code}: {count} prescriptions")

# And what drugs correspond to each BNF code
print("\nDrugs by BNF code:")
for code in bnf_counts.index[:5]:
    drugs_with_code = metformins[metformins['bnf_code'] == code]['drug_name'].value_counts().head(3)
    print(f"\nBNF {code}:")
    for drug, count in drugs_with_code.items():
        print(f"  {drug}: {count}")

# Filter to only keep actual metformin drugs
metformins_clean = metformins[metformins['drug_name_str'].str.contains('metformin', case=False, na=False)]

print(f"\nAfter filtering for 'metformin' in name:")
print(f"Total metformin prescriptions: {len(metformins_clean):,}")
print(f"Unique patients with metformin: {metformins_clean['eid'].nunique():,}")

# Check what we have now
print(f"\nMost common metformin drugs:")
top_drugs = metformins_clean['drug_name'].value_counts().head(5)
for drug, count in top_drugs.items():
    print(f"  {drug}: {count}")

BNF codes in the results:
  06.01.02.02.00: 569485 prescriptions
  06.01.02.01.00: 255561 prescriptions
  0601022B0AAABAB: 31618 prescriptions
  06010202: 15735 prescriptions
  06.01.02.03.00: 6143 prescriptions
  0601022B0AAADAD: 3035 prescriptions
  06010203: 363 prescriptions
  0601022B0BBAAAB: 175 prescriptions

Drugs by BNF code:

BNF 06.01.02.02.00:
  Metformin 500mg tablets: 383555
  Metformin 850mg tablets: 74758
  Metformin 500mg modified-release tablets: 65913

BNF 06.01.02.01.00:
  Gliclazide 80mg tablets: 180018
  Gliclazide 30mg modified-release tablets: 15563
  Gliclazide 40mg tablets: 12155

BNF 0601022B0AAABAB:
  Metformin Hydrochloride  Tablets  500 mg: 18509
  Metformin Hydrochloride TABS 500MG: 7351
  METFORMIN HYDROCHLORIDE TABLETS 500MG: 5375

BNF 06010202:
  Metformin 500mg tablets: 11153
  Metformin Hydrochloride  M/R tablets  500 mg: 1067
  Metformin 500mg modified-release tablets: 1052

BNF 06.01.02.03.00:
  Rosiglitazone 2mg / Metformin 1g tablets: 1967
  Rosi

In [18]:
# Run the complete metformin analysis with clean data
from scripts.metformin_analysis import *
metformin_diabetic_results = metformin_diabetics_analysis(
    gp_scripts=gp_scripts,
    true_metformins=metformins_clean,
    processed_ids=processed_ids,
    thetas=thetas,
    covariate_dicts=covariate_dicts,
    Y=Y,
    mi_indices=[112],  # Your MI indices
    cov=cov
)

=== METFORMIN-MI PREVENTION ANALYSIS IN DIABETICS ===
Expected effect: Metformin should REDUCE MI risk in diabetics (HR < 1.0)
Expected HR from UKPDS trial: ~0.6-0.7 (30-40% risk reduction)

1. Verifying patient cohort definitions:
=== PATIENT COHORT VERIFICATION (METFORMIN) ===
Patients with complete data: 178,317
All metformin patients: 12,684
Treated cohort (complete data + metformin): 11,276
Control cohort (complete data - metformin): 167,041
✅ No overlap between treated and control cohorts

2. Extracting treated patients using ObservationalTreatmentPatternLearner:
Found 9661 treated patients
Found 167041 never-treated patients
   Treated patients from learner: 9,661
   Never-treated patients from learner: 167,041

=== TREATED PATIENT VERIFICATION (METFORMIN) ===
Claimed treated patients: 9,661
  - With metformin: 9,661
  - Without metformin: 0
✅ All treated patients have metformin

3. Defining diabetic controls:
   Found 22142 total diabetics
   Found 11005 diabetic controls (not 