In [2]:
import numpy as np
import pandas as pd
import os
import ast
import datetime
from os.path import isfile, join
from tqdm import tqdm
from datetime import datetime, date, timedelta
from dateutil.relativedelta import relativedelta
from openpyxl import load_workbook
from collections import Counter
import warnings 
warnings.filterwarnings('ignore')

os.chdir('G:/Shared drives/CKD_Progression/')

drive = 'G'
main_path = drive + ':/Shared drives/CKD_Progression/data/CKD_COHORT_Jan2010_Mar2024_v3.csv'
data_path = drive + ':/Shared drives/CKD_Progression/data/'
docs_path = drive + ':/Shared drives/CKD_Progression/docs/'
save_path = drive + ':/Shared drives/CKD_Progression/save/'
covariates_path = docs_path + 'covariates.csv'
removecols_path = docs_path + 'remove_columns.csv'

def get_patients_new():
    folder_path = save_path + 'cohort_patients.csv'
    df = pd.read_csv(folder_path, encoding = 'utf-8')
    df = df.drop_duplicates()
    patient_list_flag = df['ENC_HN'].unique().tolist()
    return patient_list_flag

def study_period(df, column, start_date, end_date):
    df[column] = pd.to_datetime(df[column], errors = 'coerce')
    mask = (df[column] >= start_date) & (df[column] <= end_date)
    df = df.loc[mask]
    return df

def exclusion_icd():
    ''' 
    Enlists all ICD codes for the relevant cardiac diseases
    Goal: Remove patients with ICD before CDK3
    '''
    path = docs_path + 'diagnosis and procedure.xlsx'
    sheet_names = ['CVD', 'IHD', 'TIA', 'Hemorrhagic stroke', 'Ischemic stroke', 'Cerebrovascular']
    ICD_CODES_DICT = {}
    for diag in sheet_names:
        df_disease = pd.read_excel(path, sheet_name = diag)
        ICD_CODES_DICT[diag] = df_disease['ICD code'].to_list()
    return ICD_CODES_DICT

def remove_nonexistent(reference, function = 'sum'):
    originals = pd.read_excel(docs_path + 'ms_data_function_ver3.xlsx')
    originals_list = originals[originals['function'] == function]['variable'].tolist()
    reference_list = reference.columns
    return [elem for elem in originals_list if elem in reference_list]

def remove_outliers(df, docs_path = docs_path):
    file_path = docs_path + 'possible_range.xlsx'
    possible_range = pd.read_excel(file_path)
    check_range_columns = possible_range['variable'].tolist()
    upper_values = possible_range['max'].astype(float).tolist()
    lower_values = possible_range['min'].astype(float).tolist()

    for covariate, upper, lower in tqdm(zip(check_range_columns, upper_values, lower_values), 
                                        total = len(check_range_columns), desc = 'Removing outliers'):
        if covariate in df.columns:
            df[covariate] = pd.to_numeric(df[covariate], errors = 'coerce')
            outlier_mask = (df[covariate] < lower) | (df[covariate] > upper)
            df.loc[outlier_mask, covariate] = np.NaN
    return df

def carry_covariates():
    carry_df = pd.read_excel(docs_path + 'ms_data_function_ver3.xlsx')
    forward_list = carry_df[carry_df['carry'] == 'forward']['variable'].tolist()
    forback_list = carry_df[carry_df['carry'] == 'forward_backward']['variable'].tolist()
    lumping_list = carry_df[carry_df['carry'] == 'ignore']['variable'].tolist()
    fllzero_list = carry_df[carry_df['carry'] == 'fill_zero']['variable'].tolist()

    all_columns = da.columns.tolist()
    forward_list = list(set(all_columns).difference(forward_list))
    forback_list = list(set(all_columns).difference(forback_list))
    lumping_list = list(set(all_columns).difference(lumping_list))
    fllzero_list = list(set(all_columns).difference(fllzero_list))

    return forward_list, forback_list, lumping_list, fllzero_list

def carried_values(patient_data):
    forward_list, forback_list, lumping_list, fllzero_list = carry_covariates()
    patient_data[forward_list] = patient_data[forward_list].fillna(method = 'ffill')
    patient_data[forback_list] = patient_data[forback_list].fillna(method = 'ffill')
    patient_data[forback_list] = patient_data[forback_list].fillna(method = 'bfill')
    return patient_data

