In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from hmmlearn import hmm

np.random.seed(42)

In [2]:
df = pd.read_csv("../../dataset/df_cleaned_1atc.tsv", sep="\t")
df.head()

Unnamed: 0,eid,drug_era_id,drug_concept_id,drug_era_start_date,drug_era_end_date,drug_exposure_count,gap_days,concept_name,atc_code,duration,atc_level3
0,6021257,1236950609195,19005129,2014-05-12,2014-06-10,1,0,clobetasone,D07AB01,30,D07A
1,3430966,721554547993,19008994,2010-10-12,2010-11-07,1,0,mebeverine,A03AA04,27,A03A
2,2127268,910533073010,755695,2006-10-23,2007-10-04,6,23,fluoxetine,N06AB03,347,N06A
3,2441156,901943201395,1549080,2010-01-18,2010-04-17,1,0,"estrogens, conjugated (USP)",G03CA57,90,G03C
4,5489554,1649267519173,19011773,2011-03-28,2011-03-28,1,0,ascorbic acid,A11GA01,1,A11G


In [3]:
df["drug_era_start_date"] = pd.to_datetime(df["drug_era_start_date"])
df["drug_era_end_date"] = pd.to_datetime(df["drug_era_end_date"])

In [4]:
def split_eid_df(eid_df):
    # Convert dates to datetime if they aren't already
    eid_df = eid_df.copy()

    # Sort by start date
    eid_df = eid_df.sort_values("drug_era_start_date")

    # Find overlapping periods
    overlap_groups = []
    current_group = []

    for i in range(len(eid_df) - 1):
        current_end = eid_df["drug_era_end_date"].iloc[i]
        next_start = eid_df["drug_era_start_date"].iloc[i + 1]

        if current_end > next_start:
            if not current_group:  # Start new group
                current_group.append(i)
            current_group.append(i + 1)
        elif current_group:  # End of an overlap group
            overlap_groups.append(eid_df.iloc[current_group])
            current_group = []

    # Don't forget to add the last group if it exists
    if current_group:
        overlap_groups.append(eid_df.iloc[current_group])

    return overlap_groups

In [5]:
# for each eid, split the data into multiple pieces
all_drug_eras = []
all_eids = df["eid"].unique()
sampled_eids = np.random.choice(all_eids, 1000, replace=False)
for eid in tqdm(sampled_eids):
    eid_df = df[df["eid"] == eid]
    all_drug_eras.extend(split_eid_df(eid_df))

print(f"number of drug eras: {len(all_drug_eras)}")

100%|██████████| 1000/1000 [00:05<00:00, 191.62it/s]

number of drug eras: 14867





In [6]:
def create_non_overlapping_periods(overlap_df):
    # Get all unique dates (both start and end)
    dates = np.concatenate(
        [
            overlap_df["drug_era_start_date"].values,
            overlap_df["drug_era_end_date"].values,
        ]
    )
    dates = pd.to_datetime(np.unique(dates))
    dates.sort_values()

    # Create periods between each adjacent pair of dates
    periods = []
    for i in range(len(dates) - 1):
        period_start = dates[i]
        period_end = dates[i + 1]

        # Find all drugs active during this period
        active_drugs = overlap_df[
            (overlap_df["drug_era_start_date"] <= period_start)
            & (overlap_df["drug_era_end_date"] >= period_end)
        ]

        if not active_drugs.empty:
            periods.append(
                {
                    "period_start": period_start,
                    "period_end": period_end,
                    "duration": (period_end - period_start).days,
                    "concept_names": ", ".join(
                        np.sort(active_drugs["concept_name"].unique())
                    ),
                    "drug_concept_ids": ", ".join(
                        np.sort(active_drugs["drug_concept_id"].astype(str).unique())
                    ),
                    "atc_codes": ", ".join(
                        np.sort(active_drugs["atc_code"].astype(str).unique())
                    ),
                    "atc_level3_codes": ", ".join(
                        np.sort(active_drugs["atc_level3"].astype(str).unique())
                    ),
                    "eid": active_drugs["eid"].iloc[
                        0
                    ],  # Assuming same eid for all rows
                }
            )

    return pd.DataFrame(periods)

In [7]:
non_overlapping_periods = []
for index, drug_era in tqdm(enumerate(all_drug_eras)):
    this_data = create_non_overlapping_periods(drug_era)
    this_data["sequence_id"] = index
    non_overlapping_periods.append(this_data)
non_overlapping_periods[0]

14867it [00:18, 822.51it/s]


Unnamed: 0,period_start,period_end,duration,concept_names,drug_concept_ids,atc_codes,atc_level3_codes,eid,sequence_id
0,2005-08-03,2005-08-23,20,"albuterol, beclomethasone, prednisolone","1115572, 1154343, 1550557","A01AC04, A07EA07, R03AC02","A01A, A07E, R03A",3894654,0
1,2005-08-23,2005-09-01,9,"albuterol, beclomethasone","1115572, 1154343","A07EA07, R03AC02","A07E, R03A",3894654,0


#### columns of the dataset
- **eid**: unique identifier for each person
- **drug_concept_id**: unique identifier for each drug 
- **drug_era_id**: unique identifier for each drug era. each drug era is a continuous period of drug use. each person can take the same drug in multiple drug eras. 
- **drug_era_start_date**: start date of the drug era 
- **drug_era_end_date**: end date of the drug era 
- **duration of the drug era**: calculated as drug_era_end_date - drug_era_start_date + 1 day


#### HMM

Each unique drug (identified by drug_concept_id or atc_code) could represent a distinct hidden state in the HMM. \
The observed sequence represents the order of drug eras. \
*Emissions* are the durations of each drug era.


In [18]:
emissions = []
for sequence_data in non_overlapping_periods:
    emissions.extend(list(sequence_data["atc_level3_codes"]))

print(f"number of total emission values: {len(emissions)}")
emissions = np.unique(emissions)
print(f"number of unique emission values: {len(emissions)}")

number of total emission values: 46951
number of unique emission values: 7629


In [21]:
# encode emissions as integers
emissions_dict = {emissions[i]: i for i in range(len(emissions))}

In [23]:
# convert periods to sequence of emissions
sequences = []
for sequence_data in non_overlapping_periods:
    this_sequence = []
    for index, row in sequence_data.iterrows():
        this_sequence.extend(
            [emissions_dict[row["atc_level3_codes"]]] * row["duration"]
        )
    sequences.append(this_sequence)

In [24]:
cutoff_index = int(len(sequences) * 0.8)
train_sequences = sequences[:cutoff_index]
test_sequences = sequences[cutoff_index:]

train_lengths = [len(sequence) for sequence in train_sequences]
test_lengths = [len(sequence) for sequence in test_sequences]
print(f"number of train sequences: {len(train_sequences)}")
print(f"number of test sequences: {len(test_sequences)}")
print(f"average length of train sequences: {np.mean(train_lengths)}")
print(f"average length of test sequences: {np.mean(test_lengths)}")

train_sequences = [[value] for sequence in train_sequences for value in sequence]
test_sequences = [[value] for sequence in test_sequences for value in sequence]

number of train sequences: 11893
number of test sequences: 2974
average length of train sequences: 248.86193559236526
average length of test sequences: 275.68493611297913


In [25]:
model = hmm.CategoricalHMM(n_components=10)
model.fit(train_sequences, train_lengths)


In [26]:
model.score(train_sequences, train_lengths)


-9893217.569793448

In [27]:
model.score(test_sequences, test_lengths)

-inf