In [1]:
import numpy as np
import pandas as pd
import os
import ast
import datetime
from os.path import isfile, join
from tqdm import tqdm
from datetime import datetime, date, timedelta
from dateutil.relativedelta import relativedelta
from openpyxl import load_workbook
from collections import Counter
import warnings 
warnings.filterwarnings('ignore')

os.chdir('G:/Shared drives/CKD_Progression/')

drive = 'G'
main_path = drive + ':/Shared drives/CKD_Progression/data/CKD_COHORT_Jan2010_Mar2024_v3.csv'
data_path = drive + ':/Shared drives/CKD_Progression/data/'
docs_path = drive + ':/Shared drives/CKD_Progression/docs/'
save_path = drive + ':/Shared drives/CKD_Progression/save/'
covariates_path = docs_path + 'covariates.csv'
removecols_path = docs_path + 'remove_columns.csv'

def get_patients_new():
    folder_path = save_path + 'cohort_patients.csv'
    df = pd.read_csv(folder_path, encoding = 'utf-8')
    df = df.drop_duplicates()
    patient_list_flag = df['ENC_HN'].unique().tolist()
    return patient_list_flag

def study_period(df, column, start_date, end_date):
    df[column] = pd.to_datetime(df[column], errors = 'coerce')
    mask = (df[column] >= start_date) & (df[column] <= end_date)
    df = df.loc[mask]
    return df

def exclusion_icd():
    ''' 
    Enlists all ICD codes for the relevant cardiac diseases
    Goal: Remove patients with ICD before CDK3
    '''
    path = docs_path + 'diagnosis and procedure.xlsx'
    sheet_names = ['CVD', 'IHD', 'TIA', 'Hemorrhagic stroke', 'Ischemic stroke', 'Cerebrovascular']
    ICD_CODES_DICT = {}
    for diag in sheet_names:
        df_disease = pd.read_excel(path, sheet_name = diag)
        ICD_CODES_DICT[diag] = df_disease['ICD code'].to_list()
    return ICD_CODES_DICT

def remove_nonexistent(reference, function = 'sum'):
    originals = pd.read_excel(docs_path + 'ms_data_function_ver3.xlsx')
    originals_list = originals[originals['function'] == function]['variable'].tolist()
    reference_list = reference.columns
    return [elem for elem in originals_list if elem in reference_list]

def remove_outliers(df, docs_path = docs_path):
    file_path = docs_path + 'possible_range.xlsx'
    possible_range = pd.read_excel(file_path)
    check_range_columns = possible_range['variable'].tolist()
    upper_values = possible_range['max'].astype(float).tolist()
    lower_values = possible_range['min'].astype(float).tolist()

    for covariate, upper, lower in tqdm(zip(check_range_columns, upper_values, lower_values), 
                                        total = len(check_range_columns), desc = 'Removing outliers'):
        if covariate in df.columns:
            df[covariate] = pd.to_numeric(df[covariate], errors = 'coerce')
            outlier_mask = (df[covariate] < lower) | (df[covariate] > upper)
            df.loc[outlier_mask, covariate] = np.NaN
    return df

def carry_covariates():
    carry_df = pd.read_excel(docs_path + 'ms_data_function_ver3.xlsx')
    forward_list = carry_df[carry_df['carry'] == 'forward']['variable'].tolist()
    forback_list = carry_df[carry_df['carry'] == 'forward_backward']['variable'].tolist()
    lumping_list = carry_df[carry_df['carry'] == 'ignore']['variable'].tolist()
    fllzero_list = carry_df[carry_df['carry'] == 'fill_zero']['variable'].tolist()

    all_columns = da.columns.tolist()
    forward_list = list(set(all_columns).difference(forward_list))
    forback_list = list(set(all_columns).difference(forback_list))
    lumping_list = list(set(all_columns).difference(lumping_list))
    fllzero_list = list(set(all_columns).difference(fllzero_list))

    return forward_list, forback_list, lumping_list, fllzero_list

