In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import image
from matplotlib import pyplot
import os
import cv2
import random
import concurrent.futures
import time
import sklearn
from sklearn.model_selection import train_test_split
print(tf.__version__)

In [None]:
PATH = '../input/ranzcr-clip-catheter-line-classification/train_tfrecords'
filepath = os.listdir(PATH)

raw_dataset = tf.data.TFRecordDataset([os.path.join(PATH,ele) for ele in filepath])

for raw_record in raw_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    print(example)

In [None]:
import tensorflow as tf
BATCH_SIZE = 32
SHUFFLE_BUFFER = 2000
NUM_CLASSES = 11
IMAGE_SIZE = 1024
def _parse_function(proto):
    keys_to_features = {'image': tf.io.FixedLenFeature([], tf.string),
                        "CVC - Abnormal": tf.io.FixedLenFeature([], tf.int64),
                        "CVC - Borderline": tf.io.FixedLenFeature([], tf.int64),
                        "CVC - Normal": tf.io.FixedLenFeature([], tf.int64),
                        "ETT - Abnormal": tf.io.FixedLenFeature([], tf.int64),
                        "ETT - Borderline": tf.io.FixedLenFeature([], tf.int64),
                        "ETT - Normal": tf.io.FixedLenFeature([], tf.int64),
                        "NGT - Abnormal": tf.io.FixedLenFeature([], tf.int64),
                        "NGT - Borderline": tf.io.FixedLenFeature([], tf.int64),
                        "NGT - Incompletely Imaged": tf.io.FixedLenFeature([], tf.int64),
                        "NGT - Normal": tf.io.FixedLenFeature([], tf.int64),
                        "StudyInstanceUID": tf.io.FixedLenFeature([], tf.string),
                        "Swan Ganz Catheter Present": tf.io.FixedLenFeature([], tf.int64)}
    parsed_features = tf.io.parse_single_example(proto, keys_to_features)
    image = decode_image(parsed_features['image'])
    
    CVC_Abnormal = tf.cast(parsed_features['CVC - Abnormal'], tf.uint8)
    CVC_Borderline = tf.cast(parsed_features['CVC - Borderline'], tf.uint8)
    CVC_Normal = tf.cast(parsed_features['CVC - Normal'], tf.uint8)
    ETT_Abnormal = tf.cast(parsed_features['ETT - Abnormal'], tf.uint8)
    ETT_Borderline = tf.cast(parsed_features['ETT - Borderline'], tf.uint8)
    ETT_Normal = tf.cast(parsed_features['ETT - Normal'], tf.uint8)
    NGT_Abnormal = tf.cast(parsed_features['NGT - Abnormal'], tf.uint8)
    NGT_Borderline = tf.cast(parsed_features['NGT - Borderline'], tf.uint8)
    NGT_Incompletely_Imaged = tf.cast(parsed_features['NGT - Incompletely Imaged'], tf.uint8)
    NGT_Normal = tf.cast(parsed_features['NGT - Normal'], tf.uint8)
    StudyInstanceUID = tf.cast(parsed_features['StudyInstanceUID'], tf.string)
    Swan_Ganz_Catheter_Present = tf.cast(parsed_features['Swan Ganz Catheter Present'], tf.uint8)
    label=[CVC_Abnormal, CVC_Borderline, CVC_Normal, ETT_Abnormal, ETT_Borderline, ETT_Normal,\
          NGT_Abnormal, NGT_Borderline, NGT_Incompletely_Imaged, NGT_Normal, Swan_Ganz_Catheter_Present]
    return image, label
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.uint8)  
    image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
    return image
def create_dataset(filepath):    
    dataset = tf.data.TFRecordDataset(filepath)
    dataset = dataset.map(_parse_function, num_parallel_calls=4)
    dataset = dataset.shuffle(SHUFFLE_BUFFER)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
    iterator2 = tf.compat.v1.data.make_one_shot_iterator(dataset)
    image, label = iterator.get_next()
    return image, label

In [None]:
import tensorflow as tf
from tensorflow.python import keras as keras
from tensorflow.keras.applications import MobileNetV2

PATH = '../input/ranzcr-clip-catheter-line-classification/train_tfrecords'
fn = os.listdir(PATH)
filepath = [os.path.join(PATH, ele) for ele in fn]
train_path, val_path = train_test_split(filepath, test_size=3, train_size=13, random_state=None, shuffle=False, stratify=None)
train_image, train_label = create_dataset(train_path)
val_image, val_label = create_dataset(val_path)

In [None]:
PATH = '../input/ranzcr-clip-catheter-line-classification/train_tfrecords'
fn = os.listdir(PATH)
len(fn)

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