def determine_outcome(df):
    def update_columns(df, col_name, condition):
        df.loc[condition, col_name] = 1
        df[col_name] = df.groupby(['ENC_HN'])[col_name].ffill().fillna(0)
    condition_patterns = exclusion_icd()
    for condition, patterns in condition_patterns.items():
        pattern_regex = '|'.join(patterns)
        update_columns(df, condition, df['diagnosis_all'].astype(str).str.contains(pattern_regex, na = False))
    df['stroke'] = df[['TIA', 'Hemorrhagic stroke', 'Ischemic stroke']].max(axis = 1)
    df = df.drop(['TIA', 'Hemorrhagic stroke', 'Ischemic stroke'], axis = 1)
    return df


def swap_dates_if_needed(column1, column2):
    if len(column1) != len(column2):
        raise ValueError('Columns must have the same length.')
    column2_copy = column2.copy()
    for i, (value1, value2) in enumerate(zip(column1, column2)):
        if pd.isnull(value1) or pd.isnull(value2):
            continue
        if value2 < value1:
            column2_copy[i] = value1
            column1[i] = value2
    return column2_copy

def check_date_order(column1, column2):
    if len(column1) != len(column2):
        raise ValueError("Columns must have the same length.")
    for value1, value2 in zip(column1, column2):
        if pd.isnull(value1) or pd.isnull(value2):
            continue
        if value1 > value2:
            return False
    return True

def add_one_day(dx, column1, column2):
    dx[column1] = pd.to_datetime(dx[column1])
    dx[column2] = pd.to_datetime(dx[column2])
    same_date_rows = dx[column1].dt.date == dx[column2].dt.date
    dx.loc[same_date_rows, column2] += timedelta(days = 2)
    return dx

def sub_one_day(dx, column1, column2):
    dx[column1] = pd.to_datetime(dx[column1])
    dx[column2] = pd.to_datetime(dx[column2])
    same_date_rows = dx[column1].dt.date > dx[column2].dt.date
    dx.loc[same_date_rows, column1] -= timedelta(days = 1)
    return dx

def sort_dates(row):
    date_cols = [('CKD3A',  row['CKD3A_date']),
                 ('CKD3B',  row['CKD3B_date']),
                 ('CKD4',   row['CKD4_date']),
                 ('CKD5A',  row['CKD5A_date']),
                 ('CKD5B',  row['CKD5B_date']),
                 ('CVD',    row['CVD_date']),
                 ('DEATH',  row['DEAD_date'])]

    date_cols = [(name, date) for name, date in date_cols if not pd.isnull(date)]
    date_cols.sort(key = lambda x: x[1])
    return [name for name, _ in date_cols]


def add_state_column(dataframe, column_date, column_name):
    dataframe[column_name] = dataframe[column_date].notnull().astype(int)
    return dataframe

patients = get_patients_new() 
assert pd.Series(patients).nunique() == 23693

In [3]:
data = pd.read_csv(main_path, encoding = 'utf-8')
data = data[data['ENC_HN'].isin(patients)] 

In [None]:
dx = data[['ENC_HN', 'visit_date', 'CKD_stage', 'CVD','death_date']]
dx['visit_date'] = pd.to_datetime(dx['visit_date'])
dx['death_date'] = pd.to_datetime(dx['death_date'])

dx.loc[(dx['CKD_stage'] >= 'stage_3a'),    'CKD3A_date'] = dx['visit_date']
dx.loc[(dx['CKD_stage'] >= 'stage_3b'),    'CKD3B_date'] = dx['visit_date']
dx.loc[(dx['CKD_stage'] >= 'stage_4'),      'CKD4_date'] = dx['visit_date']
dx.loc[(dx['CKD_stage'] >= 'stage_5_15'),  'CKD5A_date'] = dx['visit_date']
dx.loc[(dx['CKD_stage'] >= 'stage_5_6'),   'CKD5B_date'] = dx['visit_date']
dx['CVD_date'] = [visit if cvd == 1.0 else np.NaN for visit, cvd in zip(dx['visit_date'], dx['CVD'])]
dx['DEAD_date'] = dx['death_date'].copy()

