In [1]:
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import glob
import cv2
from sklearn.model_selection import KFold, StratifiedGroupKFold, StratifiedKFold
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

In [2]:
seed = 42

DATA_DIR = '../../data'
DATASET_DIR = os.path.join(DATA_DIR, 'dataset')
DCM_DIR = os.path.join(DATASET_DIR, 'train_images')
PNG_DIR = os.path.join(DATA_DIR, 'png_folder')
SEG_DIR = os.path.join(DATASET_DIR, 'segmentations')

image_level_label = pd.read_csv(os.path.join(DATASET_DIR, 'image_level_labels.csv'))
train_series_meta = pd.read_csv(os.path.join(DATASET_DIR, 'train_series_meta.csv'))
train = pd.read_csv(os.path.join(DATASET_DIR, 'train.csv'))

df_seg = pd.read_csv(os.path.join(DATA_DIR, 'df_seg.csv'))

df_train = train_series_meta.merge(train, how='left', on='patient_id')
df_train

Unnamed: 0,patient_id,series_id,aortic_hu,incomplete_organ,bowel_healthy,bowel_injury,extravasation_healthy,extravasation_injury,kidney_healthy,kidney_low,kidney_high,liver_healthy,liver_low,liver_high,spleen_healthy,spleen_low,spleen_high,any_injury
0,10004,21057,146.00,0,1,0,0,1,0,1,0,1,0,0,0,0,1,1
1,10004,51033,454.75,0,1,0,0,1,0,1,0,1,0,0,0,0,1,1
2,10005,18667,187.00,0,1,0,1,0,1,0,0,1,0,0,1,0,0,0
3,10007,47578,329.00,0,1,0,1,0,1,0,0,1,0,0,1,0,0,0
4,10026,29700,327.00,0,1,0,1,0,1,0,0,1,0,0,1,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4706,9961,2003,381.00,0,1,0,1,0,1,0,0,1,0,0,1,0,0,0
4707,9961,63032,143.75,0,1,0,1,0,1,0,0,1,0,0,1,0,0,0
4708,9980,40214,103.00,0,1,0,1,0,1,0,0,1,0,0,0,0,1,1
4709,9980,40466,135.00,0,1,0,1,0,1,0,0,1,0,0,0,0,1,1


In [8]:
df_seg

Unnamed: 0,patient_id,bowel_healthy,bowel_injury,extravasation_healthy,extravasation_injury,kidney_healthy,kidney_low,kidney_high,liver_healthy,liver_low,...,spleen_low,spleen_high,any_injury,series_id,aortic_hu,incomplete_organ,mask_file,png_suffix,dcm_folder,fold
0,10004,1,0,0,1,0,1,0,1,0,...,0,1,1,21057,146.00,0,../../data/dataset/segmentations/21057.nii,../../data/png_folder/10004_21057,../../data/dataset/train_images/10004/21057,2
1,10004,1,0,0,1,0,1,0,1,0,...,0,1,1,51033,454.75,0,../../data/dataset/segmentations/51033.nii,../../data/png_folder/10004_51033,../../data/dataset/train_images/10004/51033,4
2,10217,1,0,0,1,1,0,0,0,1,...,0,1,1,16066,208.00,0,../../data/dataset/segmentations/16066.nii,../../data/png_folder/10217_16066,../../data/dataset/train_images/10217/16066,2
3,10228,1,0,1,0,1,0,0,0,1,...,1,0,1,30522,145.00,0,../../data/dataset/segmentations/30522.nii,../../data/png_folder/10228_30522,../../data/dataset/train_images/10228/30522,3
4,10228,1,0,1,0,1,0,0,0,1,...,1,0,1,40471,291.00,0,../../data/dataset/segmentations/40471.nii,../../data/png_folder/10228_40471,../../data/dataset/train_images/10228/40471,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
201,65504,1,0,1,0,1,0,0,0,1,...,0,1,1,55928,144.00,0,../../data/dataset/segmentations/55928.nii,../../data/png_folder/65504_55928,../../data/dataset/train_images/65504/55928,0
202,7642,0,1,1,0,1,0,0,0,1,...,1,0,1,778,183.00,0,../../data/dataset/segmentations/778.nii,../../data/png_folder/7642_778,../../data/dataset/train_images/7642/778,4
203,8848,1,0,1,0,1,0,0,0,1,...,1,0,1,41663,238.00,0,../../data/dataset/segmentations/41663.nii,../../data/png_folder/8848_41663,../../data/dataset/train_images/8848/41663,1
204,8848,1,0,1,0,1,0,0,0,1,...,1,0,1,7384,367.00,0,../../data/dataset/segmentations/7384.nii,../../data/png_folder/8848_7384,../../data/dataset/train_images/8848/7384,1


In [3]:
cols = [
        'any_injury',
        'bowel_injury',
        'extravasation_injury',
        'kidney_high',
        'spleen_high',
        'liver_high'
    ]