In [None]:
SEED = 100
BATCH_SIZE = 8
random_rotation = tf.keras.layers.experimental.preprocessing.RandomRotation(3.142/2, seed=SEED)
random_flip = tf.keras.layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED)
random_zoom = tf.keras.layers.experimental.preprocessing.RandomZoom((0, 0.25), seed=SEED)
random_translate = tf.keras.layers.experimental.preprocessing.RandomTranslation((-0, 0.25), (-0, 0.25), seed=SEED)
with tf.device('/cpu:0'):
    train_dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
    train_dataset = train_dataset.shuffle(BATCH_SIZE).batch(BATCH_SIZE)
    del train_image, train_label
    val_dataset = tf.data.Dataset.from_tensor_slices((val_image, val_label))
    val_dataset = val_dataset.shuffle(8).batch(8)
    del val_image, val_label
def normalize(imgs, label):
    #imgs = random_rotation.call(imgs)
    #imgs = random_flip.call(imgs)
    #imgs = random_zoom.call(imgs)
    #imgs = random_translate.call(imgs)
    return tf.cast(imgs, tf.float16)/255, label
train_dataset = train_dataset.map(normalize, num_parallel_calls=4)
val_dataset = val_dataset.map(normalize, num_parallel_calls=4)

In [None]:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    base_model = tf.keras.applications.MobileNetV2(include_top=False, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),\
                                                   weights='imagenet', pooling = 'max', alpha=1.3)
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.Dense(11, activation='sigmoid')
    ])
    optimizer = tf.keras.optimizers.Adam(0.00001)
    epoch_auc = tf.keras.metrics.AUC(num_thresholds=200)
    val_epoch_auc = tf.keras.metrics.AUC(num_thresholds=200)
    loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    val_loss = tf.keras.losses.CategoricalCrossentropy()

train_loss_history = []
val_loss_history = []
dist_train_dataset = strategy.experimental_distribute_dataset(train_dataset)
dist_val_dataset = strategy.experimental_distribute_dataset(val_dataset)

In [None]:
model.summary()

In [None]:
from sklearn.metrics import accuracy_score
with strategy.scope():
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=BATCH_SIZE)
def compute_acc(labels, predictions):
    return accuracy_score(labels, predictions)
def train_step(inputs):
    images, labels = inputs
    with tf.GradientTape() as tape:
        logits = model(images, training=True)
        loss_value = compute_loss(labels, logits)
    epoch_auc.update_state(labels, logits)
    train_loss_history.append(loss_value)
    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss_value
@tf.function
def distributed_train_step(dist_inputs):
    per_replica_losses = strategy.run(train_step, args=(dist_inputs,))
    loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)
    return loss
def val_step(inputs):
    images, labels = inputs
    logits = model(images, training=False)
    loss_value = loss_object(labels, logits)
    val_loss.update_states(loss_value)
    val_epoch_auc.update_state(labels, logits)
    val_loss_history.append(loss_value)
@tf.function
def distributed_val_step(dist_inputs):
    per_replica_losses = strategy.run(val_step, args=(dist_inputs,))
    loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)
    return loss
def train(epochs, verbose=1):
    for epoch in range(epochs):
        start = time.time()
        i = 0
        print ('\nEpoch {}/{} '.format(epoch+1, epochs))
        ####################### Train Loop #########################
        num_batches = 0
        loss = 0.0
        for data in dist_train_dataset:
            print(len(data))
            loss += distributed_train_step(data)
            num_batches += 1
            auc = epoch_auc.result()
            percent = float(i+1) * 100 / len(train_dataset)
            arrow   = '-' * int(percent/100 * 20 - 1) + '>'
            spaces  = ' ' * (20 - len(arrow))
            if(verbose):    
                print('\rTraining: [%s%s] %d %% - Training Loss: %f - Training AUC: %f'% (arrow, spaces, percent, loss/num_batches, auc), end='', flush=True)
            i += 1
        if(not verbose):
            print(' Epoch Loss: ', loss.numpy())
        i = 0
        if(verbose):
            print(" -", int(time.time()-start), "s", end="")
            print()
        start = time.time()
        
        ####################### Validation Loop #########################
        for data in dist_val_dataset:
            distributed_val_step(data)
            auc = val_epoch_auc.result()
            loss = val_loss.result()
            percent = float(i+1) * 100 / len(val_dataset)
            arrow   = '-' * int(percent/100 * 20 - 1) + '>'
            spaces  = ' ' * (20 - len(arrow))
            if(verbose):    
                print('\rValidate: [%s%s] %d %% - Validation Loss: %f - Validation AUC: %f'% (arrow, spaces, percent, loss, auc), end='', flush=True)
            i += 1
        if(verbose):
            print(" -", int(time.time()-start), "s")
            
        epoch_auc.reset_states()
        val_epoch_auc.reset_states()
        val_loss.reset_states()

In [None]:
train(50, verbose=1)

In [None]:
model.save('./mobilenetv2.h5')

In [None]:
for data in dist_train_dataset:
    print(data)