In [None]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import pickle
from collections import defaultdict


In [None]:
MAX_SEQ_LENGTH = 1024

def calculate_mace_outcomes(antihypertensive_starts, mace_data, index_col_name):
    """Calculate which patients had MACE within 12 months after antihypertensive start"""
    # Convert only necessary columns to datetime
    antihypertensive_starts.loc[:, index_col_name] = pd.to_datetime(antihypertensive_starts[index_col_name])
    mace_data.loc[:, 'first_mace_date'] = pd.to_datetime(mace_data['first_mace_date'])
    
    # Use faster merge with pre-selected columns
    outcomes = pd.merge(
        antihypertensive_starts[['person_id', index_col_name]],
        mace_data[['person_id', 'first_mace_date']],
        on='person_id', 
        how='left'
    )
    print(outcomes.head())
    # Vectorized operations
    days_to_mace = np.zeros(len(outcomes), dtype=np.float32)
    valid_dates = outcomes['first_mace_date'].notna()

    if len(valid_dates) > 0:
        days_to_mace[valid_dates] = (outcomes.loc[valid_dates, 'first_mace_date'] - 
                                    outcomes.loc[valid_dates, index_col_name]).apply(lambda x: x.days if pd.notna(x) else 0)
    mace_12m = np.logical_and(days_to_mace >= 7, days_to_mace <= 365).astype(np.int8)
    print(mace_12m.mean())
    return pd.DataFrame({'person_id': outcomes['person_id'], 'mace_12m': mace_12m})

def create_vocabulary(df):
    """Create vocabulary from single dataframe of concept IDs"""
    # Use numpy for faster unique operation
    unique_codes = np.unique(df['concept_id'].astype(np.int32)).astype(str)
    
    # Use dict comprehension instead of individual assignments
    vocab = {**{code: idx+4 for idx, code in enumerate(unique_codes)},
            **{'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[DAY]': 3}}
    return vocab

def tokenize_patient_data(events_array, dates_array, vocab):
    """Optimized tokenization using numpy arrays"""
    tokens = ['[CLS]']
    current_date = None
    
    for date, concept_id in zip(dates_array, events_array):
        if date != current_date:
            if current_date is not None:
                tokens.append('[SEP]')
            tokens.append('[DAY]')
            current_date = date
        tokens.append(str(concept_id))
    tokens.append('[SEP]')
    
    # Convert tokens to indices using list comprehension
    token_ids = [vocab.get(token, vocab['[PAD]']) for token in tokens][-MAX_SEQ_LENGTH:]
    return token_ids

def process_patient_data(data, antihypertensive_starts, mace_data, index_col_name, suffix=""):
    """Process patient data and create sequences with MACE outcomes"""
    # Calculate MACE outcomes first
    print("Calculating MACE outcomes...")
    mace_outcomes_df = calculate_mace_outcomes(antihypertensive_starts, mace_data, index_col_name)
    
    # Convert dates once
    print("Converting dates...")
    data.loc[:, 'event_date'] = pd.to_datetime(data['event_date'])
    data.loc[:, 'concept_id'] = data['concept_id'].astype(np.int32)
    
    # Pre-sort data
    data.sort_values(['person_id', 'event_date'], inplace=True)
    
    # Create vocabulary
    print("Creating vocabulary...")
    vocab = create_vocabulary(data)
    
    # Pre-merge outcomes
    data = data.merge(mace_outcomes_df, on='person_id', how='left')
    
    # Group data once and convert to dictionary of arrays for faster access
    print("Grouping patient data...")
    grouped_data = {
        name: (
            group['concept_id'].values,
            group['event_date'].values,
            group['mace_12m'].iloc[0]
        )
        for name, group in tqdm(data.groupby('person_id'))
    }
    
    # Initialize arrays
    n_patients = len(grouped_data)
    padded_sequences = np.zeros((n_patients, MAX_SEQ_LENGTH), dtype=np.int32)
    mace_outcomes = np.zeros(n_patients, dtype=np.float32)
    sample_ids = np.empty(n_patients, dtype=object)
    
    # Process all patients
    print("Processing sequences...")
    max_seq_length = 0
    for i, (patient_id, (events, dates, mace)) in enumerate(tqdm(grouped_data.items())):
        tokens = tokenize_patient_data(events, dates, vocab)
        seq_len = min(len(tokens), MAX_SEQ_LENGTH)
        padded_sequences[i, :seq_len] = tokens[:seq_len]
        max_seq_length = max(max_seq_length, len(tokens))
        mace_outcomes[i] = mace
        sample_ids[i] = patient_id
    
    # Create mapping
    sample_id_to_index = dict(zip(sample_ids, range(len(sample_ids))))
    
    # Save all data
    print("Saving processed data...")
    np.save(f'./processed_data/transformer_input_sequences{suffix}.npy', padded_sequences)
    np.save(f'./processed_data/transformer_input_lengths{suffix}.npy', 
            np.array([min(len(tokenize_patient_data(events, dates, vocab)), MAX_SEQ_LENGTH) 
                     for events, dates, _ in grouped_data.values()]))
    np.save(f'./processed_data/transformer_mace_outcomes{suffix}.npy', mace_outcomes)
    
    with open(f'./processed_data/transformer_vocab{suffix}.pkl', 'wb') as f:
        pickle.dump(vocab, f)
    
    pd.DataFrame(list(sample_id_to_index.items()), 
                columns=['person_id', 'index']).to_csv(
                f'./processed_data/transformer_sample_id_to_index{suffix}.csv', index=False)
    
    # Print summary statistics
    print("\nSummary Statistics:")
    print(f"Max sequence length: {max_seq_length}")
    print(f"Vocabulary size: {len(vocab)}")
    print(f"Number of patients: {n_patients}")
    print(f"Shape of padded_sequences: {padded_sequences.shape}")
    print(f"Shape of MACE_outcomes: {mace_outcomes.shape}")
    print(f"Number of positive MACE outcomes: {mace_outcomes.sum()}")
    print(f"MACE rate: {mace_outcomes.mean():.3f}")
    
    return padded_sequences, mace_outcomes, vocab, sample_id_to_index

