In [1]:
import numpy as np
import pandas as pd
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import KNNImputer, IterativeImputer

In [2]:
qual_cols = { "age", "TSH", "T3", "TT4", "T4U", "FTI", "TBG" }
cat_cols = { "sex", "on_thyroxine", "on_antithyroid_meds", "sick", "pregnant", "thyroid_surgery", "I131_treatment", "lithium", "goitre", "tumor", "psych", "referral_source", "target" }

In [3]:
remove_qual_cols = { "TBG" }
remove_cat_cols = set()

In [4]:
k = 10
neighbours = 2
iters = 10000

In [5]:
seed = 42

In [6]:
random_state = np.random.RandomState(seed)
random_generator = np.random.default_rng(seed)

In [7]:
sel_qual_cols = list(qual_cols.difference(remove_qual_cols))
sel_cat_cols = list(cat_cols.difference(remove_cat_cols))
display(sel_qual_cols)
display(sel_cat_cols)

['age', 'T3', 'T4U', 'TSH', 'TT4', 'FTI']

['sick',
 'sex',
 'thyroid_surgery',
 'lithium',
 'goitre',
 'referral_source',
 'on_antithyroid_meds',
 'tumor',
 'pregnant',
 'I131_treatment',
 'psych',
 'target',
 'on_thyroxine']

In [8]:
df = pd.read_csv("./data/thyroidDF.csv")
df

Unnamed: 0,age,sex,on_thyroxine,query_on_thyroxine,on_antithyroid_meds,sick,pregnant,thyroid_surgery,I131_treatment,query_hypothyroid,...,TT4,T4U_measured,T4U,FTI_measured,FTI,TBG_measured,TBG,referral_source,target,patient_id
0,29,F,f,f,f,f,f,f,f,t,...,,f,,f,,f,,other,-,840801013
1,29,F,f,f,f,f,f,f,f,f,...,128.0,f,,f,,f,,other,-,840801014
2,41,F,f,f,f,f,f,f,f,f,...,,f,,f,,t,11.0,other,-,840801042
3,36,F,f,f,f,f,f,f,f,f,...,,f,,f,,t,26.0,other,-,840803046
4,32,F,f,f,f,f,f,f,f,f,...,,f,,f,,t,36.0,other,S,840803047
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9167,56,M,f,f,f,f,f,f,f,f,...,64.0,t,0.83,t,77.0,f,,SVI,-,870119022
9168,22,M,f,f,f,f,f,f,f,f,...,91.0,t,0.92,t,99.0,f,,SVI,-,870119023
9169,69,M,f,f,f,f,f,f,f,f,...,113.0,t,1.27,t,89.0,f,,SVI,I,870119025
9170,47,F,f,f,f,f,f,f,f,f,...,75.0,t,0.85,t,88.0,f,,other,-,870119027


In [9]:
df.dropna()

Unnamed: 0,age,sex,on_thyroxine,query_on_thyroxine,on_antithyroid_meds,sick,pregnant,thyroid_surgery,I131_treatment,query_hypothyroid,...,TT4,T4U_measured,T4U,FTI_measured,FTI,TBG_measured,TBG,referral_source,target,patient_id
167,40,F,f,f,f,f,f,f,f,f,...,3.9,t,0.83,t,5.0,t,28.0,other,F,840827019
5256,35,F,f,f,f,f,f,t,f,f,...,73.0,t,1.16,t,63.0,t,37.0,other,-,851128040
6044,77,F,f,f,f,f,f,f,f,f,...,120.0,t,0.96,t,124.0,t,45.0,SVI,-,860305064
6045,73,M,f,f,f,f,f,f,f,f,...,89.0,t,0.74,t,119.0,t,24.0,SVI,-,860305065
6747,77,F,f,f,f,f,f,f,f,f,...,131.0,t,1.04,t,126.0,t,25.0,SVI,K,860702030
6773,74,F,f,f,f,f,f,f,f,f,...,116.0,t,0.81,t,143.0,t,22.0,SVI,-,860703046
6862,60,M,f,f,f,f,f,f,f,f,...,92.0,t,0.84,t,110.0,t,21.0,other,-,860710043
6863,66,F,f,f,f,f,f,f,f,f,...,138.0,t,0.8,t,173.0,t,15.0,SVI,-,860710044
6880,42,F,f,f,f,f,f,f,f,f,...,106.0,t,0.98,t,108.0,t,27.0,other,-,860711039
6934,29,F,f,f,f,f,f,f,f,f,...,122.0,t,1.14,t,107.0,t,36.0,SVI,-,860717007