def carried_values(patient_data):
    forward_list, forback_list, lumping_list, fllzero_list = carry_covariates()
    patient_data[forward_list] = patient_data[forward_list].fillna(method = 'ffill')
    patient_data[forback_list] = patient_data[forback_list].fillna(method = 'ffill')
    patient_data[forback_list] = patient_data[forback_list].fillna(method = 'bfill')
    return patient_data

def determine_outcome(df):
    def update_columns(df, col_name, condition):
        df.loc[condition, col_name] = 1
        df[col_name] = df.groupby(['ENC_HN'])[col_name].ffill().fillna(0)
    condition_patterns = exclusion_icd()
    for condition, patterns in condition_patterns.items():
        pattern_regex = '|'.join(patterns)
        update_columns(df, condition, df['diagnosis_all'].astype(str).str.contains(pattern_regex, na = False))
    df['stroke'] = df[['TIA', 'Hemorrhagic stroke', 'Ischemic stroke']].max(axis = 1)
    df = df.drop(['TIA', 'Hemorrhagic stroke', 'Ischemic stroke'], axis = 1)
    return df


def swap_dates_if_needed(column1, column2):
    if len(column1) != len(column2):
        raise ValueError('Columns must have the same length.')
    column2_copy = column2.copy()
    for i, (value1, value2) in enumerate(zip(column1, column2)):
        if pd.isnull(value1) or pd.isnull(value2):
            continue
        if value2 < value1:
            column2_copy[i] = value1
            column1[i] = value2
    return column2_copy

def check_date_order(column1, column2):
    if len(column1) != len(column2):
        raise ValueError("Columns must have the same length.")
    for value1, value2 in zip(column1, column2):
        if pd.isnull(value1) or pd.isnull(value2):
            continue
        if value1 > value2:
            return False
    return True

def add_one_day(dx, column1, column2):
    dx[column1] = pd.to_datetime(dx[column1])
    dx[column2] = pd.to_datetime(dx[column2])
    same_date_rows = dx[column1].dt.date == dx[column2].dt.date
    dx.loc[same_date_rows, column2] += timedelta(days = 2)
    return dx

def sub_one_day(dx, column1, column2):
    dx[column1] = pd.to_datetime(dx[column1])
    dx[column2] = pd.to_datetime(dx[column2])
    same_date_rows = dx[column1].dt.date > dx[column2].dt.date
    dx.loc[same_date_rows, column1] -= timedelta(days = 1)
    return dx

def sort_dates(row):
    date_cols = [('CKD3A',  row['CKD3A_date']),
                 ('CKD3B',  row['CKD3B_date']),
                 ('CKD4',   row['CKD4_date']),
                 ('CKD5A',  row['CKD5A_date']),
                 ('CKD5B',  row['CKD5B_date']),
                 ('CVD',    row['CVD_date']),
                 ('DEATH',  row['DEAD_date'])]

    date_cols = [(name, date) for name, date in date_cols if not pd.isnull(date)]
    date_cols.sort(key = lambda x: x[1])
    return [name for name, _ in date_cols]


def add_state_column(dataframe, column_date, column_name):
    dataframe[column_name] = dataframe[column_date].notnull().astype(int)
    return dataframe

patients = get_patients_new() 
assert pd.Series(patients).nunique() == 23693

In [4]:
transitions_data = {
    ('CKD3A', 'CVD'): 5184,
    ('CKD3A', 'CKD3B', 'CVD'): 4038,
    ('CKD3A',): 3477,
    ('CKD3A', 'CKD3B'): 2679,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B'): 1840,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'CVD'): 1118,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD'): 1025,
    ('CKD3A', 'CKD3B', 'CKD4'): 880,
    ('CKD3A', 'CVD', 'CKD3B'): 824,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4'): 431,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A'): 269,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD'): 236,
    ('CKD3A', 'CKD3B', 'CKD4', 'DEATH'): 182,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4'): 177,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'DEATH'): 160,
    ('CKD3A', 'CKD3B', 'DEATH'): 135,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH'): 122,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A'): 112,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A'): 87,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'CVD', 'DEATH'): 78,
    ('CKD3A', 'DEATH'): 74,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'DEATH'): 49,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A'): 47,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'CKD5B'): 46,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'DEATH'): 46,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'DEATH'): 45,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'DEATH'): 46,
    ('CKD3A', 'CKD3B', 'CVD', 'DEATH'): 44,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'DEATH'): 36,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B'): 31,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'CKD5B'): 30,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'CKD5B'): 29,
    ('CKD3A', 'CVD', 'DEATH'): 28,
    ('CKD3A', 'CVD', 'CKD3B', 'DEATH'): 18,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH'): 17,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'DEATH'): 16,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'DEATH'): 13,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'CKD5B', 'DEATH'): 10,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'CKD5B', 'DEATH'): 9,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH'): 5
}