## Remove dox patients from PT cohort and process PT data

In [None]:
data = pd.read_csv('./raw_data/dox_patients_1024_events_prior.csv',usecols=['person_id','concept_id','event_date'])

In [None]:
antihypertensive_data = pd.read_csv('./raw_data/antihypertensive_1024_events_prior_to_med_start.csv',usecols=['person_id','concept_id','event_date'])
antihypertensive_data = antihypertensive_data[~antihypertensive_data['person_id'].isin(data['person_id'])]

In [None]:
dob_antihx = pd.read_csv('./raw_data/dob_antihypertensives.csv')
age = antihypertensive_data.sort_values('event_date', ascending=True).drop_duplicates('person_id', keep='first').merge(dob_antihx, how='inner', on='person_id')

In [None]:
age['event_date'] = pd.to_datetime(age['event_date'])
age['birth_datetime'] = pd.to_datetime(age['birth_datetime'], format='mixed', dayfirst=False)
age['age'] = (age['event_date']-age['birth_datetime']).dt.days / 365.25
age[['age','person_id']].to_csv('./processed_data/age.csv')

In [None]:
dob_dox = pd.read_csv('./raw_data/dox_dob.csv')
age_dox = data.sort_values('event_date', ascending=True).drop_duplicates('person_id', keep='first').merge(dob_dox, how='inner', on='person_id')

In [None]:
age_dox['event_date'] = pd.to_datetime(age_dox['event_date'])
age_dox['birth_datetime'] = pd.to_datetime(age_dox['birth_datetime'], format='mixed', dayfirst=False)
age_dox['age'] = (age_dox['event_date']-age_dox['birth_datetime']).dt.days / 365.25
age_dox[['age','person_id']].to_csv('./processed_data/age_dox.csv')

In [None]:
antihypertensive_starts = pd.read_csv('./raw_data/antihypertensive_start.csv', usecols=['person_id','first_antihypertensive_date'])
antihypertensive_starts = antihypertensive_starts[~antihypertensive_starts['person_id'].isin(data['person_id'])]

mace_data = pd.read_csv('./raw_data/first_mace_post_antihypertensives.csv').drop('Unnamed: 0',axis=1)
mace_data = mace_data[~mace_data['person_id'].isin(data['person_id'])]


In [None]:
# Process everything
sequences, outcomes, vocab, id_mapping = process_patient_data(antihypertensive_data, antihypertensive_starts, mace_data, 'first_antihypertensive_date')

In [None]:
# Calculate mortality outcomes
deaths = pd.read_csv('./raw_data/death.csv')
deaths['first_exposure_date'] = pd.to_datetime(deaths['first_exposure_date'])
deaths['death_date'] = pd.to_datetime(deaths['death_date'])

# Calculate days to death and create binary outcome
deaths['days_to_death'] = (deaths['death_date'] - deaths['first_exposure_date']).dt.total_seconds() / (24 * 3600)
deaths['death_12m'] = ((deaths['days_to_death'] >= 0) & 
                      (deaths['days_to_death'] <= 365)).fillna(0).astype(np.int8)