In [None]:
dx['CKD3A']  = [1 if isinstance(date, (pd._libs.tslibs.timestamps.Timestamp)) else 0 for date in dx['CKD3A_date']]
dx['CKD3B']  = [1 if isinstance(date, (pd._libs.tslibs.timestamps.Timestamp)) else 0 for date in dx['CKD3B_date']]
dx['CKD4']   = [1 if isinstance(date, (pd._libs.tslibs.timestamps.Timestamp)) else 0 for date in dx['CKD4_date']]
dx['CKD5A']  = [1 if isinstance(date, (pd._libs.tslibs.timestamps.Timestamp)) else 0 for date in dx['CKD5A_date']]
dx['CKD5B']  = [1 if isinstance(date, (pd._libs.tslibs.timestamps.Timestamp)) else 0 for date in dx['CKD5B_date']]
dx['DEATH']  = [1 if isinstance(date, (pd._libs.tslibs.timestamps.Timestamp)) else 0 for date in dx['death_date']]

In [None]:
dx = dx.groupby(['ENC_HN'])[['visit_date', 'CKD_stage', 'DEATH', 'CKD3A_date', 'CKD3B_date', 'CKD4_date', 'CKD5A_date', 'CKD5B_date', 'CVD_date', 'DEAD_date']].min().reset_index()

for CKD in ['CKD3A_date', 'CKD3B_date', 'CKD4_date', 'CKD5A_date', 'CKD5B_date']:
    mask_cvd = (dx['CVD_date'] == dx[CKD])
    dx.loc[mask_cvd, 'CVD_date'] += timedelta(days = 1)

In [None]:
dx = add_one_day(dx, 'CVD_date', 'DEAD_date')
check_date_order(dx['CVD_date'], dx['DEAD_date'])

False

In [None]:
dx['transition'] = dx.apply(sort_dates, axis = 1)
dx['transition'] = dx['transition'].apply(tuple)

In [None]:
dx = add_state_column(dataframe = dx, column_date = 'CKD3A_date', column_name = 'CKD3A_status')
dx = add_state_column(dataframe = dx, column_date = 'CKD3B_date', column_name = 'CKD3B_status')
dx = add_state_column(dataframe = dx, column_date = 'CKD4_date',  column_name = 'CKD04_status')
dx = add_state_column(dataframe = dx, column_date = 'CKD5A_date', column_name = 'CKD5A_status')
dx = add_state_column(dataframe = dx, column_date = 'CKD5B_date', column_name = 'CKD5B_status')
dx = add_state_column(dataframe = dx, column_date = 'CVD_date',   column_name = 'CVD00_status')
dx = add_state_column(dataframe = dx, column_date = 'DEAD_date',  column_name = 'DEATH_status')

In [None]:
transition = dx['transition'].unique()
last_recent_date = pd.to_datetime('12-31-2023')

In [126]:
# 4  CKD3A
# 5  CKD3B
# 6  CKD4
# 7  CKD5A
# 8  CKD5B
# 9  CVD
# 10 DEATH