In [4]:
def get_train_df(df_train, df_seg, cols, save=False):
    df_tmp = df_train[~df_train['series_id'].isin(df_seg['series_id'].unique())]
    df_tmp_not = df_train[df_train['series_id'].isin(df_seg['series_id'].unique())]
    assert df_tmp['series_id'].nunique() + df_seg['series_id'].nunique() == df_train['series_id'].nunique()

    df_tmp_not = df_tmp_not.merge(df_seg[['patient_id', 'series_id', 'fold']], how='left', on=['patient_id','series_id'])

    k = df_tmp.groupby('patient_id').first().reset_index()

    sgkf = MultilabelStratifiedKFold(5, shuffle=True, random_state=seed)

    k['fold'] = -1
    for fold, (train_idx, valid_idx) in enumerate(sgkf.split(X=k, y=k[cols[1:]])):
        k.loc[valid_idx, 'fold'] = fold

    k2 = df_tmp.merge(k[['patient_id', 'fold']], how='left', on='patient_id')

    train_df = pd.concat([k2, df_tmp_not], axis=0).reset_index(drop=True)
    print(train_df.groupby('patient_id')['fold'].nunique().value_counts())

    train_df['png_suffix'] = PNG_DIR + '/' + train_df['patient_id'].astype(str) + '_' + train_df['series_id'].astype(str)
    train_df['dcm_folder'] = DCM_DIR + '/' + train_df['patient_id'].astype(str) + '/' + train_df['series_id'].astype(str)

    if save:
        train_df.to_csv(os.path.join(DATA_DIR, 'train_df.csv'), index=False)
        print('save [train_df.csv]')
    return train_df

train_df = get_train_df(df_train, df_seg, cols, save=False)
train_df

1    3084
2      63
Name: fold, dtype: int64


Unnamed: 0,patient_id,series_id,aortic_hu,incomplete_organ,bowel_healthy,bowel_injury,extravasation_healthy,extravasation_injury,kidney_healthy,kidney_low,...,liver_healthy,liver_low,liver_high,spleen_healthy,spleen_low,spleen_high,any_injury,fold,png_suffix,dcm_folder
0,10005,18667,187.0,0,1,0,1,0,1,0,...,1,0,0,1,0,0,0,4,../../data/png_folder/10005_18667,../../data/dataset/train_images/10005/18667
1,10007,47578,329.0,0,1,0,1,0,1,0,...,1,0,0,1,0,0,0,1,../../data/png_folder/10007_47578,../../data/dataset/train_images/10007/47578
2,10026,29700,327.0,0,1,0,1,0,1,0,...,1,0,0,1,0,0,0,2,../../data/png_folder/10026_29700,../../data/dataset/train_images/10026/29700
3,10026,42932,122.0,0,1,0,1,0,1,0,...,1,0,0,1,0,0,0,2,../../data/png_folder/10026_42932,../../data/dataset/train_images/10026/42932
4,10051,17486,345.0,0,1,0,1,0,1,0,...,1,0,0,0,1,0,1,0,../../data/png_folder/10051_17486,../../data/dataset/train_images/10051/17486
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4706,65504,55928,144.0,0,1,0,1,0,1,0,...,0,1,0,0,0,1,1,0,../../data/png_folder/65504_55928,../../data/dataset/train_images/65504/55928
4707,7642,778,183.0,0,0,1,1,0,1,0,...,0,1,0,0,1,0,1,4,../../data/png_folder/7642_778,../../data/dataset/train_images/7642/778
4708,8848,41663,238.0,0,1,0,1,0,1,0,...,0,1,0,0,1,0,1,1,../../data/png_folder/8848_41663,../../data/dataset/train_images/8848/41663
4709,8848,7384,367.0,0,1,0,1,0,1,0,...,0,1,0,0,1,0,1,1,../../data/png_folder/8848_7384,../../data/dataset/train_images/8848/7384


In [21]:
def check_fold(df, cols):
    for col in cols:
        l = df.groupby('fold')[col].value_counts()
        print(col)
        for i in range(5):
            print(f'fold {i} :', l[i][0], l[i][1])
        print()

check_fold(train_df, cols)

any_injury
fold 0 : 697 273
fold 1 : 691 258
fold 2 : 685 241
fold 3 : 682 259
fold 4 : 668 257

bowel_injury
fold 0 : 952 18
fold 1 : 926 23
fold 2 : 907 19
fold 3 : 921 20
fold 4 : 901 24

extravasation_injury
fold 0 : 904 66
fold 1 : 893 56
fold 2 : 869 57
fold 3 : 881 60
fold 4 : 864 61

kidney_high
fold 0 : 950 20
fold 1 : 931 18
fold 2 : 906 20
fold 3 : 916 25
fold 4 : 899 26

spleen_high
fold 0 : 912 58
fold 1 : 907 42
fold 2 : 877 49
fold 3 : 903 38
fold 4 : 883 42

liver_high
fold 0 : 954 16
fold 1 : 927 22
fold 2 : 910 16
fold 3 : 920 21
fold 4 : 909 16

