In [2]:
import pandas as pd

terminology definition

we say drug era B is a **subsequent drug era** for drug era A if:
1. drug era A and drug era B are with the same person, and
2. drug_era_end_date of B > drug_era_start_date of A (i.e. B starts after A ends)

we say drug era B is the **closest subsequent drug era** for drug era A if:
- after A ends, B starts first among all the subsequent drug eras \
(theoretically, for one drug era, there could be multiple closest subsequent drug eras)

In [3]:
df = pd.read_csv("../../dataset/dataset.tsv", sep="\t")
# df.head()
# display(df.loc[:, df.columns != "eid"].head())

In [17]:
def get_drug_switches_all(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Find the closest subsequent drug era B for each drug era A.
    Output both valid and invalid drug eras.
    """
    # Ensure datetime format
    if df["drug_era_start_date"].dtype != "datetime64[ns]":
        df["drug_era_start_date"] = pd.to_datetime(
            df["drug_era_start_date"], format="%d/%m/%Y"
        )
    if df["drug_era_end_date"].dtype != "datetime64[ns]":
        df["drug_era_end_date"] = pd.to_datetime(
            df["drug_era_end_date"], format="%d/%m/%Y"
        )

    # Sort all data by eid and start date
    df_sorted = df.sort_values(["eid", "drug_era_start_date"]).reset_index(drop=True)

    # Identify single drug era patients
    eid_counts = df_sorted.groupby("eid").size()
    single_era_eids = eid_counts[eid_counts == 1].index
    single_era_df = df_sorted[df_sorted["eid"].isin(single_era_eids)].copy()
    single_era_df["reason"] = "single_drug_era"

    # Process multi-drug era patients
    multi_era_df = df_sorted[~df_sorted["eid"].isin(single_era_eids)]

    # Create cross join within eid groups
    pairs = pd.merge(
        multi_era_df.assign(key=1),
        multi_era_df.assign(key=1),
        on=["eid", "key"],
        suffixes=("_A", "_B"),
    ).drop("key", axis=1)

    # Calculate time differences
    pairs["switch_interval"] = (
        pairs["drug_era_start_date_B"] - pairs["drug_era_end_date_A"]
    )

    # Create mask for valid transitions
    valid_mask = pairs["switch_interval"] > pd.Timedelta(0)

    # Identify drug eras with no valid subsequent eras
    no_subsequent_mask = ~pairs.groupby(["eid", "drug_era_id_A"])[
        "switch_interval"
    ].transform(lambda x: (x > pd.Timedelta(0)).any())

    no_subsequent_df = multi_era_df[
        multi_era_df["drug_era_id"].isin(
            pairs[no_subsequent_mask]["drug_era_id_A"].unique()
        )
    ].copy()
    no_subsequent_df["reason"] = "no_subsequent_era"

    # Combine all invalid cases
    invalid_df = pd.concat([single_era_df, no_subsequent_df], ignore_index=True)

    # Get valid transitions
    valid_pairs = pairs[valid_mask].copy()

    # Find minimum switch interval for each drug era A
    min_intervals = valid_pairs.groupby(["eid", "drug_era_id_A"])[
        "switch_interval"
    ].transform("min")
    valid_pairs = valid_pairs[valid_pairs["switch_interval"] == min_intervals]

    # Create final result dataframe
    result_df = pd.DataFrame(
        {
            "eid": valid_pairs["eid"],
            "A_drug_era_id": valid_pairs["drug_era_id_A"],
            "A_drug_concept_id": valid_pairs["drug_concept_id_A"],
            "A_drug_era_start_date": valid_pairs["drug_era_start_date_A"],
            "A_drug_era_end_date": valid_pairs["drug_era_end_date_A"],
            "A_drug_exposure_count": valid_pairs["drug_exposure_count_A"],
            "A_gap_days": valid_pairs["gap_days_A"],
            "B_drug_era_id": valid_pairs["drug_era_id_B"],
            "B_drug_concept_id": valid_pairs["drug_concept_id_B"],
            "B_drug_era_start_date": valid_pairs["drug_era_start_date_B"],
            "B_drug_era_end_date": valid_pairs["drug_era_end_date_B"],
            "B_drug_exposure_count": valid_pairs["drug_exposure_count_B"],
            "B_gap_days": valid_pairs["gap_days_B"],
            "switch_interval": valid_pairs["switch_interval"],
        }
    )

    return result_df, invalid_df

In [34]:
def get_drug_switches(df: pd.DataFrame) -> pd.DataFrame:
    """
    Find the closest subsequent drug era B for each drug era A.
    Output only valid drug eras.
    """
    # Ensure datetime format
    if df["drug_era_start_date"].dtype != "datetime64[ns]":
        df["drug_era_start_date"] = pd.to_datetime(
            df["drug_era_start_date"], format="%d/%m/%Y"
        )
    if df["drug_era_end_date"].dtype != "datetime64[ns]":
        df["drug_era_end_date"] = pd.to_datetime(
            df["drug_era_end_date"], format="%d/%m/%Y"
        )

    # Sort all data by eid and start date
    df_sorted = df.sort_values(["eid", "drug_era_start_date"]).reset_index(drop=True)

    # Create cross join within eid groups
    pairs = pd.merge(
        df_sorted.assign(key=1),
        df_sorted.assign(key=1),
        on=["eid", "key"],
        suffixes=("_A", "_B"),
    ).drop("key", axis=1)

    # Calculate time differences
    pairs["switch_interval"] = (
        pairs["drug_era_start_date_B"] - pairs["drug_era_end_date_A"]
    )

    # Filter valid transitions (positive time difference)
    pairs = pairs[pairs["switch_interval"] > pd.Timedelta(0)]

    # Find minimum switch interval for each drug era A
    min_intervals = pairs.groupby(["eid", "drug_era_id_A"])[
        "switch_interval"
    ].transform("min")
    pairs = pairs[pairs["switch_interval"] == min_intervals]

    # Create final result dataframe
    result = pd.DataFrame(
        {
            "eid": pairs["eid"],
            "A_drug_era_id": pairs["drug_era_id_A"],
            "A_drug_concept_id": pairs["drug_concept_id_A"],
            "A_drug_era_start_date": pairs["drug_era_start_date_A"],
            "A_drug_era_end_date": pairs["drug_era_end_date_A"],
            "A_drug_exposure_count": pairs["drug_exposure_count_A"],
            "A_gap_days": pairs["gap_days_A"],
            "B_drug_era_id": pairs["drug_era_id_B"],
            "B_drug_concept_id": pairs["drug_concept_id_B"],
            "B_drug_era_start_date": pairs["drug_era_start_date_B"],
            "B_drug_era_end_date": pairs["drug_era_end_date_B"],
            "B_drug_exposure_count": pairs["drug_exposure_count_B"],
            "B_gap_days": pairs["gap_days_B"],
            "switch_interval": pairs["switch_interval"],
        }
    )

    return result

In [None]:
# result_df, invalid_df = get_drug_switches_all(df)
result_df = get_drug_switches(df)

In [None]:
same_drug_switches = result_df[
    result_df["A_drug_concept_id"] == result_df["B_drug_concept_id"]
]
print(f"Number of switches between same drug: {len(same_drug_switches)}")
print(
    f"Percentage of total switches: {len(same_drug_switches) / len(result_df) * 100:.2f}%"
)

In [None]:
print(len(df))
# print(
#     len(result_df["A_drug_era_id"].unique()) + len(invalid_df["drug_era_id"].unique())
# )
print(len(result_df["A_drug_era_id"].unique()))
# print(len(invalid_df["drug_era_id"].unique()))

In [None]:
print(len(result_df))
# print(len(invalid_df))

test

In [5]:
df_small = df.sample(100000)

In [32]:
result_df_small = get_drug_switches(df_small)

In [None]:
len(result_df_small)

In [15]:
result_df_small, invalid_df_small = get_drug_switches_all(df_small)

In [None]:
print(len(df_small))
print(
    len(result_df_small["A_drug_era_id"].unique())
    + len(invalid_df_small["drug_era_id"].unique())
)
print(len(result_df_small["A_drug_era_id"].unique()))
print(len(invalid_df_small["drug_era_id"].unique()))