for idx in tqdm(range(0, len(dx))):
    # ('CKD3A', 'CVD')
    if transition[0] == dx.iloc[idx, 11]:
        dx.iloc[idx,  5] = dx.iloc[idx, 9]
        dx.iloc[idx,  6] = dx.iloc[idx, 9]
        dx.iloc[idx,  7] = dx.iloc[idx, 9]
        dx.iloc[idx,  8] = dx.iloc[idx, 9]
        dx.iloc[idx, 10] = last_recent_date

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD')
    if transition[1] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date

    # ('CKD3A',)
    if transition[2] == dx.iloc[idx, 11]:
        dx.iloc[idx,  5] = last_recent_date
        dx.iloc[idx,  6] = last_recent_date
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx,  9] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date     

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B')
    if transition[3] == dx.iloc[idx, 11]:
        dx.iloc[idx,  9] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date 

    # ('CKD3A', 'CKD3B', 'CVD')
    if transition[4] == dx.iloc[idx, 11]:
        dx.iloc[idx,  6] = last_recent_date
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date   

    # ('CKD3A', 'CKD3B')
    if transition[5] == dx.iloc[idx, 11]:
        dx.iloc[idx,  6] = last_recent_date
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx,  9] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date   

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'CVD')
    if transition[6] == dx.iloc[idx, 11]:
        dx.iloc[idx, 10] = last_recent_date   

    # ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A')
    if transition[7] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = dx.iloc[idx, 7]
        dx.iloc[idx, 10] = last_recent_date  

    # ('CKD3A', 'CKD3B', 'CKD4', 'CVD')
    if transition[8] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date 

    # ('CKD3A', 'CKD3B', 'CVD', 'CKD4')
    if transition[9] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date   

    # ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'CKD5B')
    if transition[10] == dx.iloc[idx, 11]:
        dx.iloc[idx, 10] = last_recent_date 

    # ('CKD3A', 'CKD3B', 'CKD4', 'DEATH')
    if transition[11] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx,  9] = last_recent_date

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'CKD5B')
    if transition[12] == dx.iloc[idx, 11]:
        dx.iloc[idx, 10] = last_recent_date  

    # ('CKD3A', 'CKD3B', 'DEATH')
    if transition[13] == dx.iloc[idx, 11]:
        dx.iloc[idx,  6] = dx.iloc[idx, 10]
        dx.iloc[idx,  7] = dx.iloc[idx, 10]
        dx.iloc[idx,  8] = dx.iloc[idx, 10]
        dx.iloc[idx,  9] = dx.iloc[idx, 10]

    # ('CKD3A', 'CVD', 'CKD3B')
    if transition[14] == dx.iloc[idx, 11]:
        dx.iloc[idx,  6] = last_recent_date
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date   

    # ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'DEATH')
    if transition[15] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = last_recent_date  

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'DEATH')
    if transition[16] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = dx.iloc[idx, 10]  
        dx.iloc[idx,  9] = dx.iloc[idx, 10]  

    # ('CKD3A', 'CVD', 'CKD3B', 'CKD4')
    if transition[17] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date 

    # ('CKD3A', 'CKD3B', 'CKD4')
    if transition[18] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = last_recent_date
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx,  9] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date     

    # ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'DEATH')
    if transition[19] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = dx.iloc[idx, 10]
        dx.iloc[idx,  8] = dx.iloc[idx, 10]

    # ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'DEATH')
    if transition[20] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = dx.iloc[idx, 10]

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A')
    if transition[21] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx,  9] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date  

    # ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A')
    if transition[22] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date    

    # ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'DEATH')
    if transition[23] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = dx.iloc[idx, 10]

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'DEATH')
    if transition[24] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = dx.iloc[idx, 10]   

    # ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'CKD5B')
    if transition[25] == dx.iloc[idx, 11]:
        dx.iloc[idx, 10] = last_recent_date   

    # ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'CKD5B', 'DEATH')
    if transition[26] == dx.iloc[idx, 11]:
        pass   

    # ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A')
    if transition[27] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = last_recent_date
        dx.iloc[idx, 10] = last_recent_date    

    # ('CKD3A', 'CKD3B', 'CVD', 'DEATH')
    if transition[28] == dx.iloc[idx, 11]:
        dx.iloc[idx,  6] = dx.iloc[idx, 10]
        dx.iloc[idx,  7] = dx.iloc[idx, 10]
        dx.iloc[idx,  8] = dx.iloc[idx, 10] 

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'CVD', 'DEATH')
    if transition[29] == dx.iloc[idx, 11]:
        pass  

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH')
    if transition[30] == dx.iloc[idx, 11]:
        dx.iloc[idx,  9] = dx.iloc[idx, 10]

    # ('CKD3A', 'CVD', 'DEATH')
    if transition[31] == dx.iloc[idx, 11]:
        dx.iloc[idx,  5] = dx.iloc[idx, 10]
        dx.iloc[idx,  6] = dx.iloc[idx, 10]
        dx.iloc[idx,  7] = dx.iloc[idx, 10]
        dx.iloc[idx,  8] = dx.iloc[idx, 10]

    # ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B')
    if transition[32] == dx.iloc[idx, 11]:
        dx.iloc[idx, 10] = last_recent_date   

    # ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'DEATH')
    if transition[33] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = dx.iloc[idx, 10]
        dx.iloc[idx,  8] = dx.iloc[idx, 10]  

    # ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'DEATH')
    if transition[34] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = dx.iloc[idx, 10]
        dx.iloc[idx,  8] = dx.iloc[idx, 10]  

    # ('CKD3A', 'CVD', 'CKD3B', 'DEATH')
    if transition[35] == dx.iloc[idx, 11]:
        dx.iloc[idx,  6] = dx.iloc[idx, 10]
        dx.iloc[idx,  7] = dx.iloc[idx, 10]
        dx.iloc[idx,  8] = dx.iloc[idx, 10]  

    # ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH')
    if transition[36] == dx.iloc[idx, 11]:
        pass     

    # ('CKD3A', 'DEATH')
    if transition[37] == dx.iloc[idx, 11]:
        dx.iloc[idx,  5] = dx.iloc[idx, 10]
        dx.iloc[idx,  6] = dx.iloc[idx, 10]
        dx.iloc[idx,  7] = dx.iloc[idx, 10]
        dx.iloc[idx,  8] = dx.iloc[idx, 10]
        dx.iloc[idx,  9] = dx.iloc[idx, 10]  

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'DEATH', 'CVD')
    if transition[38] == dx.iloc[idx, 11]:
        dx.iloc[idx,  8] = dx.iloc[idx, 10] + timedelta(days = 1)  
        dx.iloc[idx, 10] = dx.iloc[idx, 10] + timedelta(days = 1)

    # ('CKD3A', 'CKD3B', 'CKD4', 'DEATH', 'CVD')
    if transition[39] == dx.iloc[idx, 11]:
        dx.iloc[idx,  7] = dx.iloc[idx, 10] + timedelta(days = 1)
        dx.iloc[idx,  8] = dx.iloc[idx, 10] + timedelta(days = 1)
        dx.iloc[idx, 10] = dx.iloc[idx, 10] + timedelta(days = 1)     

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'CKD5B', 'DEATH')
    if transition[40] == dx.iloc[idx, 11]:
        pass

    # ('CKD3A', 'CKD3B', 'DEATH', 'CVD')
    if transition[41] == dx.iloc[idx, 11]:
        dx.iloc[idx,  6] = dx.iloc[idx, 10] + timedelta(days = 1)
        dx.iloc[idx,  7] = dx.iloc[idx, 10] + timedelta(days = 1)
        dx.iloc[idx,  8] = dx.iloc[idx, 10] + timedelta(days = 1)
        dx.iloc[idx, 10] = dx.iloc[idx, 10] + timedelta(days = 1)   

    # ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH')
    if transition[42] == dx.iloc[idx, 11]:
        pass

    # ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH', 'CVD')
    if transition[43] == dx.iloc[idx, 11]:
        dx.iloc[idx, 10] = dx.iloc[idx, 10] + timedelta(days = 1)

