#  Stratified GroupKFold TFRecords 

All the Dataset can be found at : [ [(128x128)](https://www.kaggle.com/prateek0x/ranzcr-128x128) , ([256x256](https://www.kaggle.com/prateek0x/ranzcr-256x256)) , ([384x384](https://www.kaggle.com/prateek0x/ranzcr-384x384/)) , ([512x512](https://www.kaggle.com/prateek0x/ranzcr-512x512)) ]

**A sample notebook** that presents Stratified GroupKFold cross-validation and Efficient Net architecture getting trained using TPUs can be found [here.](https://www.kaggle.com/prateek0x/stratified-groupkfold-with-efn-tfrecords)

For discussions see this [thread](https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification/discussion/208689)

### Thanks to :
1. Notebook :  [How to Create TFrecords](https://www.kaggle.com/dimitreoliveira/cassava-leaf-disease-stratified-tfrecords-256x256)
2. Discussion : [A simple way to split folds](https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification/discussion/204638)


In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
from sklearn.model_selection import GroupKFold
import re, os, cv2, random, warnings, shutil,tqdm

def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 0
seed_everything(seed)
warnings.filterwarnings('ignore')

# Configuration

In [None]:
database_base_path = '../input/ranzcr-clip-catheter-line-classification/'
PATH = f'{database_base_path}train/'
IMGS = os.listdir(PATH)
N_FILES = 15 # split images into 15 files
HEIGHT, WIDTH = (256, 256) # Resized Image size
IMG_QUALITY = 100 

print(f'Image samples: {len(IMGS)}')

# Auxiliary functions

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
                      
    image = tf.image.resize(image, [HEIGHT, WIDTH])
    image = tf.reshape(image, [HEIGHT, WIDTH, 3])
    return image

def read_tfrecord(example):
    
    TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "ETT - Abnormal": tf.io.FixedLenFeature([], tf.int64),
        "ETT - Borderline": tf.io.FixedLenFeature([], tf.int64),
        "ETT - Normal": tf.io.FixedLenFeature([], tf.int64),
        "NGT - Abnormal": tf.io.FixedLenFeature([], tf.int64),
        "NGT - Borderline": tf.io.FixedLenFeature([], tf.int64),
        "NGT - Incompletely Imaged": tf.io.FixedLenFeature([], tf.int64),
        "NGT - Normal": tf.io.FixedLenFeature([], tf.int64),
        "CVC - Abnormal": tf.io.FixedLenFeature([], tf.int64),
        "CVC - Borderline": tf.io.FixedLenFeature([], tf.int64),
        "CVC - Normal": tf.io.FixedLenFeature([], tf.int64),
        "Swan Ganz Catheter Present": tf.io.FixedLenFeature([], tf.int64),
        "StudyInstanceUID":tf.io.FixedLenFeature([], tf.string),
        "PatientID":tf.io.FixedLenFeature([], tf.string)
    
    }
    
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])
    target = [example['ETT - Abnormal'],
                 example['ETT - Borderline'],
                 example['ETT - Normal'],
                 example['NGT - Abnormal'],
                 example['NGT - Borderline'],
                 example['NGT - Incompletely Imaged'],
                 example['NGT - Normal'],
                 example['CVC - Abnormal'],
                 example['CVC - Borderline'],
                 example['CVC - Normal'],
                 example['Swan Ganz Catheter Present']]
    
    
    name = example['PatientID']
    return image, target, name

def load_dataset(filenames, HEIGHT, WIDTH, CHANNELS=3):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

def display_samples(ds, row, col):
    ds_iter = iter(ds)
    plt.figure(figsize=(15, int(15*row/col)))
    for j in range(row*col):
        image, label, name = next(ds_iter)
        plt.subplot(row,col,j+1)
        plt.axis('off')
        plt.imshow(image[0])
        plt.title(name.numpy()[0], fontsize=12)
    plt.show()

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)


