In [1]:
import pandas as pd
import numpy as np
import random 
import torch 
import os 
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import models
from torch import nn
from scipy import stats
import matplotlib.pyplot as plt
from matplotlib import image
import shutil
from sklearn.mixture import GaussianMixture
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix, balanced_accuracy_score

In [2]:
groups_mapping={0:'chexpert_1', 1:'chexpert_0', 2:'mimic_0', 3:'mimic_1'}

### TRUE GROUP BALANCED DATA:

In [3]:
chex = pd.read_csv('resized_data/chexpert/full_data_chexpert_full.csv').iloc[:,1:]
mimic = pd.read_csv('resized_data/mimic/full_data_mimic_full.csv').iloc[:,1:]
df=pd.concat([chex,mimic])

In [6]:
df.new_path.unique()

array(['/scratch/paa9751/mlhc-project/resized_data/chexpert/imgs/CheXpert-v1.0trainpatient17799study2view2_lateral.npy',
       '/scratch/paa9751/mlhc-project/resized_data/chexpert/imgs/CheXpert-v1.0trainpatient34816study6view1_frontal.npy',
       '/scratch/paa9751/mlhc-project/resized_data/chexpert/imgs/CheXpert-v1.0trainpatient34722study10view1_frontal.npy',
       ...,
       '/scratch/paa9751/mlhc-project/resized_data/mimic/imgs/mimic-cxr-jpg-2.0.0.physionet.orgfilesp14p14607991s57935244d4de5d85-581e7f06-c1b0430f-62c5a6e2-8e820ff5.npy',
       '/scratch/paa9751/mlhc-project/resized_data/mimic/imgs/mimic-cxr-jpg-2.0.0.physionet.orgfilesp15p15296176s55638048ee9c46f7-07144e81-5750e091-9b4a0035-cab21b92.npy',
       '/scratch/paa9751/mlhc-project/resized_data/mimic/imgs/mimic-cxr-jpg-2.0.0.physionet.orgfilesp11p11089893s57650194f3dc58c7-b50e0ab5-a7afd93f-5d0bb2fb-b87c0889.npy'],
      dtype=object)

In [44]:
df[df.split=='train'].true_group_idx.value_counts()

true_group_idx
0    6300
2    6300
1     700
3     700
Name: count, dtype: int64

In [45]:
df[df.split=='val'].true_group_idx.value_counts()

true_group_idx
0    1350
2    1350
1     150
3     150
Name: count, dtype: int64

In [46]:
def subsample_train(data):
    return data.sample(700)
def subsample_val(data):
    return data.sample(150)
undersampled_train=df[df.split=='train'].groupby('true_group_idx').apply(subsample_train)
undersampled_val=df[df.split=='val'].groupby('true_group_idx').apply(subsample_val)
undersampled_groupstrue=pd.concat([undersampled_train,undersampled_val,df[df.split=='test']])

In [47]:
undersampled_groupstrue[undersampled_groupstrue.split=='train'].shape

(2800, 32)

In [41]:
undersampled_groupstrue[undersampled_groupstrue.dataset_idx==1].to_csv('resized_data/chexpert/full_data_chexpert_true_group_balanced.csv')
undersampled_groupstrue[undersampled_groupstrue.dataset_idx==0].to_csv('resized_data/mimic/full_data_mimic_true_group_balanced.csv')

### CREATE SNS AND Y-SNS GROUPS

Run re-weighting NURD on our dataset. 
weight = 1/p(y|z) 
run weighted ERM where weight is attached to loss 

- only train the last layer? 