100%|██████████| 23693/23693 [00:45<00:00, 515.10it/s]


In [127]:
transitions = [
    'CKD3A_CKD3B', 'CKD3A_CKD4', 'CKD3A_CKD5A', 'CKD3A_CKD5B', 'CKD3A_CVD', 'CKD3A_DEATH',
    'CKD3B_CKD4', 'CKD3B_CKD5A', 'CKD3B_CKD5B', 'CKD3B_CVD', 'CKD3B_DEATH',
    'CKD4_CKD5A', 'CKD4_CKD5B', 'CKD4_CVD', 'CKD4_DEATH',
    'CKD5A_CKD5B', 'CKD5A_CVD', 'CKD5A_DEATH',
    'CKD5B_CVD', 'CKD5B_DEATH',
    'CVD_CKD3B', 'CVD_CKD4', 'CVD_CKD5A', 'CVD_CKD5B', 'CVD_DEATH']

for transition in transitions:
    dx[transition] = 0

def populate_transitions(row):
    states = row['transition']
    for i in range(len(states) - 1):
        transition = f"{states[i]}_{states[i+1]}"
        if transition in dx.columns:
            row[transition] = 1
    return row

dx = dx.apply(populate_transitions, axis = 1)

In [91]:
def calculate_month_difference(dataframe, date_column1, date_column2, new_column_name):
    dataframe[new_column_name] = dataframe.apply(
        lambda row: relativedelta(row[date_column2], row[date_column1]).years * 12 +
                    relativedelta(row[date_column2], row[date_column1]).months,
        axis = 1)
    return dataframe