transitions_of_interest = [
    ('CKD3A', 'CKD3B'),
    ('CKD3A', 'CKD4'),
    ('CKD3A', 'CKD5A'),
    ('CKD3A', 'CKD5B'),
    ('CKD3A', 'CVD'),
    ('CKD3A', 'DEATH'),
    ('CKD3B', 'CKD4'),
    ('CKD3B', 'CKD5A'),
    ('CKD3B', 'CKD5B'),
    ('CKD3B', 'CVD'),
    ('CKD3B', 'DEATH'),
    ('CKD4', 'CKD5A'),
    ('CKD4', 'CKD5B'),
    ('CKD4', 'CVD'),
    ('CKD4', 'DEATH'),
    ('CKD5A', 'CKD5B'),
    ('CKD5A', 'CVD'),
    ('CKD5A', 'DEATH'),
    ('CKD5B', 'CVD'),
    ('CKD5B', 'DEATH'),
    ('CVD', 'CKD3B'),
    ('CVD', 'CKD4'),
    ('CVD', 'CKD5A'),
    ('CVD', 'CKD5B'),
    ('CVD', 'DEATH')
]

transition_counts = {transition: 0 for transition in transitions_of_interest}
for sequence, count in transitions_data.items():
    for transition in transitions_of_interest:
        if transition[0] in sequence and transition[1] in sequence:
            if sequence.index(transition[0]) < sequence.index(transition[1]):
                transition_counts[transition] += count

for transition, count in transition_counts.items():
    print(f"{transition}: {count}")

('CKD3A', 'CKD3B'): 14930
('CKD3A', 'CKD4'): 7192
('CKD3A', 'CKD5A'): 4392
('CKD3A', 'CKD5B'): 3335
('CKD3A', 'CVD'): 13875
('CKD3A', 'DEATH'): 1133
('CKD3B', 'CKD4'): 7192
('CKD3B', 'CKD5A'): 4392
('CKD3B', 'CKD5B'): 3335
('CKD3B', 'CVD'): 7532
('CKD3B', 'DEATH'): 1031
('CKD4', 'CKD5A'): 4392
('CKD4', 'CKD5B'): 3335
('CKD4', 'CVD'): 2778
('CKD4', 'DEATH'): 834
('CKD5A', 'CKD5B'): 3335
('CKD5A', 'CVD'): 1519
('CKD5A', 'DEATH'): 547
('CKD5B', 'CVD'): 1196
('CKD5B', 'DEATH'): 241
('CVD', 'CKD3B'): 1131
('CVD', 'CKD4'): 961
('CVD', 'CKD5A'): 482
('CVD', 'CKD5B'): 177
('CVD', 'DEATH'): 460