In [7]:
middle_range=[0.25,0.75]
chex = pd.read_csv('resized_data/chexpert/full_data_chexpert_full.csv').iloc[:,1:]
mimic = pd.read_csv('resized_data/mimic/full_data_mimic_full.csv').iloc[:,1:]
df=pd.concat([chex,mimic])
df=df[(df.observed_prob<=middle_range[0])|(df.observed_prob>=middle_range[1])]
for n in [2,4,8,16,32]:
    n_groups = n
    group_label = 'new_sns_'+str(n)
    
    #cuts=list(np.linspace(0,1,num=n+1))

    cuts = list(np.linspace(0,middle_range[0],num=(n//2)+1)) + list(np.linspace(middle_range[1],1,num=(n//2)+1))
    
    df[f'{str(n)}_prob_range']=pd.cut(df.observed_prob,cuts)
    df[group_label]=pd.cut(df.observed_prob,cuts,labels=False)

    unique_values = sorted(df[group_label].unique())
    complete_range = list(range(unique_values[0], unique_values[-1] + 1))
    missing_values = set(complete_range) - set(unique_values)
    mapping = {}
    shift = 0
    for value in complete_range:
        if value in unique_values:
            mapping[value] = value - shift
        else:
            shift += 1
    df[group_label] = df[group_label].map(mapping)
    print(n)
    print(df[group_label].unique())
    
    x_group = df[group_label].values.reshape(-1,1)
    y_group = df['true_group_idx'].values
    
    clf = DecisionTreeClassifier(random_state=0, max_depth=5, class_weight='balanced')
    clf = clf.fit(x_group, y_group)
    print(f'sns {n} groups classification = ')
    print(balanced_accuracy_score(clf.predict(x_group), y_group))
    
df['new_sns_y_2']=df.new_sns_2.astype(str) + df.Cardiomegaly.astype(int).astype(str)
df['new_sns_y_2']=pd.factorize(df.new_sns_y_2)[0]

# df['new_sns_y_3']=df.new_sns_3.astype(str) + df.Cardiomegaly.astype(int).astype(str)
# df['new_sns_y_3']=pd.factorize(df.new_sns_y_3)[0]

df['new_sns_y_4']=df.new_sns_4.astype(str) + df.Cardiomegaly.astype(int).astype(str)
df['new_sns_y_4']=pd.factorize(df.new_sns_y_4)[0]

# df['new_sns_y_6']=df.new_sns_group_6.astype(str) + df.Cardiomegaly.astype(int).astype(str)
# df['new_sns_y_6']=pd.factorize(df.new_sns_y_6)[0]

df['new_sns_y_8']=df.new_sns_8.astype(str) + df.Cardiomegaly.astype(int).astype(str)
df['new_sns_y_8']=pd.factorize(df.new_sns_y_8)[0]

df['new_sns_y_16']=df.new_sns_16.astype(str) + df.Cardiomegaly.astype(int).astype(str)
df['new_sns_y_16']=pd.factorize(df.new_sns_y_16)[0]

df['new_sns_y_32']=df.new_sns_32.astype(str) + df.Cardiomegaly.astype(int).astype(str)
df['new_sns_y_32']=pd.factorize(df.new_sns_y_32)[0]

2
[1 0]
sns 2 groups classification = 
0.4949517652767075
4
[2 3 0 1]
sns 4 groups classification = 
0.48494001999960995
8
[5 7 0 6 4 3 2 1]
sns 8 groups classification = 
0.4778360090229562
16
[11 14  0 15 13  9  8 12  1  6 10  5  7  3  4  2]
sns 16 groups classification = 
0.47842061990729634
32
[23 29  0 22 30 31 27 26 28 18 17 25 24  3 16 13 21 11 20  1  2 15  6 14
  8 19 10  5  7  9  4 12]
sns 32 groups classification = 
0.4902672237350779




In [8]:
x_group = df['new_sns_y_4'].values.reshape(-1,1)
y_group = df['true_group_idx'].values
clf = DecisionTreeClassifier(random_state=0, max_depth=5, class_weight='balanced')
clf = clf.fit(x_group, y_group)
print(f'y-sns4 groups classification = ')
print(balanced_accuracy_score(clf.predict(x_group), y_group))

y-sns4 groups classification = 
0.9296157822472518


In [10]:
df['new_sns_y_4'].unique()

array([0, 1, 2, 3, 4, 5, 6, 7])

In [13]:
df[['new_sns_4','observed_prob']].groupby('new_sns_4')['observed_prob'].min()

new_sns_4
0    0.000098
1    0.200012
2    0.600037
3    0.800037
Name: observed_prob, dtype: float64

In [14]:
df[df.dataset_idx==1].to_csv('resized_data/chexpert/full_data_chexpert_new_groups_0.4_0.6.csv')
df[df.dataset_idx==0].to_csv('resized_data/mimic/full_data_mimic_new_groups_0.4_0.6.csv')

In [21]:
##CHECK GROUPS ARE CORRECT
chexpert_dir = '/scratch/paa9751/mlhc-project/resized_data/chexpert/full_data_chexpert_new_groups.csv' #removing low accuracy subgroup
mimic_dir = '/scratch/paa9751/mlhc-project/resized_data/mimic/full_data_mimic_new_groups.csv'

In [22]:
fulldf=pd.concat([pd.read_csv(chexpert_dir),pd.read_csv(mimic_dir)])

In [23]:
fulldf[['new_sns_y_2','observed_prob']].groupby('new_sns_y_2')['observed_prob'].max()

new_sns_y_2
0    0.998765
1    0.299927
2    0.999970
3    0.299898
Name: observed_prob, dtype: float64

In [24]:
fulldf.columns

Index(['Unnamed: 0.1', 'Unnamed: 0', 'subject_id', 'Cardiomegaly', 'old_path',
       'new_path', 'split', 'dataset_idx', 'true_group', 'true_group_idx',
       'predicted_prob', 'sns_group', 'observed_prob', 'majority_group',
       'sns_group_2', 'sns_group_3', 'sns_group_4', 'sns_group_8',
       'sns_group_16', 'sample_split', 'pred', 'sns_group_6', 'sns_group_',
       'sns_group_5', 'max_accuracy_model_predictions',
       'min_loss_model_predictions', 'predprob_new', 'sns_y_2', 'sns_y_3',
       'sns_y_4', 'sns_y_6', 'sns_y_8', 'sns_y_16', '2_prob_range',
       'new_sns_2', '4_prob_range', 'new_sns_4', '8_prob_range', 'new_sns_8',
       '16_prob_range', 'new_sns_16', '32_prob_range', 'new_sns_32',
       'new_sns_y_2', 'new_sns_y_4', 'new_sns_y_8', 'new_sns_y_16',
       'new_sns_y_32'],
      dtype='object')

### BALANCING:

In [15]:
chex = pd.read_csv('resized_data/chexpert/full_data_chexpert_new_groups_0.4_0.6.csv').iloc[:,1:]
mimic = pd.read_csv('resized_data/mimic/full_data_mimic_new_groups_0.4_0.6.csv').iloc[:,1:]
df=pd.concat([chex,mimic])

#### 1. 4 GROUPS (with and without label) 

In [22]:
df[df.split=='train'].new_sns_y_4.value_counts()

new_sns_y_4
0    4509
6    3856
5    1529
2    1177
4     474
3     380
1     373
7     370
Name: count, dtype: int64

In [23]:
df[df.split=='val'].new_sns_y_4.value_counts()

new_sns_y_4
0    944
6    831
5    325
2    278
4    117
3     75
7     67
1     66
Name: count, dtype: int64

In [24]:
def subsample_train(data):
    return data.sample(370) #replace accordingly 
def subsample_val(data):
    return data.sample(66) #replace accordingly 
undersampled_train=df[df.split=='train'].groupby('new_sns_y_4').apply(subsample_train) #new_sns_y_4
undersampled_val=df[df.split=='val'].groupby('new_sns_y_4').apply(subsample_val) #new_sns_y_4
undersampled_groups4=pd.concat([undersampled_train,undersampled_val,df[df.split=='test']])

In [30]:
print(undersampled_groups4.shape)
undersampled_groups4[undersampled_groups4.split=='train'].new_sns_y_4.value_counts() # new_sns_y_4

(5990, 47)


new_sns_y_4
0    370
1    370
2    370
3    370
4    370
5    370
6    370
7    370
Name: count, dtype: int64

In [29]:
undersampled_groups4[undersampled_groups4.split=='train'].shape

(2960, 47)

In [31]:
undersampled_groups4[undersampled_groups4.dataset_idx==1].to_csv('resized_data/chexpert/full_data_chexpert_y4group_balanced_0.4_0.6.csv')
undersampled_groups4[undersampled_groups4.dataset_idx==0].to_csv('resized_data/mimic/full_data_mimic_y4group_balanced_0.4_0.6.csv')

#### 2. 2 GROUPS

In [19]:
df[df.split=='val'].new_sns_y_2.value_counts()

new_sns_y_2
0    1113
2    1027
1     111
3     103
Name: count, dtype: int64

In [20]:
def subsample_train(data):
    return data.sample(547)
def subsample_val(data):
    return data.sample(103)
undersampled_train=df[df.split=='train'].groupby('new_sns_y_2').apply(subsample_train)
undersampled_val=df[df.split=='val'].groupby('new_sns_y_2').apply(subsample_val)
undersampled_groups2=pd.concat([undersampled_train,undersampled_val,df[df.split=='test']])

In [22]:
undersampled_groups2[undersampled_groups2.split=='val'].new_sns_y_2.value_counts()#.idxmin()howcanilivewithoutishahowisitevenpossibleimsoinlove

new_sns_y_2
0    103
1    103
2    103
3    103
Name: count, dtype: int64

In [23]:
undersampled_groups2[undersampled_groups2.dataset_idx==1].to_csv('resized_data/chexpert/full_data_chexpert_y2group_balanced.csv')
undersampled_groups2[undersampled_groups2.dataset_idx==0].to_csv('resized_data/mimic/full_data_mimic_y2group_balanced.csv')

#### 2. 8 GROUPS

In [25]:
df[df.split=='train'].new_sns_y_8.value_counts()

new_sns_y_8
1     2541
7     2082
0     1362
14    1199
11     827
4      820
12     677
3      499
15     153
6      152
5      150
2      137
8      137
9      132
10     130
13     127
Name: count, dtype: int64

In [27]:
def subsample_train(data):
    return data.sample(127)
def subsample_val(data):
    return data.sample(19)
undersampled_train=df[df.split=='train'].groupby('new_sns_y_8').apply(subsample_train)
undersampled_val=df[df.split=='val'].groupby('new_sns_y_8').apply(subsample_val)
undersampled_groups8=pd.concat([undersampled_train,undersampled_val,df[df.split=='test']])
undersampled_groups8[undersampled_groups8.split=='train'].new_sns_y_8.value_counts()#.idxmin()howcanilivewithoutishahowisitevenpossibleimsoinlove

new_sns_y_8
0     127
1     127
2     127
3     127
4     127
5     127
6     127
7     127
8     127
9     127
10    127
11    127
12    127
13    127
14    127
15    127
Name: count, dtype: int64

In [28]:
undersampled_groups8[undersampled_groups8.dataset_idx==1].to_csv('resized_data/chexpert/full_data_chexpert_y8group_balanced.csv')
undersampled_groups8[undersampled_groups8.dataset_idx==0].to_csv('resized_data/mimic/full_data_mimic_y8group_balanced.csv')

In [29]:
undersampled_groups8[undersampled_groups8.split=='train'].true_group_idx.value_counts()

true_group_idx
2    669
0    602
3    414
1    347
Name: count, dtype: int64

In [30]:
undersampled_groups4[undersampled_groups4.split=='train'].true_group_idx.value_counts()

true_group_idx
2    702
0    638
3    430
1    366
Name: count, dtype: int64

In [31]:
undersampled_groups2[undersampled_groups2.split=='train'].true_group_idx.value_counts()

true_group_idx
2    727
0    653
3    441
1    367
Name: count, dtype: int64

In [16]:
#check balanced data: 
yc = '/scratch/paa9751/mlhc-project/resized_data/chexpert/full_data_chexpert_y2group_balanced.csv'
ym = '/scratch/paa9751/mlhc-project/resized_data/mimic/full_data_mimic_y2group_balanced.csv'
ysns2 = pd.concat([pd.read_csv(yc),pd.read_csv(ym)])
yc = '/scratch/paa9751/mlhc-project/resized_data/chexpert/full_data_chexpert_y4group_balanced.csv'
ym = '/scratch/paa9751/mlhc-project/resized_data/mimic/full_data_mimic_y4group_balanced.csv'
ysns4 = pd.concat([pd.read_csv(yc),pd.read_csv(ym)])
yc = '/scratch/paa9751/mlhc-project/resized_data/chexpert/full_data_chexpert_y8group_balanced.csv'
ym = '/scratch/paa9751/mlhc-project/resized_data/mimic/full_data_mimic_y8group_balanced.csv'
ysns8 = pd.concat([pd.read_csv(yc),pd.read_csv(ym)])


c = '/scratch/paa9751/mlhc-project/resized_data/chexpert/full_data_chexpert_2group_balanced.csv'
m = '/scratch/paa9751/mlhc-project/resized_data/mimic/full_data_mimic_2group_balanced.csv'
sns2 = pd.concat([pd.read_csv(c),pd.read_csv(m)])
c = '/scratch/paa9751/mlhc-project/resized_data/chexpert/full_data_chexpert_4group_balanced.csv'
m = '/scratch/paa9751/mlhc-project/resized_data/mimic/full_data_mimic_4group_balanced.csv'
sns4 = pd.concat([pd.read_csv(c),pd.read_csv(m)])
c = '/scratch/paa9751/mlhc-project/resized_data/chexpert/full_data_chexpert_8group_balanced.csv'
m = '/scratch/paa9751/mlhc-project/resized_data/mimic/full_data_mimic_8group_balanced.csv'
sns8 = pd.concat([pd.read_csv(c),pd.read_csv(m)])

In [32]:
#train:
y2=ysns2[ysns2.split=='train'].true_group_idx.value_counts()#.plot.bar()
s2=sns2[sns2.split=='train'].true_group_idx.value_counts()
pd.concat([s2,y2],axis=1)#.plot.bar()

Unnamed: 0_level_0,count,count
true_group_idx,Unnamed: 1_level_1,Unnamed: 2_level_1
2,719,727
0,700,653
3,442,441
1,375,367


In [33]:
y4=ysns4[ysns4.split=='train'].true_group_idx.value_counts()#.plot.bar()
s4=sns4[sns4.split=='train'].true_group_idx.value_counts()
pd.concat([s4,y4],axis=1)#.plot.bar()

Unnamed: 0_level_0,count,count
true_group_idx,Unnamed: 1_level_1,Unnamed: 2_level_1
2,719,702
0,658,638
3,445,430
1,374,366


In [34]:
y8=ysns8[ysns8.split=='train'].true_group_idx.value_counts()#.plot.bar()
s8=sns8[sns8.split=='train'].true_group_idx.value_counts()
pd.concat([s8,y8],axis=1)#.plot.bar()

Unnamed: 0_level_0,count,count
true_group_idx,Unnamed: 1_level_1,Unnamed: 2_level_1
2,677,669
0,625,602
3,409,414
1,361,347