# Create array aligned with our sequence data using the id_mapping
death_outcomes = np.zeros(len(id_mapping), dtype=np.int8)
death_dict = dict(zip(deaths['person_id'], deaths['death_12m']))

# Use id_mapping dictionary directly
for person_id, idx in id_mapping.items():
    death_outcomes[idx] = death_dict.get(person_id, 0)

# Save to file
np.save('./processed_data/transformer_death_outcomes.npy', death_outcomes)

# Print summary
print(f"Number of deaths within 12 months: {death_outcomes.sum()}")
print(f"12-month mortality rate: {death_outcomes.mean():.3f}")

## Rest of data processing

In [None]:
# Load your data files - skip if you already have them loaded
data = pd.read_csv('./raw_data/dox_patients_1024_events_prior.csv',usecols=['person_id','concept_id','event_date'])
antihypertensive_starts = pd.read_csv('./raw_data/dox_start.csv')
mace_data = pd.read_csv('./raw_data/first_mace_post_dox.csv')


In [None]:

# Process everything
sequences, outcomes, vocab, id_mapping = process_patient_data(data, antihypertensive_starts, mace_data, 'first_doxorubicin_date','_FT')

In [None]:
# Calculate mortality outcomes
deaths = pd.read_csv('./raw_data/dox_death.csv')
deaths['first_exposure_date'] = pd.to_datetime(deaths['first_exposure_date'])
deaths['death_date'] = pd.to_datetime(deaths['death_date'])

# Calculate days to death and create binary outcome
deaths['days_to_death'] = (deaths['death_date'] - deaths['first_exposure_date']).dt.total_seconds() / (24 * 3600)
deaths['death_12m'] = ((deaths['days_to_death'] >= 0) & 
                      (deaths['days_to_death'] <= 365)).fillna(0).astype(np.int8)

# Create array aligned with our sequence data using the id_mapping
death_outcomes = np.zeros(len(id_mapping), dtype=np.int8)
death_dict = dict(zip(deaths['person_id'], deaths['death_12m']))

# Use id_mapping dictionary directly
for person_id, idx in id_mapping.items():
    death_outcomes[idx] = death_dict.get(person_id, 0)

# Save to file
np.save('./processed_data/transformer_death_outcomes_FT.npy', death_outcomes)

# Print summary
print(f"Number of deaths within 12 months: {death_outcomes.sum()}")
print(f"12-month mortality rate: {death_outcomes.mean():.3f}")

In [None]:
mace_data = pd.read_csv('./raw_data/first_mace_post_antihypertensives.csv').drop('Unnamed: 0',axis=1)


In [None]:
dox_mace = mace_data[mace_data['person_id'].isin(data['person_id'])].drop('Unnamed: 0',axis=1)

In [None]:
dox_mace['first_exposure_date'] = pd.to_datetime(dox_mace['first_exposure_date'])
dox_mace['first_mace_date'] = pd.to_datetime(dox_mace['first_mace_date'])


In [None]:
samson = dox_mace[
    (dox_mace['first_mace_date'] - dox_mace['first_exposure_date']).dt.days.between(7, 365, inclusive='both')
]
samson.head()

In [None]:
shreya = pd.read_csv('./shreya_positive_class.csv')
shreya.head()

In [None]:
df = samson.merge(shreya, how='outer',left_on='person_id', right_on='patient_id')
df['prediction_time'] = pd.to_datetime(df['prediction_time'])
diff_pats = np.union1d((df[df['prediction_time'] != df['first_exposure_date']])['person_id'], (df[df['prediction_time'] != df['first_exposure_date']])['patient_id'])
df[df['prediction_time'] != df['first_exposure_date']]

In [None]:
dox_mace.to_csv('./samson_mace_data_dox_cohort.csv')

In [None]:
# Convert to datetime
dox_mace['first_exposure_date'] = pd.to_datetime(dox_mace['first_exposure_date'])
dox_mace['first_mace_date'] = pd.to_datetime(dox_mace['first_mace_date'])

# Calculate days and label using the same logic as original code
days_to_mace = np.zeros(len(dox_mace), dtype=np.float32)
valid_dates = dox_mace['first_mace_date'].notna()

if len(valid_dates) > 0:
    days_to_mace[valid_dates] = (dox_mace.loc[valid_dates, 'first_mace_date'] - 
                                dox_mace.loc[valid_dates, 'first_exposure_date']).apply(lambda x: x.days if pd.notna(x) else 0)

dox_mace['label'] = np.logical_and(days_to_mace >= 7, days_to_mace <= 365).astype(np.int8)