In [1]:
import pandas as pd
import numpy as np

terminology definition

we say drug era B is a **subsequent drug era** for drug era A if:
1. drug era A and drug era B are with the same person, and
2. drug_era_end_date of B > drug_era_start_date of A (i.e. B starts after A ends)

we say drug era B is the **closest subsequent drug era** for drug era A if:
- after A ends, B starts first among all the subsequent drug eras \
(theoretically, for one drug era, there could be multiple closest subsequent drug eras)

In [None]:
df = pd.read_csv("../../dataset/dataset.tsv", sep="\t")
# df.head()
display(df.loc[:, df.columns != "eid"].head())

In [4]:
def process_drug_switches(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Find the closest subsequent drug era B for each drug era A.

    Method:
        For each eid:
            for each drug era A, find its subsequent drug eras B and then add a row for each switch combination of A and B

    Args:
        df (pd.DataFrame): Input dataframe containing drug era information
            Required columns: eid, drug_era_id, drug_concept_id, drug_era_start_date,
            drug_era_end_date, drug_exposure_count, gap_days

    Returns:
        tuple[pd.DataFrame, pd.DataFrame]:
            - Processed dataframe containing drug switch information
            - Dataframe containing invalid drug eras (no subsequent eras)
    """
    # Ensure datetime format
    if df["drug_era_start_date"].dtype != "datetime64[ns]":
        df["drug_era_start_date"] = pd.to_datetime(
            df["drug_era_start_date"], format="%d/%m/%Y"
        )
    if df["drug_era_end_date"].dtype != "datetime64[ns]":
        df["drug_era_end_date"] = pd.to_datetime(
            df["drug_era_end_date"], format="%d/%m/%Y"
        )

    result_rows = []
    invalid_rows = []

    # Process each patient's drug eras
    for eid, patient_df in df.groupby("eid"):
        patient_df = patient_df.sort_values("drug_era_start_date").reset_index(
            drop=True
        )

        n_rows = len(patient_df)
        if n_rows < 2:  # Patient has only one drug era
            invalid_df_row = patient_df.iloc[0].to_dict()
            invalid_df_row["reason"] = "single_drug_era"
            invalid_rows.append(invalid_df_row)
            continue

        # Create matrices for vectorized operations
        end_dates = patient_df["drug_era_end_date"].values
        start_dates = patient_df["drug_era_start_date"].values

        # Initialize time_diffs matrix with NaT
        time_diffs = np.full(
            (n_rows, n_rows), np.timedelta64("NaT"), dtype="timedelta64[ns]"
        )

        # Calculate all future pairwise time differences
        for i in range(n_rows):
            for j in range(i + 1, n_rows):
                time_diffs[i, j] = start_dates[j] - end_dates[i]

        # Replace negative time differences with NaT
        time_diffs[time_diffs <= pd.Timedelta(0)] = np.timedelta64("NaT")

        # Process each drug era
        for idx in range(n_rows):
            row_times = time_diffs[idx]
            valid_times = row_times[~np.isnat(row_times)]

            if len(valid_times) == 0:  # No valid subsequent drug eras
                invalid_df_row = patient_df.iloc[idx].to_dict()
                invalid_df_row["reason"] = "no_subsequent_era"
                invalid_rows.append(invalid_df_row)
                continue

            # Find all indices with minimum time difference
            min_time = valid_times.min()
            min_indices = np.where(row_times == min_time)[0]

            row_a = patient_df.iloc[idx]
            for min_idx in min_indices:
                row_b = patient_df.iloc[min_idx]
                combined_row = {
                    "eid": eid,
                    "A_drug_era_id": row_a.drug_era_id,
                    "A_drug_concept_id": row_a.drug_concept_id,
                    "A_drug_era_start_date": row_a.drug_era_start_date,
                    "A_drug_era_end_date": row_a.drug_era_end_date,
                    "A_drug_exposure_count": row_a.drug_exposure_count,
                    "A_gap_days": row_a.gap_days,
                    "B_drug_era_id": row_b.drug_era_id,
                    "B_drug_concept_id": row_b.drug_concept_id,
                    "B_drug_era_start_date": row_b.drug_era_start_date,
                    "B_drug_era_end_date": row_b.drug_era_end_date,
                    "B_drug_exposure_count": row_b.drug_exposure_count,
                    "B_gap_days": row_b.gap_days,
                    "switch_interval": min_time,
                }
                result_rows.append(combined_row)

    return pd.DataFrame(result_rows), pd.DataFrame(invalid_rows)

In [None]:
result_df, invalid_df = process_drug_switches(df)

In [None]:
same_drug_switches = result_df[
    result_df["A_drug_concept_id"] == result_df["B_drug_concept_id"]
]
print(f"Number of switches between same drug: {len(same_drug_switches)}")
print(
    f"Percentage of total switches: {len(same_drug_switches) / len(result_df) * 100:.2f}%"
)

In [None]:
print(len(df))
print(
    len(result_df["A_drug_era_id"].unique()) + len(invalid_df["drug_era_id"].unique())
)
print(len(result_df["A_drug_era_id"].unique()))
print(len(invalid_df["drug_era_id"].unique()))

In [None]:
print(len(result_df))
print(len(invalid_df))