In [6]:
# Define the data as a dictionary
transitions_data = {
    ('CKD3A', 'CVD'): 5184,
    ('CKD3A', 'CKD3B', 'CVD'): 4038,
    ('CKD3A',): 3477,
    ('CKD3A', 'CKD3B'): 2679,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B'): 1840,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'CVD'): 1118,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD'): 1025,
    ('CKD3A', 'CKD3B', 'CKD4'): 880,
    ('CKD3A', 'CVD', 'CKD3B'): 824,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4'): 431,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A'): 269,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD'): 236,
    ('CKD3A', 'CKD3B', 'CKD4', 'DEATH'): 182,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4'): 177,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'DEATH'): 160,
    ('CKD3A', 'CKD3B', 'DEATH'): 135,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH'): 122,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A'): 112,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A'): 87,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'CVD', 'DEATH'): 78,
    ('CKD3A', 'DEATH'): 74,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'DEATH'): 49,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A'): 47,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'CKD5B'): 46,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'DEATH'): 46,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'DEATH'): 45,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'DEATH'): 46,
    ('CKD3A', 'CKD3B', 'CVD', 'DEATH'): 44,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'DEATH'): 36,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B'): 31,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'CKD5B'): 30,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'CKD5B'): 29,
    ('CKD3A', 'CVD', 'DEATH'): 28,
    ('CKD3A', 'CVD', 'CKD3B', 'DEATH'): 18,
    ('CKD3A', 'CKD3B', 'CVD', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH'): 17,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'DEATH'): 16,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'DEATH'): 13,
    ('CKD3A', 'CKD3B', 'CKD4', 'CVD', 'CKD5A', 'CKD5B', 'DEATH'): 10,
    ('CKD3A', 'CKD3B', 'CKD4', 'CKD5A', 'CVD', 'CKD5B', 'DEATH'): 9,
    ('CKD3A', 'CVD', 'CKD3B', 'CKD4', 'CKD5A', 'CKD5B', 'DEATH'): 5
}

# Define the transitions of interest
transitions_of_interest = [
    ('CKD3A', 'CKD3B'),
    ('CKD3A', 'CKD4'),
    ('CKD3A', 'CKD5A'),
    ('CKD3A', 'CKD5B'),
    ('CKD3A', 'CVD'),
    ('CKD3A', 'DEATH'),
    ('CKD3B', 'CKD4'),
    ('CKD3B', 'CKD5A'),
    ('CKD3B', 'CKD5B'),
    ('CKD3B', 'CVD'),
    ('CKD3B', 'DEATH'),
    ('CKD4', 'CKD5A'),
    ('CKD4', 'CKD5B'),
    ('CKD4', 'CVD'),
    ('CKD4', 'DEATH'),
    ('CKD5A', 'CKD5B'),
    ('CKD5A', 'CVD'),
    ('CKD5A', 'DEATH'),
    ('CKD5B', 'CVD'),
    ('CKD5B', 'DEATH'),
    ('CVD', 'CKD3B'),
    ('CVD', 'CKD4'),
    ('CVD', 'CKD5A'),
    ('CVD', 'CKD5B'),
    ('CVD', 'DEATH')
]

# Initialize a dictionary to store the counts for each transition
transition_counts = {transition: 0 for transition in transitions_of_interest}

# Calculate the number of patients for each transition
for sequence, count in transitions_data.items():
    for i in range(len(sequence) - 1):
        transition = (sequence[i], sequence[i + 1])
        if transition in transition_counts:
            transition_counts[transition] += count

# Print the results
for transition, count in transition_counts.items():
    print(f"{transition}: {count}")


('CKD3A', 'CKD3B'): 13799
('CKD3A', 'CKD4'): 0
('CKD3A', 'CKD5A'): 0
('CKD3A', 'CKD5B'): 0
('CKD3A', 'CVD'): 6343
('CKD3A', 'DEATH'): 74
('CKD3B', 'CKD4'): 6520
('CKD3B', 'CKD5A'): 0
('CKD3B', 'CKD5B'): 0
('CKD3B', 'CVD'): 4754
('CKD3B', 'DEATH'): 153
('CKD4', 'CKD5A'): 4204
('CKD4', 'CKD5B'): 0
('CKD4', 'CVD'): 1259
('CKD4', 'DEATH'): 241
('CKD5A', 'CKD5B'): 3297
('CKD5A', 'CVD'): 323
('CKD5A', 'DEATH'): 257
('CKD5B', 'CVD'): 1196
('CKD5B', 'DEATH'): 163
('CVD', 'CKD3B'): 1131
('CVD', 'CKD4'): 672
('CVD', 'CKD5A'): 188
('CVD', 'CKD5B'): 38
('CVD', 'DEATH'): 245