def convert_negative_to_zero(dataframe):
    df = dataframe.copy()
    for column in df.columns[19:]:
        if pd.api.types.is_numeric_dtype(df[column]):
            df[column] = df[column].apply(lambda x: 0 if x < 0 else x)
    return df

execute = True
if execute:
    dx = calculate_month_difference(dx, 'CKD3A_date', 'CKD3B_date', 'ckd3a_to_ckd3b_months')
    dx = calculate_month_difference(dx, 'CKD3A_date', 'CKD4_date',  'ckd3a_to_ckd04_months')
    dx = calculate_month_difference(dx, 'CKD3A_date', 'CKD5A_date', 'ckd3a_to_ckd5a_months')
    dx = calculate_month_difference(dx, 'CKD3A_date', 'CKD5B_date', 'ckd3a_to_ckd5b_months')
    dx = calculate_month_difference(dx, 'CKD3A_date', 'CVD_date',   'ckd3a_to_cvd00_months')
    dx = calculate_month_difference(dx, 'CKD3A_date', 'DEAD_date',  'ckd3a_to_death_months')

    dx = calculate_month_difference(dx, 'CKD3B_date', 'CKD4_date',  'ckd3b_to_ckd04_months')
    dx = calculate_month_difference(dx, 'CKD3B_date', 'CKD5A_date', 'ckd3b_to_ckd5a_months')
    dx = calculate_month_difference(dx, 'CKD3B_date', 'CKD5B_date', 'ckd3b_to_ckd5b_months')
    dx = calculate_month_difference(dx, 'CKD3B_date', 'CVD_date',   'ckd3b_to_cvd00_months')
    dx = calculate_month_difference(dx, 'CKD3B_date', 'DEAD_date',  'ckd3b_to_death_months')

    dx = calculate_month_difference(dx, 'CKD4_date',  'CKD5A_date', 'ckd04_to_ckd5a_months')
    dx = calculate_month_difference(dx, 'CKD4_date',  'CKD5B_date', 'ckd04_to_ckd5b_months')
    dx = calculate_month_difference(dx, 'CKD4_date',  'CVD_date',   'ckd04_to_cvd00_months')
    dx = calculate_month_difference(dx, 'CKD4_date',  'DEAD_date',  'ckd04_to_death_months')

    dx = calculate_month_difference(dx, 'CKD5A_date', 'CKD5B_date', 'ckd5a_to_ckd5b_months')
    dx = calculate_month_difference(dx, 'CKD5A_date', 'CVD_date',   'ckd5a_to_cvd00_months')
    dx = calculate_month_difference(dx, 'CKD5A_date', 'DEAD_date',  'ckd5a_to_death_months')

    dx = calculate_month_difference(dx, 'CKD5B_date', 'CVD_date',   'ckd5b_to_cvd00_months')
    dx = calculate_month_difference(dx, 'CKD5B_date', 'DEAD_date',  'ckd5b_to_death_months')

    dx = calculate_month_difference(dx, 'CVD_date',   'CKD3B_date', 'cvd00_to_ckd3b_months')
    dx = calculate_month_difference(dx, 'CVD_date',   'CKD4_date',  'cvd00_to_ckd04_months')
    dx = calculate_month_difference(dx, 'CVD_date',   'CKD5A_date', 'cvd00_to_ckd5a_months')
    dx = calculate_month_difference(dx, 'CVD_date',   'CKD5B_date', 'cvd00_to_ckd5b_months')
    dx = calculate_month_difference(dx, 'CVD_date',   'DEAD_date',  'cvd00_to_death_months')

    dx = convert_negative_to_zero(dx)

In [128]:
dx.to_csv(save_path + 'multistage_ver004.csv', index = False)