In [2]:
cd ..

d:\min\research_projects\MedicalDataKit


In [69]:
import numpy as np
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
import pandas as pd
from typing import Union
from sklearn.impute import SimpleImputer
from sklearn.model_selection import StratifiedKFold


class MechTester:

    def __init__(
        self, 
        use_self: bool = False,
        model: str = 'logitcv',
        perform_permutation: bool = False,
        permutation_repetitions: int = 10,
    ):
        
        self.use_self = use_self
        self.model = model
        self.permutation_repetitions = permutation_repetitions
        self.perform_permutation = perform_permutation

    def compute_mech_params(self, data: Union[np.ndarray, pd.DataFrame]):
        
        if isinstance(data, pd.DataFrame):
            data = data.to_numpy()
        
        data = data.copy()
        data_dim = data.shape[1]
        mask = np.isnan(data)
        
        # impute missing values
        imputer = SimpleImputer(strategy='mean')
        data = imputer.fit_transform(data)
        
        T_list  = []
        T_perm_list = []       
        for col_idx in range(data_dim):

            if mask[:, col_idx].sum() == 0 or len(np.unique(mask[:, col_idx])) == 1:
                continue
            
            # Fit model on mask
            if self.use_self:
                X = data
                y = mask[:, col_idx]
            else:
                X = np.concatenate([data[:, :col_idx], data[:, col_idx+1:]], axis=1)
                y = mask[:, col_idx]
            
            if self.model == 'logitcv':
                model = LogisticRegressionCV(cv=StratifiedKFold(n_splits=5), penalty='l2', max_iter=2000, solver='lbfgs')
            elif self.model == 'logit':
                model = LogisticRegression(penalty='l2', max_iter=2000, solver='lbfgs')
            else:
                raise ValueError(f"Model {self.model} not supported")
            
            model.fit(X, y)
            
            params = model.coef_.flatten()
            T = np.linalg.norm(params, ord=2)
            T_list.append(T)
            
            # Permute the missing mask
            if self.perform_permutation:
                T_perms = []
                for _ in range(self.permutation_repetitions):
                    mask_perm = np.random.permutation(mask[:, col_idx].copy())
                    y_perm = mask_perm
                    
                    model.fit(X, y_perm)
                    params = model.coef_.flatten()
                    T_perm = np.linalg.norm(params, ord=2)
                    T_perms.append(T_perm)
                
                T_perm_list.append(np.mean(T_perms))
            else:
                T_perm_list.append(0)

        return np.array(T_list), np.array(T_perm_list)

In [1]:
import numpy as np
import pandas as pd
from MedDataKit.utils import column_check

In [37]:
%load_ext autoreload
%autoreload 2
from MedDataKit.dataset.clinical_dataset import ARI2Dataset

dataset = ARI2Dataset()
data = dataset.load_raw_data()
dataset.generate_ml_task_dataset(
    task_name = 'predict_Y_death', 
    config = {'numerical_encoding': 'quantile', 'missing_strategy': 'impute_cat'}, 
    verbose = True
)
dataset.ml_task_dataset.show_dataset_info()
data = dataset.ml_task_dataset.data

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
{'missing_strategy': 'impute_cat', 'missing_drop_thres': 0.6, 'ordinal_as_numerical': False, 'categorical_encoding': 'ordinal', 'numerical_encoding': 'quantile', 'drop_unused_targets': True}
Raw data shape:  (4552, 97)
After setting target feature:  (4552, 96)
After feature engineering:  (4552, 96)
After handling missing data:  (4552, 94)
Final ml task dataset shape:  (4552, 94)
Task name: predict_Y_death  Task type: classification
Target: Y_death Num classes: 4
Data Shape: (4552, 94) (num 19 cat 74)
Missing ratio:  4.2%
Feature Groups:
    - demongraphic: 15 features (e.g., afe,country,smi2 ... nut,wam,mvm)
    - past_history: 22 features (e.g., twb,hcm,clin ... saogp,hcs,hfa)
    - lab_vital_signs: 17 features (e.g., rr,sickj,sickl ... hrat,s3,temp)
    - clinical_status: 39 features (e.g., slpm,ova,gru ... abb,wake,str)


In [40]:
for country in data['country'].unique():
    print(country)
    data_item = data[data['country'] == country]
    print(data_item.isna().sum().sum(), sum(data_item.isna().sum() > 0))
    mech_tester = MechTester(use_self=True, model='logitcv', perform_permutation=True, permutation_repetitions=5)
    ret = mech_tester.compute_mech_params(data_item)
    print(sum(ret[0]), sum(ret[1]))


0.0
3044 9
11.7196926891298 0.7392375222455289
1.0
2111 10
24.31855691981666 0.557142508743546
2.0
1901 8
7.5153349757669226 0.3786602282498723
3.0
10752 8
43.75980576114043 0.7457309815566736