In [10]:
df[sel_qual_cols]

Unnamed: 0,age,T3,T4U,TSH,TT4,FTI
0,29,,,0.3,,
1,29,1.9,,1.6,128.0,
2,41,,,,,
3,36,,,,,
4,32,,,,,
...,...,...,...,...,...,...
9167,56,,0.83,,64.0,77.0
9168,22,,0.92,,91.0,99.0
9169,69,,1.27,,113.0,89.0
9170,47,,0.85,,75.0,88.0


In [11]:
df[sel_cat_cols]

Unnamed: 0,sick,sex,thyroid_surgery,lithium,goitre,referral_source,on_antithyroid_meds,tumor,pregnant,I131_treatment,psych,target,on_thyroxine
0,f,F,f,f,f,other,f,f,f,f,f,-,f
1,f,F,f,f,f,other,f,f,f,f,f,-,f
2,f,F,f,f,f,other,f,f,f,f,f,-,f
3,f,F,f,f,f,other,f,f,f,f,f,-,f
4,f,F,f,f,f,other,f,f,f,f,f,S,f
...,...,...,...,...,...,...,...,...,...,...,...,...,...
9167,f,M,f,f,f,SVI,f,f,f,f,f,-,f
9168,f,M,f,f,f,SVI,f,f,f,f,f,-,f
9169,f,M,f,f,f,SVI,f,f,f,f,f,I,f
9170,f,F,f,f,f,other,f,f,f,f,f,-,f


In [12]:
qual_na_df = df[sel_qual_cols].dropna()
qual_na_df

Unnamed: 0,age,T3,T4U,TSH,TT4,FTI
19,36,2.4,1.06,1.50,90.0,85.0
21,40,2.3,1.08,1.20,104.0,96.0
22,40,2.1,0.84,5.90,88.0,105.0
23,77,2.4,1.13,0.05,107.0,95.0
27,51,2.1,0.87,0.05,93.0,106.0
...,...,...,...,...,...,...
9129,65,1.8,0.90,0.73,85.0,94.0
9130,65,2.1,1.19,4.10,135.0,113.0
9134,74,1.0,1.25,53.00,49.0,39.0
9137,42,1.3,0.73,2.30,59.0,81.0


In [13]:
from data_split import k_fold_split

In [14]:
knn_imputed_df_map = dict()

complete_vals_df, remove_vals_df = k_fold_split(qual_na_df, k, random_generator)
missing_vals_idxs = list(remove_vals_df.index)

real_df = pd.concat([complete_vals_df, remove_vals_df])

for col in sel_qual_cols:
    missing_vals_df = remove_vals_df.copy()
    missing_vals_df[col] = np.NaN

    curr_df = pd.concat([complete_vals_df, missing_vals_df])

    knn_imputer = KNNImputer(n_neighbors=neighbours)
    knn_imputed_mat = knn_imputer.fit_transform(curr_df)
    
    knn_imputed_df = pd.DataFrame(knn_imputed_mat, columns=curr_df.columns, index=curr_df.index)
    knn_imputed_df["{} (real)".format(col)] = real_df[col]
    knn_imputed_df["{} (imputed)".format(col)] = knn_imputed_df[col]
    knn_imputed_df.drop([col], axis=1, inplace=True)

    knn_imputed_df = knn_imputed_df.loc[missing_vals_idxs]

    knn_imputed_df_map[col] = knn_imputed_df

