Let's try using a transfer learning model that was specifically **trained on chest x-ray images! **

#### CheXNet - Keras

* CheXNet is based on Densenet 121, which was pretrained on imagenet, before being finetuned on ChestX-ray14, which contained 112,120 frontal view greyscale X-rays from 30,805 patients. 
    * For more about CheXnet, check out the original article or github with the trained model: https://github.com/brucechou1983/CheXNet-Keras
* Loading the model naively won't work, but I provide a workaround here.
* Keras - for ease of use! :) 
    
* Data loading code copied from the kernel [Baseline: Transfer Learning+RandomForest](https://www.kaggle.com/titericz/baseline-transfer-learning-randomforest-gpu/) 
* Transfer learning best practices are* **not** yet applied here* - a frozen base model and tuning of the output layer, followed by unfreezing all layers and gentler finetuning.
    * Removing the added dense layer at the end may improve things (just be sure to handle the logits)
* Note that this is just a starter kernel - there's lots more that could be done to improve the model, the transfer learning, etc' 
* In this simple notebook we'll finetune chexnet, and see how it does vs imagenet pretrained models 
* For static feature extraction, see: [danofer: ranzcr-chexnet-x-ray-transfer-learning-extractor](https://www.kaggle.com/danofer/ranzcr-chexnet-x-ray-transfer-learning-extractor)


   

In [None]:
import os

# import efficientnet.tfkeras as efn
import numpy as np
import pandas as pd
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import train_test_split
import tensorflow as tf
# import tensorflow.keras.applications.efficientnet as efn


import tensorflow as tf
from tensorflow.keras import Sequential
from keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.applications import densenet
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, AveragePooling2D

In [None]:
## size of the pooled output layer from the model
POOLED_OUTPUT_SIZE = 1024 # 1024 for densenet 121, 2048 for mobilenet? 

chexnet_weights_path = "../input/chexnet-keras-weights/brucechou1983_CheXNet_Keras_0.3.0_weights.h5"

# FAST_RUN = False # use only a few rows, for fast debugging

In [None]:
### enable mixed precision - may affect results, but should speed things up
#### https://www.tensorflow.org/guide/mixed_precision
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.mixed_precision import experimental as mixed_precision

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

## Helper functions

The following functions are hidden:
```python
auto_select_accelerator()

build_decoder(with_labels=True, target_size=(256, 256), ext='jpg')

build_augmenter(with_labels=True)

build_dataset(paths, labels=None, bsize=64, cache=True,
              decode_fn=None, augment_fn=None,
              augment=True, repeat=True, shuffle=1024, 
              cache_dir="")
```

Unhide below to see:

In [None]:
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return strategy


def build_decoder(with_labels=True, target_size=(300, 300), ext='jpg'):
    def decode(path):
        file_bytes = tf.io.read_file(path)
        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")

        img = tf.cast(img, tf.float32) / 255.0
        img = tf.image.resize(img, target_size)

        return img
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode

## todo: more augmentations, e.g. rotation, albumention - https://www.kaggle.com/bjoernholzhauer/inference-for-trained-fastai-efficientnet-b4
def build_augmenter(with_labels=True):
    def augment(img):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        return img
    
    def augment_with_labels(img, label):
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=32, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024, 
                  cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)
    
    return dset

## Variables and configurations

In [None]:
COMPETITION_NAME = "ranzcr-clip-catheter-line-classification"
strategy = auto_select_accelerator()
BATCH_SIZE = strategy.num_replicas_in_sync * 32 #16 was 16 - in kernels env, without mixed precision
# GCS_DS_PATH = KaggleDatasets().get_gcs_path(COMPETITION_NAME)

## Preparing dataset

### Loading and preprocess CSVs

In [None]:
load_dir = f"/kaggle/input/{COMPETITION_NAME}/"
df = pd.read_csv(load_dir + 'train.csv')
paths = load_dir + "train/" + df['StudyInstanceUID'] + '.jpg'

sub_df = pd.read_csv(load_dir + 'sample_submission.csv')
test_paths = load_dir + "test/" + sub_df['StudyInstanceUID'] + '.jpg'

# Get the multi-labels
label_cols = sub_df.columns[1:]
labels = df[label_cols].values

In [None]:
# Train test split
(
    train_paths, valid_paths, 
    train_labels, valid_labels
) = train_test_split(paths, labels, test_size=0.1, random_state=42)

Results are notably better with larger image sizes, but this means slower training and will run into the GPU/kernel time limit! 

In [None]:
# Build the tensorflow datasets
IMSIZES = (224, 240, 260, 300, 380, 456, 528, 600) # (224, 240, 260, 300, 380, 456, 528, 600)
# index i corresponds to b-i
size = IMSIZES[3] # [2]

decoder = build_decoder(with_labels=True, target_size=(size, size))
test_decoder = build_decoder(with_labels=False, target_size=(size, size))

# Build the tensorflow datasets - we don't need to define batch size when using this
dtrain = build_dataset(
    train_paths, train_labels, bsize=BATCH_SIZE, 
    cache_dir='/kaggle/tf_cache', decode_fn=decoder
)

dvalid = build_dataset(
    valid_paths, valid_labels, bsize=BATCH_SIZE, 
    repeat=False, shuffle=False, augment=False, 
    cache_dir='/kaggle/tf_cache', decode_fn=decoder
)

dtest = build_dataset(
    test_paths, bsize=BATCH_SIZE, repeat=False, 
    shuffle=False, augment=False, cache=False, 
    decode_fn=test_decoder
)

## Modeling

In [None]:
n_labels = labels.shape[1]


with strategy.scope():
    # # Instantiate cheXnet model with pretrained weights. Pop last layers, add average pooling
    base = densenet.DenseNet121(weights=None,
                                include_top=False,
                                input_shape=(size,size,3)
                               )
    ## workaround - add dummy layer then load weights then pop dummy layer, in order to match expected shape for pretrained weights
    predictions = tf.keras.layers.Dense(14, activation='sigmoid', name='predictions')(base.output)
    base = tf.keras.Model(inputs=base.input, outputs=predictions) 

    base.load_weights(chexnet_weights_path)
    print("CheXNet loaded")
    # base.trainable=False # freeze most layers - for better finetuning procedure - TODO
    # base.training=False
    
### https://stackoverflow.com/questions/41668813/how-to-add-and-remove-new-layers-in-keras-after-loading-weights
#     base._layers.pop() ## _labels instead of labels  

    new_model = tf.keras.layers.GlobalAveragePooling2D()(base.layers[-3].output) 
    ### OPT: add use flatten instead of global pooling. Opt: add dropout, fully connected layers after
    new_model = tf.keras.layers.Dense(n_labels, activation='sigmoid')(new_model) 

    model = tf.keras.Model(base.input, new_model)

    model.compile(
        optimizer='adam',
        loss='binary_crossentropy',
        metrics=[tf.keras.metrics.AUC(multi_label=True)])
#     model.summary()

In [None]:
# model.layers[-4:]
base.layers[-5:-1]

In [None]:
# ############### Train the model ###############
steps_per_epoch = train_paths.shape[0] // BATCH_SIZE
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    'model.h5', save_best_only=True, monitor='val_auc', mode='max')
lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_auc", patience=3, min_lr=1e-6, mode='max')
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', min_delta=0.0001, patience=6, mode='max')

In [None]:
history = model.fit(
    dtrain, 
    epochs=36, # training for longer results in a better model - but timed out on kernels
    verbose=1,
    callbacks=[checkpoint, lr_reducer,early_stop],
    steps_per_epoch=steps_per_epoch,
    validation_data=dvalid)

In [None]:
model.load_weights('model.h5')

## Save history

In [None]:
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('history.csv')

## Submission

In [None]:
sub_df[label_cols] = model.predict(dtest, verbose=1)
sub_df.to_csv('submission.csv', index=False)

sub_df.head()