In [14]:
import pandas as pd

In [41]:
data = pd.read_csv('data/boldlink/bold_dataset.csv')

In [42]:
data.shape

(49093, 142)

In [None]:
data_eicu = data[data['source_db'] == 'eicu']

center_list = data_eicu['hospitalid'].unique()
for center in center_list:
    data_center = data_eicu[data_eicu['hospitalid'] == center]
    data_shape = data_center.shape[0]*data_center.shape[1]
    if len(data_center) > 100:
    	print(center, len(data_center), data_center.isna().sum().sum()/data_shape, sum(data_center.isna().sum() > 0))

71 104 0.2917118093174431 110
73 1293 0.29114516954783615 111
79 650 0.25730227518959914 112
92 124 0.2613584734211722 110
142 809 0.22093873500583228 111
143 227 0.2388782031395421 110
140 177 0.187992360945333 109
141 405 0.23974960876369328 110
144 228 0.1881640721522115 110
154 427 0.22070125672065177 111
175 109 0.22070034888228454 111
165 697 0.21779457231192031 112
171 315 0.22615694164989938 112
176 815 0.21385120539186037 113
148 866 0.21363399798328075 112
167 1360 0.24033761391880695 113
157 563 0.2266529907687689 112
155 178 0.25486627630954267 110
146 263 0.2474963851550367 112
152 499 0.2651782438115668 112
183 639 0.23013511428508454 111
181 302 0.20832944687995522 109
184 273 0.20507661352731776 110
180 143 0.18029153944646903 103
188 1263 0.32848795066519465 111
197 290 0.2575036425449247 109
195 275 0.22332906530089627 111
198 371 0.2810447591207623 108
206 310 0.20231712857791911 111
199 795 0.2223757640180707 111
208 1379 0.3124891480864885 112
210 175 0.28631790744

In [25]:
center_list = data['source_db'].unique()
for center in center_list:
    data_item = data[data['source_db'] == center]
    data_shape = data_item.shape[0]*data_item.shape[1]
    if len(data_item) > 100:
    	print(center, len(data_item), data_item.isna().sum().sum()/data_shape, sum(data_item.isna().sum() > 0))

eicu 43438 0.25843617809810193 115
mimic_iii 740 0.18859916254282452 112
mimic_iv 4915 0.15167996790509078 111


In [45]:
data = pd.read_csv('data/eicu_v1/data.csv')
data_patient = pd.read_csv('data/eicu_v1/patient.csv')

data = pd.merge(data, data_patient[['patientunitstayid', 'hospitalid']], on='patientunitstayid', how='inner')

In [51]:
data = data.drop(columns=[
    'uniquepid', 'patienthealthsystemstayid', 'patientunitstayid'], axis=1)

In [54]:
data.shape

(200859, 41)

In [70]:
from sklearn.preprocessing import MinMaxScaler, StandardScaler

center_list = data['hospitalid'].unique()
for center in center_list:
	data_center = data[data['hospitalid'] == center]
	data_shape = data_center.shape[0]*data_center.shape[1]
	if len(data_center) > 100 and len(data_center) <= 2000:
		X = data_center.copy()
		scaler = StandardScaler()
		X = scaler.fit_transform(X)
		scaler = MinMaxScaler()
		X = scaler.fit_transform(X)
		mech_tester = MechTester(
    		use_self=True, model='logitcv', perform_permutation=False, permutation_repetitions=5
		)
		ret = mech_tester.compute_mech_params(X)
		print(
			center, len(data_center), 
			data_center.isna().sum().sum()/data_shape, 
			sum(data_center.isna().sum() > 0), sum(ret[0])
		)

66 1002 0.3056813202862568 40 327.0913803378128
63 1826 0.39296877087062215 40 973.3343245666113
68 692 0.38791766530382066 40 382.281697331558
59 854 0.3201576512252242 40 472.1167859402534
58 321 0.3925233644859813 40 249.80306473238113
71 1021 0.35450658130479445 40 589.3685544709274
60 458 0.3236766428799659 40 1858.092429173857
56 325 0.33951219512195124 40 234.79462391307914
69 697 0.33901389229100326 40 445.5697758905933
67 496 0.28314319433516916 40 510.49494248287624
61 233 0.301685334449911 40 36.631153626161115
79 1313 0.198502777106979 40 188.1489443074869
108 354 0.2049056083781177 40 139.2288286195742
92 586 0.25838674769000247 40 159.9686946103424
95 256 0.2461890243902439 40 252.72084272637727
112 373 0.62499182632577 40 623.0427467103264
133 228 0.6114676936243046 40 668.3480127402592
131 131 0.43641779929249674 40 865.2397717137532
120 111 0.6011865524060646 40 1002.1674837853081
125 247 0.40090846252592083 40 1271.0246466689475
138 176 0.6632483370288248 40 2084.3832