In [15]:
for col, knn_imputed_df in knn_imputed_df_map.items():
    display("{} with KNN".format(col))
    display(knn_imputed_df)

'age with KNN'

Unnamed: 0,T3,T4U,TSH,TT4,FTI,age (real),age (imputed)
92,0.4,0.73,0.20,98.0,134.0,88.0,36.5
312,0.2,0.94,145.00,16.0,17.0,65.0,78.5
1347,2.2,0.81,0.10,153.0,188.0,33.0,35.0
4525,1.9,0.96,0.58,140.0,145.0,85.0,39.0
4328,2.3,0.98,0.33,99.0,101.0,74.0,58.5
...,...,...,...,...,...,...,...
7528,2.5,1.09,2.30,93.0,85.0,55.0,57.5
2110,2.0,0.93,1.80,66.0,71.0,42.0,63.5
3210,1.9,1.06,1.80,102.0,97.0,61.0,46.5
4874,2.4,1.07,1.80,73.0,69.0,90.0,66.5


'T3 with KNN'

Unnamed: 0,age,T4U,TSH,TT4,FTI,T3 (real),T3 (imputed)
92,88.0,0.73,0.20,98.0,134.0,0.4,1.450
312,65.0,0.94,145.00,16.0,17.0,0.2,0.225
1347,33.0,0.81,0.10,153.0,188.0,2.2,1.650
4525,85.0,0.96,0.58,140.0,145.0,1.9,1.950
4328,74.0,0.98,0.33,99.0,101.0,2.3,1.300
...,...,...,...,...,...,...,...
7528,55.0,1.09,2.30,93.0,85.0,2.5,2.050
2110,42.0,0.93,1.80,66.0,71.0,2.0,2.400
3210,61.0,1.06,1.80,102.0,97.0,1.9,2.100
4874,90.0,1.07,1.80,73.0,69.0,2.4,1.400


'T4U with KNN'

Unnamed: 0,age,T3,TSH,TT4,FTI,T4U (real),T4U (imputed)
92,88.0,0.4,0.20,98.0,134.0,0.73,0.745
312,65.0,0.2,145.00,16.0,17.0,0.94,1.020
1347,33.0,2.2,0.10,153.0,188.0,0.81,0.830
4525,85.0,1.9,0.58,140.0,145.0,0.96,0.965
4328,74.0,2.3,0.33,99.0,101.0,0.98,0.995
...,...,...,...,...,...,...,...
7528,55.0,2.5,2.30,93.0,85.0,1.09,1.070
2110,42.0,2.0,1.80,66.0,71.0,0.93,0.940
3210,61.0,1.9,1.80,102.0,97.0,1.06,1.050
4874,90.0,2.4,1.80,73.0,69.0,1.07,1.060


'TSH with KNN'

Unnamed: 0,age,T3,T4U,TT4,FTI,TSH (real),TSH (imputed)
92,88.0,0.4,0.73,98.0,134.0,0.20,1.3000
312,65.0,0.2,0.94,16.0,17.0,145.00,55.0000
1347,33.0,2.2,0.81,153.0,188.0,0.10,0.2000
4525,85.0,1.9,0.96,140.0,145.0,0.58,1.3750
4328,74.0,2.3,0.98,99.0,101.0,0.33,0.6750
...,...,...,...,...,...,...,...
7528,55.0,2.5,1.09,93.0,85.0,2.30,3.1000
2110,42.0,2.0,0.93,66.0,71.0,1.80,0.6525
3210,61.0,1.9,1.06,102.0,97.0,1.80,0.8000
4874,90.0,2.4,1.07,73.0,69.0,1.80,1.6250


'TT4 with KNN'