# Create TF Records
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(image,ETT_Abnormal,ETT_Borderline,ETT_Normal, NGT_Abnormal, NGT_Borderline, NGT_Incompletely_Imaged, NGT_Normal, CVC_Abnormal,CVC_Borderline, CVC_Normal, SwanGanzCatheterPresent, unique_id, patient_id):
    feature = {
      'image': _bytes_feature(image),
      'ETT - Abnormal':_int64_feature(ETT_Abnormal) ,
      'ETT - Borderline':_int64_feature(ETT_Borderline),
      'ETT - Normal':_int64_feature(ETT_Normal),
      'NGT - Abnormal':_int64_feature(NGT_Abnormal),
      'NGT - Borderline':_int64_feature(NGT_Borderline),
      'NGT - Incompletely Imaged':_int64_feature(NGT_Incompletely_Imaged),
      'NGT - Normal':_int64_feature(NGT_Normal),
      'CVC - Abnormal':_int64_feature(CVC_Abnormal),
      'CVC - Borderline':_int64_feature(CVC_Borderline),
      'CVC - Normal':_int64_feature(CVC_Normal),
      'Swan Ganz Catheter Present':_int64_feature(SwanGanzCatheterPresent),
      'StudyInstanceUID': _bytes_feature(unique_id),
      'PatientID': _bytes_feature(patient_id)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

# Load data

In [None]:
train = pd.read_csv(database_base_path + 'train.csv')
print('Train samples: ', len(train))

display(train.head())

## Split samples into 15 different files

In [None]:
target_cols = ['ETT - Abnormal', 'ETT - Borderline',
       'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline',
       'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal',
       'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']

In [None]:
groups = np.array(train['PatientID'].values)
gkf = GroupKFold(n_splits=N_FILES)
train['fold'] = -1

for fold_n, (train_idx, val_idx) in enumerate(gkf.split(train["StudyInstanceUID"],train[target_cols],groups)):
    print('Fold: %s has %s samples' % (fold_n+1, len(val_idx)))
    train['fold'].loc[val_idx] = fold_n
    
display(train.head())
train.to_csv('train.csv', index=False)

# Generate TF records

In [None]:
for tfrec_num in range(N_FILES):
    print('\nWriting TFRecord %i of %i...'%(tfrec_num, N_FILES))
    samples = train[train['fold'] == tfrec_num]
    n_samples = len(samples)
    print(f'{n_samples} samples')
    with tf.io.TFRecordWriter('Id_train%.2i-%i.tfrec'%(tfrec_num, n_samples)) as writer:
        for row in tqdm.tqdm(samples.itertuples()):
            
            ETT_Abnormal = row._2
            ETT_Borderline = row._3
            ETT_Normal = row._4
            NGT_Abnormal = row._5
            NGT_Borderline = row._6
            NGT_Incompletely_Imaged = row._7
            NGT_Normal = row._8
            CVC_Abnormal = row._9
            CVC_Borderline = row._10
            CVC_Normal = row._11
            SwanGanzCatheterPresent = row._12
            patient_id = row.PatientID
            image_name = row.StudyInstanceUID+".jpg"
            img_path = f'{PATH}{image_name}'
            
            img = cv2.imread(img_path)
            img = cv2.resize(img, (HEIGHT, WIDTH))
            img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, IMG_QUALITY))[1].tostring()
            
            example = serialize_example(img,ETT_Abnormal,ETT_Borderline,ETT_Normal, NGT_Abnormal, NGT_Borderline, NGT_Incompletely_Imaged, NGT_Normal, CVC_Abnormal,CVC_Borderline, CVC_Normal, SwanGanzCatheterPresent, str.encode(image_name), str.encode(patient_id))
            
            writer.write(example)

# Visualize created TF records


In [None]:
AUTO = tf.data.experimental.AUTOTUNE
FILENAMES = tf.io.gfile.glob('Id_train*.tfrec')
print(f'TFRecords files: {FILENAMES}')
print(f'Created image samples: {count_data_items(FILENAMES)}')

display_samples(load_dataset(FILENAMES, HEIGHT, WIDTH).batch(1), 6, 6)

## Samples in each TFRecord

In [None]:
fig = plt.figure(figsize=(16, 6))

ax = sns.countplot(train['fold'])

ax.tick_params(axis='x', labelsize=20)
ax.tick_params(axis='y', labelsize=20)
ax.set_xticklabels([f'{value} ' for value in range(1,16)])
ax.set_xlabel('Folds', size=20, labelpad=20)
ax.set_ylabel('Samples', size=20, labelpad=20)

plt.title(f'Training Set Number of Samples in Folds', size=20, pad=20)

plt.show()

## Labels distribution for each TFRecord

In [None]:
tfrec_1 = train[train["fold"]==0]
label_dist = dict()
for tar in target_cols:
    label_dist[tar] = int(tfrec_1[tar].sum())
    
fig = plt.figure(figsize=(7, 8), dpi=100)

ax = sns.barplot(y=["Count : "+str(list(label_dist.values())[value])+" : "+list(label_dist.keys())[value] for value in range(0,11)],
            x=list(label_dist.values()),
            hue=list(label_dist.keys()),
           )

In [None]:
splits = train.groupby('fold').sum()[target_cols] \
        .reset_index(drop=True) \
        .T \
        .rename(columns={fold - 1: fold for fold in sorted(train['fold'].unique())}) \
        .reset_index() \
        .rename(columns={'index': 'Target'})

splits = pd.melt(splits, id_vars=['Target'], value_name='Count')
splits['Total'] = splits.groupby('Target')['Count'].transform('sum')
splits = splits.sort_values(by=['Total', 'Target'], ascending=False).reset_index(drop=True)
splits['variable'] = 'Fold ' + splits['variable'].astype(str)

fig = plt.figure(figsize=(16, 20), dpi=100)

sns.barplot(x=splits['Count'],
            y=splits['Target'],
            hue=splits['variable'])

plt.xlabel('')
plt.ylabel('')
plt.tick_params(axis='x', labelsize=15)
plt.tick_params(axis='y', labelsize=15)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0, prop={'size': 20})
plt.title('Multi Label Stratified GroupKFold Target Counts', size=18, pad=18)

plt.show()