In [None]:
import os,re,gc
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 

import albumentations as albu
import tensorflow as tf 
from tensorflow.keras.applications import ResNet152

# Configuration

In [None]:
LABELS = np.array(['ETT - Abnormal', 'ETT - Borderline',
       'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline',
       'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal',
       'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present'])

N_LABELS = 11 
AUTO = tf.data.experimental.AUTOTUNE

class CONFIG: 
    tta = 5
    batchsize = 32
    imsize = (512,512)

# Dataset

In [None]:
## decoder 
def decode_fn(path):
    file_bytes = tf.io.read_file(path)
    img = tf.io.decode_jpeg(file_bytes,channels=3)
    img = tf.image.resize(img,CONFIG.imsize)
    img = tf.cast(img,tf.uint8) 
    return img

## Test Time Augmentation 
transform = albu.Compose([
    albu.HorizontalFlip(p=0.5),
    albu.VerticalFlip(p=0.5),
    albu.CLAHE(clip_limit=(1,10),p=1)
])

def aug_fn(image):
    aug_img = transform(image = image)["image"]
    aug_img = tf.cast(aug_img/255, tf.float32)
    aug_img = tf.image.resize(aug_img,CONFIG.imsize) 
    return aug_img

def process_data(image):
    aug_img = tf.numpy_function(func=aug_fn, inp=[image], Tout=tf.float32)
    aug_img.set_shape((*CONFIG.imsize,3))
    return aug_img

## Make CLAHE Data 
def make_clahe_dataset(paths,cache_dir=False):
    if cache_dir:
        os.makedirs(cache_dir,exist_ok=True)
    dset = tf.data.Dataset.from_tensor_slices(paths)
    dset = dset.map(decode_fn,num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache_dir else dset 
    dset = dset.map(process_data,num_parallel_calls=AUTO)
    dset = dset.repeat()
    dset = dset.batch(CONFIG.batchsize)
    dset = dset.prefetch(AUTO)
    return dset

In [None]:
sub_df = pd.read_csv("../input/ranzcr-clip-catheter-line-classification/sample_submission.csv")
test_paths = "../input/ranzcr-clip-catheter-line-classification/test/" + sub_df["StudyInstanceUID"] + ".jpg"
clahe_dset = make_clahe_dataset(test_paths)

# Show Image

In [None]:
def view_image(ds,num=4):
    print(ds)
    fig = plt.figure(figsize=(22, 22))
    images = next(iter(ds))
    for i,img in enumerate(images):
        if i == num:
            break 
        img = img.numpy()
        ax = fig.add_subplot(3,4,i+1,xticks=[],yticks=[])
        ax.imshow(img)
    plt.show()

In [None]:
view_image(clahe_dset) 

# Model

In [None]:
def create_model(config):
    model = tf.keras.Sequential([
        ResNet152(input_shape=(*config.imsize,3),
                              weights=None,
                              include_top=False),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(N_LABELS,activation="sigmoid")
    ])
    model.summary()
    return model

In [None]:
model = create_model(CONFIG)
model.load_weights("../input/model-clahe-512/model_nb13_4_0.h5")

# Inference

In [None]:
TEST_NUM = sub_df.shape[0]

steps = (CONFIG.tta*TEST_NUM + CONFIG.batchsize - 1)//CONFIG.batchsize
pred = model.predict(clahe_dset,steps=steps,verbose=1)[:CONFIG.tta*TEST_NUM]
pred = np.mean(pred.reshape((TEST_NUM,CONFIG.tta,N_LABELS),order = "F"),axis = 1)

# Submission

In [None]:
sub_df[LABELS] = pred
sub_df.to_csv('submission.csv', index=False)
sub_df.head()