Unnamed: 0,age,T3,T4U,TSH,FTI,TT4 (real),TT4 (imputed)
92,88.0,0.4,0.73,0.20,134.0,98.0,114.5
312,65.0,0.2,0.94,145.00,17.0,16.0,24.5
1347,33.0,2.2,0.81,0.10,188.0,153.0,162.5
4525,85.0,1.9,0.96,0.58,145.0,140.0,135.0
4328,74.0,2.3,0.98,0.33,101.0,99.0,120.5
...,...,...,...,...,...,...,...
7528,55.0,2.5,1.09,2.30,85.0,93.0,78.5
2110,42.0,2.0,0.93,1.80,71.0,66.0,83.0
3210,61.0,1.9,1.06,1.80,97.0,102.0,93.5
4874,90.0,2.4,1.07,1.80,69.0,73.0,70.5


'FTI with KNN'

Unnamed: 0,age,T3,T4U,TSH,TT4,FTI (real),FTI (imputed)
92,88.0,0.4,0.73,0.20,98.0,134.0,95.5
312,65.0,0.2,0.94,145.00,16.0,17.0,9.0
1347,33.0,2.2,0.81,0.10,153.0,188.0,123.5
4525,85.0,1.9,0.96,0.58,140.0,145.0,146.5
4328,74.0,2.3,0.98,0.33,99.0,101.0,90.5
...,...,...,...,...,...,...,...
7528,55.0,2.5,1.09,2.30,93.0,85.0,101.5
2110,42.0,2.0,0.93,1.80,66.0,71.0,73.5
3210,61.0,1.9,1.06,1.80,102.0,97.0,118.5
4874,90.0,2.4,1.07,1.80,73.0,69.0,66.5


In [16]:
mice_imputed_df_map = dict()

complete_vals_df, remove_vals_df = k_fold_split(qual_na_df, k, random_generator)
missing_vals_idxs = list(remove_vals_df.index)

real_df = pd.concat([complete_vals_df, remove_vals_df])

for col in sel_qual_cols:
    missing_vals_df = remove_vals_df.copy()
    missing_vals_df[col] = np.NaN

    curr_df = pd.concat([complete_vals_df, missing_vals_df])

    mice_imputer = IterativeImputer(random_state=random_state, max_iter=iters)
    mice_imputed_mat = mice_imputer.fit_transform(curr_df)
    
    mice_imputed_df = pd.DataFrame(mice_imputed_mat, columns=curr_df.columns, index=curr_df.index)
    mice_imputed_df["{} (real)".format(col)] = real_df[col]
    mice_imputed_df["{} (imputed)".format(col)] = mice_imputed_df[col]
    mice_imputed_df.drop([col], axis=1, inplace=True)

    mice_imputed_df = mice_imputed_df.loc[missing_vals_idxs]

    mice_imputed_df_map[col] = mice_imputed_df

In [17]:
for col, mice_imputed_df in mice_imputed_df_map.items():
    display("{} with MICE".format(col))
    display(mice_imputed_df)

'age with MICE'

Unnamed: 0,T3,T4U,TSH,TT4,FTI,age (real),age (imputed)
734,2.0,1.41,5.60,134.0,95.0,27.0,77.768345
6971,1.1,0.84,13.00,110.0,131.0,81.0,77.771484
4563,2.7,1.07,2.30,95.0,88.0,70.0,77.764124
3350,2.0,1.04,0.84,101.0,97.0,73.0,77.766604
2235,1.6,1.07,0.60,95.0,89.0,74.0,77.764614
...,...,...,...,...,...,...,...
916,0.7,0.77,0.60,68.0,89.0,74.0,77.762158
6829,2.8,1.26,1.20,181.0,144.0,46.0,77.782530
603,1.9,0.92,0.40,106.0,115.0,69.0,77.770482
8241,2.3,0.93,1.20,98.0,105.0,46.0,77.767755


'T3 with MICE'

