In [None]:
from sklearn.model_selection import GroupKFold 
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
COMP_PATH = '../input/ranzcr-clip-catheter-line-classification/'

N_SPLITS = 5

In [None]:
df_train = pd.read_csv(COMP_PATH+'train.csv')

In [None]:
targets = ['ETT - Abnormal', 'ETT - Borderline',
       'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline',
       'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal',
       'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']

In [None]:
gkf =  GroupKFold(n_splits = N_SPLITS)

df_train["fold"] = -1

df_train = df_train.sample(frac=1).reset_index(drop=True)

result = []   
for fold, (train_idx, val_idx) in enumerate(gkf.split(df_train, df_train[targets], df_train['PatientID'])):
    print(len(train_idx), len(val_idx))
    df_train.loc[val_idx, 'fold'] = fold

In [None]:
fig = plt.figure(figsize=(16, 6))

ax = sns.countplot(df_train['fold'])

ax.tick_params(axis='x', labelsize=20)
ax.tick_params(axis='y', labelsize=20)
ax.set_xticklabels([f'{value} ({count:,})' for value, count in df_train['fold'].value_counts().sort_index().to_dict().items()])
ax.set_xlabel('Folds', size=20, labelpad=20)
ax.set_ylabel('Samples', size=20, labelpad=20)

plt.title(f'Training Set Number of Samples in Folds', size=20, pad=20)

plt.show()
splits = df_train.groupby('fold').sum()[targets] \
        .reset_index(drop=True) \
        .T \
        .rename(columns={fold - 1: fold for fold in sorted(df_train['fold'].unique())}) \
        .reset_index() \
        .rename(columns={'index': 'Target'})

splits = pd.melt(splits, id_vars=['Target'], value_name='Count')
splits['Total'] = splits.groupby('Target')['Count'].transform('sum')
splits = splits.sort_values(by=['Total', 'Target'], ascending=False).reset_index(drop=True)
splits['variable'] = 'Fold ' + splits['variable'].astype(str)

fig = plt.figure(figsize=(16, 8), dpi=100)

sns.barplot(x=splits['Count'],
            y=splits['Target'],
            hue=splits['variable'])

plt.xlabel('')
plt.ylabel('')
plt.tick_params(axis='x', labelsize=15)
plt.tick_params(axis='y', labelsize=15)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0, prop={'size': 20})
plt.title('Multi Label Stratified GroupKFold Target Counts', size=18, pad=18)

plt.show()

In [None]:
df_train.to_csv(f'train_{N_SPLITS}_kfolds.csv')