Unnamed: 0,age,T4U,TSH,TT4,FTI,T3 (real),T3 (imputed)
734,27.0,1.41,5.60,134.0,95.0,2.0,2.653621
6971,81.0,0.84,13.00,110.0,131.0,1.1,1.847232
4563,70.0,1.07,2.30,95.0,88.0,2.7,1.892578
3350,73.0,1.04,0.84,101.0,97.0,2.0,1.932987
2235,74.0,1.07,0.60,95.0,89.0,1.6,1.901187
...,...,...,...,...,...,...,...
916,74.0,0.77,0.60,68.0,89.0,0.7,1.299452
6829,46.0,1.26,1.20,181.0,144.0,2.8,2.966864
603,69.0,0.92,0.40,106.0,115.0,1.9,1.880045
8241,46.0,0.93,1.20,98.0,105.0,2.3,1.794005


'T4U with MICE'

Unnamed: 0,age,T3,TSH,TT4,FTI,T4U (real),T4U (imputed)
734,27.0,2.0,5.60,134.0,95.0,1.41,1.228634
6971,81.0,1.1,13.00,110.0,131.0,0.84,0.875253
4563,70.0,2.7,2.30,95.0,88.0,1.07,1.045581
3350,73.0,2.0,0.84,101.0,97.0,1.04,1.012983
2235,74.0,1.6,0.60,95.0,89.0,1.07,1.006599
...,...,...,...,...,...,...,...
916,74.0,0.7,0.60,68.0,89.0,0.77,0.816409
6829,46.0,2.8,1.20,181.0,144.0,1.26,1.276063
603,69.0,1.9,0.40,106.0,115.0,0.92,0.945619
8241,46.0,2.3,1.20,98.0,105.0,0.93,0.961926


'TSH with MICE'

Unnamed: 0,age,T3,T4U,TT4,FTI,TSH (real),TSH (imputed)
734,27.0,2.0,1.41,134.0,95.0,5.60,13.278634
6971,81.0,1.1,0.84,110.0,131.0,13.00,2.535249
4563,70.0,2.7,1.07,95.0,88.0,2.30,9.379132
3350,73.0,2.0,1.04,101.0,97.0,0.84,8.362239
2235,74.0,1.6,1.07,95.0,89.0,0.60,11.683056
...,...,...,...,...,...,...,...
916,74.0,0.7,0.77,68.0,89.0,0.60,9.184800
6829,46.0,2.8,1.26,181.0,144.0,1.20,-6.587121
603,69.0,1.9,0.92,106.0,115.0,0.40,3.705222
8241,46.0,2.3,0.93,98.0,105.0,1.20,4.894284


'TT4 with MICE'

Unnamed: 0,age,T3,T4U,TSH,FTI,TT4 (real),TT4 (imputed)
734,27.0,2.0,1.41,5.60,95.0,134.0,140.687933
6971,81.0,1.1,0.84,13.00,131.0,110.0,104.474669
4563,70.0,2.7,1.07,2.30,88.0,95.0,101.141237
3350,73.0,2.0,1.04,0.84,97.0,101.0,103.624700
2235,74.0,1.6,1.07,0.60,89.0,95.0,100.133954
...,...,...,...,...,...,...,...
916,74.0,0.7,0.77,0.60,89.0,68.0,66.689368
6829,46.0,2.8,1.26,1.20,144.0,181.0,163.486578
603,69.0,1.9,0.92,0.40,115.0,106.0,104.236303
8241,46.0,2.3,0.93,1.20,105.0,98.0,98.433522


'FTI with MICE'

Unnamed: 0,age,T3,T4U,TSH,TT4,FTI (real),FTI (imputed)
734,27.0,2.0,1.41,5.60,134.0,95.0,83.938754
6971,81.0,1.1,0.84,13.00,110.0,131.0,131.273540
4563,70.0,2.7,1.07,2.30,95.0,88.0,88.409761
3350,73.0,2.0,1.04,0.84,101.0,97.0,96.701175
2235,74.0,1.6,1.07,0.60,95.0,89.0,85.229518
...,...,...,...,...,...,...,...
916,74.0,0.7,0.77,0.60,68.0,89.0,93.227539
6829,46.0,2.8,1.26,1.20,181.0,144.0,155.645817
603,69.0,1.9,0.92,0.40,106.0,115.0,117.470016
8241,46.0,2.3,0.93,1.20,98.0,105.0,108.776557
