<div class = "alert alert-block alert-info">
    <h2><font color = "red">DISCLAIMER</font></h2>
    <p>The approach that will be shown here is simply a fun experiment. It can provide idea or encourage ...</p>
</div>


---

### Version Update 2

[WIP]: Working to add **Model interpretability with Integrated Gradients**. FYI, the model `LamdaDenseNet` seems can able to take the leverage of TPU; the MXU seems remain to high. That's really interesting. 

References:

- Integrated Gradients original [paper](https://arxiv.org/abs/1703.01365)
- Original [implementation](https://github.com/ankurtaly/Integrated-Gradients)
- Reusing code from [here](https://keras.io/examples/vision/integrated_gradients/) by [Aakash Kumar Nain](https://www.kaggle.com/aakashnain)


### Version Update 1

Previously we've built a custom **mixed depth-wise group convolutional** layer in `tf.keras` and added it at the top of the pre trained model. In this version we will extend it and try something new. We will modify the pre-trained model **internally** and add some custom layer. Mainly, we will be tweaking with the following stuff. 

```
- Dense Net 121
- Mixed Depth-wise Group Convolution
- Lamda-Layer: Long-Range Interaction without Attention.
```

The end architecture will be look something like as follows, namely **LamdaDenseNet-121**: 

![LamdaDenseNet](https://user-images.githubusercontent.com/17668390/103923527-68532100-513f-11eb-97f5-d16806865f08.png)


The features flow will be something like this:

- First we init the convolution part of `DenseNet-121` with a custom `MixDepthGroupConvolution2D` layer.
- Second we iterate over the consecutive Dense Blocks and add a custom `LamdaLayer` in between the Dense Block and Transition Layer until Dense Block 5. 

## <font color = "seagreen">-)</font>

In [None]:
!pip install efficientnet -q
!pip install einops -q

import numpy as np
import pandas as pd
import seaborn as sns
from glob import glob
import albumentations as A 
from pylab import rcParams
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import os, gc, cv2, random, warnings, math, sys, json, pprint

# sklearn
from sklearn.model_selection import  GroupKFold
from sklearn.metrics import roc_auc_score

# tf 
import tensorflow as tf
import efficientnet.tfkeras as efn
from tensorflow.keras import backend as K

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
warnings.simplefilter('ignore')

In [None]:
SEED  = 101
TRAIN_DF       = '../input/ranzcr-clip-catheter-line-classification/train.csv'
TRAIN_IMG_PATH = '../input/ranzcr-clip-catheter-line-classification/train/'
TEST_IMG_PATH  = '../input/ranzcr-clip-catheter-line-classification/test/'
CLASS_MAP      = '../input/ranzcr-clip-catheter-line-classification/train_annotations.csv'

In [None]:
df = pd.read_csv(TRAIN_DF)
df.head()

## <font color = "seagreen">RANZCR-CLiP Dataloader</font>

In [None]:
from kaggle_datasets import KaggleDatasets

def auto_select_accelerator():
    """
    Reference: 
        * https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu
        * https://www.kaggle.com/xhlulu/ranzcr-efficientnet-tpu-training
    """
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    return strategy


def build_decoder(with_labels=True, target_size=(256, 256), ext='jpg'):
    def decode(path):
        file_bytes = tf.io.read_file(path)
        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")

        img = tf.cast(img, tf.float32) / 255.0
        img = tf.image.resize(img, target_size)

        return img
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode


def build_augmenter(with_labels=True):
    def augment(img):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        return img
    
    def augment_with_labels(img, label):
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=32, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024, 
                  cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)
    
    return dset

strategy    = auto_select_accelerator()
BATCH_SIZE  = strategy.num_replicas_in_sync * 16
GCS_DS_PATH = KaggleDatasets().get_gcs_path('ranzcr-clip-catheter-line-classification')

print(BATCH_SIZE)

In [None]:
from sklearn.model_selection import train_test_split

paths = GCS_DS_PATH + "/train/" + df['StudyInstanceUID'] + '.jpg'
sub_df = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/sample_submission.csv')
test_paths = GCS_DS_PATH + "/test/" + sub_df['StudyInstanceUID'] + '.jpg'

# Get the multi-labels
label_cols = sub_df.columns[1:]
labels = df[label_cols].values

# Train test split
(
    train_paths, valid_paths, 
    train_labels, valid_labels
) = train_test_split(paths, labels, test_size=0.15, random_state=42)

In [None]:
# Build the tensorflow datasets
IMSIZES = (224, 240, 260, 300, 380, 456, 528, 600, 850)
im_size = IMSIZES[3]

decoder = build_decoder(with_labels=True, target_size=(im_size, im_size))
test_decoder = build_decoder(with_labels=False, target_size=(im_size, im_size))

train_dataset = build_dataset(
    train_paths, train_labels, bsize=BATCH_SIZE, decode_fn=decoder
)

valid_dataset = build_dataset(
    valid_paths, valid_labels, bsize=BATCH_SIZE, decode_fn=decoder,
    repeat=False, shuffle=False, augment=False
)

test_dataset = build_dataset(
    test_paths, cache=False, bsize=BATCH_SIZE, decode_fn=test_decoder,
    repeat=False, shuffle=False, augment=False
)

## Mixed Depthwise Convolution 

or **MDConv** is a core building block of [MixNet](https://arxiv.org/abs/1907.09595) architecture. This is the layer that we will add before the dense blocks. Generally, Mixed Convolution is a group of convolutions with varying filter sizes on each layer. Typically, we use constant filter dimention on per layer i.e `Conv2D`. But the group convolution is a concept where we use multiple size kernel on per layer. 

<p align="center">
  <img width="600" height="200" src="https://user-images.githubusercontent.com/17668390/103163501-dc7cf300-4828-11eb-813f-d23e5abb89d9.png">
</p>

However, unlike `pytorch`, currently in `tf.keras`, still there is no such `group` parameter inside any convolutional layers that conveniently do this job. Instead we need to do something like as follows: 

<p align="center">
  <img width="500" height="200" src="https://user-images.githubusercontent.com/17668390/103163565-82c8f880-4829-11eb-85d9-645273c80e1c.png">
</p>


The original implementation I've found a bit messy to use for custom data set. However, here I will reuse their code bases and write a simple `tf.keras` layer for mix depth-wise convolution for easy use. It's actually pretty simple and with reasonable effort the whole `MixNet` is achievable. 

- [Official Code](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet).

---

## <font color = "seagreen">1. Mix-Depth Group Convolution</font>


The code is adopted from official [code](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) and modified accordingly. To understand the operation, we must first understand what depthwise convolution is. The following diagram stated it at best. 

<p align="center">
  <img width="600" height="200" src="https://user-images.githubusercontent.com/17668390/103163823-5fa04800-482d-11eb-94be-3465988b7d6e.png">
</p>


Mixed Depthwise is simply an extend of it. Unlike a constant kernel size, above which is `5 x 5`, the `MixConv` uses multiple kernel sizes to the splitted channels, like below: `3 x 3`, `5 x 5` etc.

<p align="center">
  <img width="600" height="200" src="https://user-images.githubusercontent.com/17668390/103163501-dc7cf300-4828-11eb-813f-d23e5abb89d9.png">
</p>

In [None]:
class MixDepthGroupConvolution2D(tf.keras.layers.Layer):
    def __init__(self, kernels=[3, 5],
                 conv_kwargs=None,
                 **kwargs):
        super(MixDepthGroupConvolution2D, self).__init__(**kwargs)

        if conv_kwargs is None:
            conv_kwargs = {
                'strides': (1, 1),
                'padding': 'same',
                'dilation_rate': (1, 1),
                'use_bias': False,
            }
        self.channel_axis = -1 
        self.kernels = kernels
        self.groups = len(self.kernels)
        self.strides = conv_kwargs.get('strides', (1, 1))
        self.padding = conv_kwargs.get('padding', 'same')
        self.dilation_rate = conv_kwargs.get('dilation_rate', (1, 1))
        self.use_bias = conv_kwargs.get('use_bias', False)
        self.conv_kwargs = conv_kwargs or {}

        self.layers = [tf.keras.layers.DepthwiseConv2D(kernels[i],
                                       strides=self.strides,
                                       padding=self.padding,
                                       activation=tf.nn.relu,                
                                       dilation_rate=self.dilation_rate,
                                       use_bias=self.use_bias,
                                       kernel_initializer='he_normal')
                        for i in range(self.groups)]

    def call(self, inputs, **kwargs):
        if len(self.layers) == 1:
            return self.layers[0](inputs)
        filters = K.int_shape(inputs)[self.channel_axis]
        splits  = self.split_channels(filters, self.groups)
        x_splits  = tf.split(inputs, splits, self.channel_axis)
        x_outputs = [c(x) for x, c in zip(x_splits, self.layers)]
        return tf.keras.layers.concatenate(x_outputs, 
                                           axis=self.channel_axis)

    def split_channels(self, total_filters, num_groups):
        split = [total_filters // num_groups for _ in range(num_groups)]
        split[0] += total_filters - sum(split)
        return split

    def get_config(self):
        config = {
            'kernels': self.kernels,
            'groups': self.groups,
            'strides': self.strides,
            'padding': self.padding,
            'dilation_rate': self.dilation_rate,
            'conv_kwargs': self.conv_kwargs,
        }
        base_config = super(MixDepthGroupConvolution2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

## <font color = "seagreen">2. Lamda Layer</font>

LamdaNetworks, basically an approach for the alternatives of Attention Mechanism. The main intuition is to capture the long-range interaction between an input and structured contextual information. 


<p align="center">
  <img width="600" height="600" src="https://user-images.githubusercontent.com/17668390/103927561-be769300-5144-11eb-8c4f-d4e885913017.png">
</p>

However, currently the paper is under in double blind review and that's why there's no official implementation of it. If I'm not wrong they've showed a performance boost after adding the Lamda layer with `ResNet` family and `MobileNet`. Check the experimental details in page 17 and also 19 for hybrid lamda network. In the paper, it's claimed that the lamda layer can capture such long-range interaction by transforming available contexts into linear functions. Here, we'll be using an un-official implementaion of `LamdaLayer` from [here](https://github.com/lucidrains/lambda-networks). You can also find `PyTorch` code too. 

In [None]:
import tensorflow as tf 
from tensorflow.keras import backend as K

from einops.layers.tensorflow import Rearrange
from tensorflow.keras import initializers
from tensorflow import einsum, nn, meshgrid
from tensorflow.keras.layers import Layer, InputSpec
from tensorflow.keras import initializers as initializations

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, concatenate, ZeroPadding2D
from tensorflow.keras.layers import Dense, Dropout, Activation
from tensorflow.keras.layers import Conv2D, Conv3D,  Softmax, Lambda, Add, Layer
from tensorflow.keras.layers import AveragePooling2D, GlobalAveragePooling2D, MaxPooling2D
from tensorflow.keras.layers import BatchNormalization

In [None]:
# helpers functions
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def calc_rel_pos(n):
    pos = tf.stack(meshgrid(tf.range(n), tf.range(n), indexing = 'ij'))
    pos = Rearrange('n i j -> (i j) n')(pos)             # [n*n, 2] pos[n] = (i, j)
    rel_pos = pos[None, :] - pos[:, None]                # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
    rel_pos += n - 1                                     # shift value range from [-n+1, n-1] to [0, 2n-2]
    return rel_pos

# lambda layer
class LambdaLayer(Layer):
    def __init__(
        self,
        *,
        dim_k,
        n = None,
        r = None,
        heads = 4,
        dim_out = None,
        dim_u = 1):
        super(LambdaLayer, self).__init__()
        '''
        Ref: https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/tfkeras.py
        '''

        self.out_dim = dim_out
        self.u = dim_u  # intra-depth dimension
        self.heads = heads

        assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
        self.dim_v = dim_out // heads
        self.dim_k = dim_k
        self.heads = heads

        self.to_q = Conv2D(self.dim_k * heads, 1, use_bias=False)
        self.to_k = Conv2D(self.dim_k * dim_u, 1, use_bias=False)
        self.to_v = Conv2D(self.dim_v * dim_u, 1, use_bias=False)

        self.norm_q = BatchNormalization()
        self.norm_v = BatchNormalization()

        self.local_contexts = exists(r)
        if exists(r):
            assert (r % 2) == 1, 'Receptive kernel size should be odd'
            self.pos_conv = Conv3D(dim_k, (1, r, r), padding='same')
        else:
            assert exists(n), 'You must specify the window length (n = h = w)'
            rel_length = 2 * n - 1
            self.rel_pos_emb = self.add_weight(name='pos_emb',
                                               shape=(rel_length, rel_length, dim_k, dim_u),
                                               initializer=initializers.random_normal,
                                               trainable=True)
            self.rel_pos = calc_rel_pos(n)

    def call(self, x, **kwargs):
        b, hh, ww, c, u, h = *x.get_shape().as_list(), self.u, self.heads

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = self.norm_q(q)
        v = self.norm_v(v)

        q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q)
        k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k)
        v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v)

        k = nn.softmax(k)

        Lc = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b n h v', q, Lc)

        if self.local_contexts:
            v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v)
            Lp = self.pos_conv(v)
            Lp = Rearrange('b v h w k -> b v k (h w)')(Lp)
            Yp = einsum('b h k n, b v k n -> b n h v', q, Lp)
        else:
            rel_pos_emb = tf.gather_nd(self.rel_pos_emb, self.rel_pos)
            Lp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b n h v', q, Lp)

        Y = Yc + Yp
        out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y)
        return out

    def compute_output_shape(self, input_shape):
        return (*input_shape[:2], self.out_dim)

    def get_config(self):
        config = {'output_dim': (*self.input_shape[:2], self.out_dim)}
        base_config = super(LambdaLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

## <font color = "seagreen">3. Dense Net 121</font>

For such pre-trained model, direct source code from `tf.keras.app` would be nice but for concised implementaion I will be using from [this](https://github.com/flyyufelix/DenseNet-Keras) un-official implementaion. But the main problem is the author last updated his repo four years ago. So, lots of codes are old and won't fit with today's API. We need to modify many things. 

Here I'm porting the required code with proper upgradation according to the recent API. I'm choosing `DenseNet 121`, feel free to take others. However we remember our end network

![LamdaDenseNet](https://user-images.githubusercontent.com/17668390/103923527-68532100-513f-11eb-97f5-d16806865f08.png)

Let's build piece by piece. 

### Dense Block

In [None]:
def dense_block(x, stage, nb_layers, nb_filter, 
                growth_rate, dropout_rate=None, weight_decay=1e-4, grow_nb_filters=True):
    ''' Build a dense_block where the output of each conv_block is fed to subsequent ones
        # Arguments
            x: input tensor
            stage: index for dense block
            nb_layers: the number of layers of conv_block to append to the model.
            nb_filter: number of filters
            growth_rate: growth rate
            dropout_rate: dropout rate
            weight_decay: weight decay factor
            grow_nb_filters: flag to decide to allow number of filters to grow
    '''

    eps = 1.1e-5
    concat_feat = x

    for i in range(nb_layers):
        branch = i+1
        x = conv_block(concat_feat, stage, branch, 
                       growth_rate, dropout_rate, weight_decay)
        concat_feat = concatenate([concat_feat, x], 
                                  axis=concat_axis, name='concat_'+str(stage)+'_'+str(branch))

        if grow_nb_filters:
            nb_filter += growth_rate

    return concat_feat, nb_filter

### Conv Block

In [None]:
def conv_block(x, stage, branch, nb_filter, dropout_rate=None, weight_decay=1e-4):
    '''Apply BatchNorm, Relu, bottleneck 1x1 Conv2D, 3x3 Conv2D, and option dropout
        # Arguments
            x: input tensor 
            stage: index for dense block
            branch: layer index within each dense block
            nb_filter: number of filters
            dropout_rate: dropout rate
            weight_decay: weight decay factor
    '''
    eps = 1.1e-5
    conv_name_base = 'conv' + str(stage) + '_' + str(branch)
    relu_name_base = 'relu' + str(stage) + '_' + str(branch)

    # 1x1 Convolution (Bottleneck layer)
    inter_channel = nb_filter * 4  
    x = BatchNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base+'_x1_bn')(x)
    x = Scale(axis=concat_axis, name=conv_name_base+'_x1_scale')(x)
    x = Activation('relu', name=relu_name_base+'_x1')(x)
    x = Conv2D(inter_channel, 1, 1, name=conv_name_base+'_x1', use_bias=False)(x)

    if dropout_rate:
        x = Dropout(dropout_rate)(x)

    # 3x3 Convolution
    x = BatchNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base+'_x2_bn')(x)
    x = Scale(axis=concat_axis, name=conv_name_base+'_x2_scale')(x)
    x = Activation('relu', name=relu_name_base+'_x2')(x)
    x = Conv2D(nb_filter, 1, 1, name=conv_name_base+'_x2', use_bias=False)(x) # 3, 3

    if dropout_rate:
        x = Dropout(dropout_rate)(x)

    return x

### Transition Block

In [None]:
def transition_block(x, stage, nb_filter, compression=1.0, dropout_rate=None, weight_decay=1E-4):
    ''' Apply BatchNorm, 1x1 Convolution, averagePooling, optional compression, dropout 
        # Arguments
            x: input tensor
            stage: index for dense block
            nb_filter: number of filters
            compression: calculated as 1 - reduction. Reduces the number of feature maps in the transition block.
            dropout_rate: dropout rate
            weight_decay: weight decay factor
    '''
    eps = 1.1e-5
    conv_name_base = 'conv' + str(stage) + '_blk'
    relu_name_base = 'relu' + str(stage) + '_blk'
    pool_name_base = 'pool' + str(stage) 

    x = BatchNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base+'_bn')(x)
    x = Scale(axis=concat_axis, name=conv_name_base+'_scale')(x)
    x = Activation('relu', name=relu_name_base)(x)
    x = Conv2D(int(nb_filter * compression), 1, 1, name=conv_name_base, use_bias=False)(x)

    if dropout_rate:
        x = Dropout(dropout_rate)(x)

    x = AveragePooling2D((2, 2), strides=(2, 2), name=pool_name_base)(x)

    return x

**A Custom Layer for DenseNet used for BatchNormalization**

In [None]:
class Scale(Layer):
    '''Custom Layer for DenseNet used for BatchNormalization.
    
    Learns a set of weights and biases used for scaling the input data.
    the output consists simply in an element-wise multiplication of the input
    and a sum of a set of constants:
        out = in * gamma + beta,
    where 'gamma' and 'beta' are the weights and biases larned.
    '''
    def __init__(self, weights=None, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs):
        self.momentum = momentum
        self.axis = axis
        self.beta_init = initializations.get(beta_init)
        self.gamma_init = initializations.get(gamma_init)
        self.initial_weights = weights
        super(Scale, self).__init__(**kwargs)

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = (int(input_shape[self.axis]),)

        # Tensorflow >= 1.0.0 compatibility
        self.gamma = K.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name))
        self.beta = K.variable(self.beta_init(shape), name='{}_beta'.format(self.name))
        #self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name))
        #self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
        self._trainable_weights = [self.gamma, self.beta]

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

    def call(self, x, mask=None):
        input_shape = self.input_spec[0].shape
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]
        out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape)
        return out

    def get_config(self):
        config = {"momentum": self.momentum, "axis": self.axis}
        base_config = super(Scale, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

## <font color = "seagreen">MixDepthConvLamdaDenseNet</font>

Let's combine the above chunks and build the whole model. In the following function we will add `Dense-Block + Lamda-Layer + Transition-Layer` iteratively; just like shown in the below diagram.

![LamdaDenseNet](https://user-images.githubusercontent.com/17668390/103923527-68532100-513f-11eb-97f5-d16806865f08.png)

In [None]:
def MixDepthConvLamdaDenseNet(image_size, 
                              nb_dense_block=3, # should 4 but set 3 for fast prototype 
                              growth_rate=32, 
                              nb_filter=64, reduction=0.0, 
                              dropout_rate=0.0, weight_decay=1e-4, 
                              classes=1000, weights_path=None):
    '''Instantiate the DenseNet 121 architecture,
        # Arguments
            nb_dense_block: number of dense blocks to add to end
            growth_rate: number of filters to add per dense block
            nb_filter: initial number of filters
            reduction: reduction factor of transition blocks.
            dropout_rate: dropout rate
            weight_decay: weight decay factor
            classes: optional number of classes to classify images
            weights_path: path to pre-trained weights
        # Returns
            A Keras model instance.
    '''
    eps = 1.1e-5

    # compute compression factor
    compression = 1.0 - reduction

    # Handle Dimension Ordering for different backends
    global concat_axis
    if K.image_data_format() == 'channels_last':
        concat_axis = -1
        img_input = Input(shape=(image_size, image_size, 3), name='data')
    else:
        concat_axis = 1
        img_input = Input(shape=(3, image_size, image_size), name='data')

    # From architecture for ImageNet (Table 1 in the paper)
    nb_filter = 64
    nb_layers = [6,12,24,16] # For DenseNet-121

    # Initial convolution
    x = MixDepthGroupConvolution2D(kernels=[3,5,7])(img_input)
    x = ZeroPadding2D((3, 3), name='conv1_zeropadding')(x)
    x = Conv2D(nb_filter, 7, 2, name='conv1', use_bias=False)(x)
    x = BatchNormalization(epsilon=eps, axis=concat_axis, name='conv1_bn')(x)
    x = Scale(axis=concat_axis, name='conv1_scale')(x)
    x = Activation('relu', name='relu1')(x)
    x = ZeroPadding2D((1, 1), name='pool1_zeropadding')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), name='pool1')(x)

    # Add dense blocks
    for block_idx in range(nb_dense_block - 1):
        stage = block_idx+2
        x, nb_filter = dense_block(x, stage, nb_layers[block_idx], nb_filter, 
                                   growth_rate, dropout_rate=dropout_rate, weight_decay=weight_decay)
        
        # add lamda layer
        x = LambdaLayer(
            dim_out = nb_filter, # channels out
            r = 23,       # the receptive field for relative positional encoding (23 x 23)
            dim_k = 16,   # key dimension
            heads = 8,   # number of heads, for multi-query; values dimension must be divisible by number of heads for multi-head query
            dim_u = 1     # 'intra-depth' dimension
        )(x)

        # Add transition_block
        x = transition_block(x, stage, nb_filter, compression=compression, 
                             dropout_rate=dropout_rate, weight_decay=weight_decay)
        nb_filter = int(nb_filter * compression)

    final_stage = stage + 1
    x, nb_filter = dense_block(x, final_stage, nb_layers[-1], nb_filter, 
                               growth_rate, dropout_rate=dropout_rate, weight_decay=weight_decay)

    x = BatchNormalization(epsilon=eps, axis=concat_axis, name='conv'+str(final_stage)+'_blk_bn')(x)
    x = Scale(axis=concat_axis, name='conv'+str(final_stage)+'_blk_scale')(x)
    x = Activation('relu', name='relu'+str(final_stage)+'_blk')(x)
    x = GlobalAveragePooling2D(name='pool'+str(final_stage))(x)
 
    x = Dense(classes, name='fc6')(x)
    x = Activation('sigmoid', name='prob')(x)

    model = Model(img_input, x, name='densenet')

    return model

In [None]:
with strategy.scope():
    model = MixDepthConvLamdaDenseNet(
        image_size=im_size,
        reduction=0.5, 
        classes=labels.shape[1])
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
        loss='binary_crossentropy',
        metrics=[tf.keras.metrics.AUC(multi_label=True)])

In [None]:
import time

class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)
        

def set_callback():
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        'model.h5', save_best_only=True, 
        save_weights_only=True,
        monitor='val_auc', mode='max')

    lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_auc", patience=3, min_lr=1e-6, mode='max')

    csv_loger = tf.keras.callbacks.CSVLogger('his.csv')
    
    time_his = TimeHistory()
    
    return [checkpoint, lr_reducer, csv_loger, time_his]

In [None]:
# print out the model params
trainable_count = np.sum([K.count_params(w) \
                          for w in model.trainable_weights]) 
non_trainable_count = np.sum([K.count_params(w) \
                              for w in model.non_trainable_weights])
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

steps_per_epoch = train_paths.shape[0] // BATCH_SIZE

model.fit(
    train_dataset, 
    epochs=10,
    verbose=1,
    callbacks=set_callback(),
    steps_per_epoch=steps_per_epoch,
    validation_data=valid_dataset)

In [None]:
history = pd.read_csv('his.csv') 

# find the lowest validation loss score
print(history.loc[history['val_auc'].idxmin()])
history.head()

In [None]:
plt.figure(figsize=(19,6))

plt.subplot(121)
plt.plot(history.epoch, history.loss, label="loss")
plt.plot(history.epoch, history.val_loss, label="val_loss")
plt.legend()


plt.subplot(122)
plt.plot(history.epoch, history.auc, label="auc")
plt.plot(history.epoch, history.val_auc, label="val_auc")
plt.legend()

plt.show()

# Model interpretability with Integrated Gradients

In [None]:
def get_img_array(img_path, size=(300, 300)):
    # `img` is a PIL image of size 300
    img = tf.keras.preprocessing.image.load_img(img_path, target_size=size)
    # `array` is a float32 Numpy array of shape (300, 300, 3)
    array = tf.keras.preprocessing.image.img_to_array(img)
    # We add a dimension to transform our array into a "batch"
    # of size (1, 300, 300, 3)
    array = np.expand_dims(array, axis=0)
    return array

In [None]:
from tensorflow.keras.applications import densenet

def get_gradients(img_input, top_pred_idx):
    """Computes the gradients of outputs w.r.t input image.

    Args:
        img_input: 4D image tensor
        top_pred_idx: Predicted label for the input image

    Returns:
        Gradients of the predictions w.r.t img_input
    """
    images = tf.cast(img_input, tf.float32)

    with tf.GradientTape() as tape:
        tape.watch(images)
        preds = model(images)
        top_class = preds[:, top_pred_idx]

    grads = tape.gradient(top_class, images)
    return grads


def get_integrated_gradients(img_input, top_pred_idx, baseline=None, num_steps=50):
    """Computes Integrated Gradients for a predicted label.

    Args:
        img_input (ndarray): Original image
        top_pred_idx: Predicted label for the input image
        baseline (ndarray): The baseline image to start with for interpolation
        num_steps: Number of interpolation steps between the baseline
            and the input used in the computation of integrated gradients. These
            steps along determine the integral approximation error. By default,
            num_steps is set to 50.

    Returns:
        Integrated gradients w.r.t input image
    """
    # If baseline is not provided, start with a black image
    # having same size as the input image.
    if baseline is None:
        baseline = np.zeros(img_size).astype(np.float32)
    else:
        baseline = baseline.astype(np.float32)

    # 1. Do interpolation.
    img_input = img_input.astype(np.float32)
    interpolated_image = [
        baseline + (step / num_steps) * (img_input - baseline)
        for step in range(num_steps + 1)
    ]
    interpolated_image = np.array(interpolated_image).astype(np.float32)

    # 2. Preprocess the interpolated images
    interpolated_image = densenet.preprocess_input(interpolated_image)

    # 3. Get the gradients
    grads = []
    for i, img in enumerate(interpolated_image):
        img = tf.expand_dims(img, axis=0)
        grad = get_gradients(img, top_pred_idx=top_pred_idx)
        grads.append(grad[0])
    grads = tf.convert_to_tensor(grads, dtype=tf.float32)

    # 4. Approximate the integral using the trapezoidal rule
    grads = (grads[:-1] + grads[1:]) / 2.0
    avg_grads = tf.reduce_mean(grads, axis=0)

    # 5. Calculate integrated gradients and return
    integrated_grads = (img_input - baseline) * avg_grads
    return integrated_grads

In [None]:
def random_baseline_integrated_gradients(
    img_input, top_pred_idx, num_steps=50, num_runs=2
):
    """Generates a number of random baseline images.

    Args:
        img_input (ndarray): 3D image
        top_pred_idx: Predicted label for the input image
        num_steps: Number of interpolation steps between the baseline
            and the input used in the computation of integrated gradients. These
            steps along determine the integral approximation error. By default,
            num_steps is set to 50.
        num_runs: number of baseline images to generate

    Returns:
        Averaged integrated gradients for `num_runs` baseline images
    """
    # 1. List to keep track of Integrated Gradients (IG) for all the images
    integrated_grads = []

    # 2. Get the integrated gradients for all the baselines
    for run in range(num_runs):
        baseline = np.random.random(img_size) * 255
        igrads = get_integrated_gradients(
            img_input=img_input,
            top_pred_idx=top_pred_idx,
            baseline=baseline,
            num_steps=num_steps,
        )
        integrated_grads.append(igrads)

    # 3. Return the average integrated gradients for the image
    integrated_grads = tf.convert_to_tensor(integrated_grads)
    return tf.reduce_mean(integrated_grads, axis=0)

**Gradient Visualization**

In [None]:
class GradVisualizer:
    """Plot gradients of the outputs w.r.t an input image."""

    def __init__(self, positive_channel=None, negative_channel=None):
        if positive_channel is None:
            self.positive_channel = [0, 255, 0]
        else:
            self.positive_channel = positive_channel

        if negative_channel is None:
            self.negative_channel = [255, 0, 0]
        else:
            self.negative_channel = negative_channel

    def apply_polarity(self, attributions, polarity):
        if polarity == "positive":
            return np.clip(attributions, 0, 1)
        else:
            return np.clip(attributions, -1, 0)

    def apply_linear_transformation(
        self,
        attributions,
        clip_above_percentile=99.9,
        clip_below_percentile=70.0,
        lower_end=0.2,
    ):
        # 1. Get the thresholds
        m = self.get_thresholded_attributions(
            attributions, percentage=100 - clip_above_percentile
        )
        e = self.get_thresholded_attributions(
            attributions, percentage=100 - clip_below_percentile
        )

        # 2. Transform the attributions by a linear function f(x) = a*x + b such that
        # f(m) = 1.0 and f(e) = lower_end
        transformed_attributions = (1 - lower_end) * (np.abs(attributions) - e) / (
            m - e
        ) + lower_end

        # 3. Make sure that the sign of transformed attributions is the same as original attributions
        transformed_attributions *= np.sign(attributions)

        # 4. Only keep values that are bigger than the lower_end
        transformed_attributions *= transformed_attributions >= lower_end

        # 5. Clip values and return
        transformed_attributions = np.clip(transformed_attributions, 0.0, 1.0)
        return transformed_attributions

    def get_thresholded_attributions(self, attributions, percentage):
        if percentage == 100.0:
            return np.min(attributions)

        # 1. Flatten the attributions
        flatten_attr = attributions.flatten()

        # 2. Get the sum of the attributions
        total = np.sum(flatten_attr)

        # 3. Sort the attributions from largest to smallest.
        sorted_attributions = np.sort(np.abs(flatten_attr))[::-1]

        # 4. Calculate the percentage of the total sum that each attribution
        # and the values about it contribute.
        cum_sum = 100.0 * np.cumsum(sorted_attributions) / total

        # 5. Threshold the attributions by the percentage
        indices_to_consider = np.where(cum_sum >= percentage)[0][0]

        # 6. Select the desired attributions and return
        attributions = sorted_attributions[indices_to_consider]
        return attributions

    def binarize(self, attributions, threshold=0.001):
        return attributions > threshold

    def morphological_cleanup_fn(self, attributions, structure=np.ones((4, 4))):
        closed = ndimage.grey_closing(attributions, structure=structure)
        opened = ndimage.grey_opening(closed, structure=structure)
        return opened

    def draw_outlines(
        self, attributions, percentage=90, connected_component_structure=np.ones((3, 3))
    ):
        # 1. Binarize the attributions.
        attributions = self.binarize(attributions)

        # 2. Fill the gaps
        attributions = ndimage.binary_fill_holes(attributions)

        # 3. Compute connected components
        connected_components, num_comp = ndimage.measurements.label(
            attributions, structure=connected_component_structure
        )

        # 4. Sum up the attributions for each component
        total = np.sum(attributions[connected_components > 0])
        component_sums = []
        for comp in range(1, num_comp + 1):
            mask = connected_components == comp
            component_sum = np.sum(attributions[mask])
            component_sums.append((component_sum, mask))

        # 5. Compute the percentage of top components to keep
        sorted_sums_and_masks = sorted(component_sums, key=lambda x: x[0], reverse=True)
        sorted_sums = list(zip(*sorted_sums_and_masks))[0]
        cumulative_sorted_sums = np.cumsum(sorted_sums)
        cutoff_threshold = percentage * total / 100
        cutoff_idx = np.where(cumulative_sorted_sums >= cutoff_threshold)[0][0]
        if cutoff_idx > 2:
            cutoff_idx = 2

        # 6. Set the values for the kept components
        border_mask = np.zeros_like(attributions)
        for i in range(cutoff_idx + 1):
            border_mask[sorted_sums_and_masks[i][1]] = 1

        # 7. Make the mask hollow and show only the border
        eroded_mask = ndimage.binary_erosion(border_mask, iterations=1)
        border_mask[eroded_mask] = 0

        # 8. Return the outlined mask
        return border_mask

    def process_grads(
        self,
        image,
        attributions,
        polarity="positive",
        clip_above_percentile=99.9,
        clip_below_percentile=0,
        morphological_cleanup=False,
        structure=np.ones((3, 3)),
        outlines=False,
        outlines_component_percentage=90,
        overlay=True,
    ):
        if polarity not in ["positive", "negative"]:
            raise ValueError(
                f""" Allowed polarity values: 'positive' or 'negative'
                                    but provided {polarity}"""
            )
        if clip_above_percentile < 0 or clip_above_percentile > 100:
            raise ValueError("clip_above_percentile must be in [0, 100]")

        if clip_below_percentile < 0 or clip_below_percentile > 100:
            raise ValueError("clip_below_percentile must be in [0, 100]")

        # 1. Apply polarity
        if polarity == "positive":
            attributions = self.apply_polarity(attributions, polarity=polarity)
            channel = self.positive_channel
        else:
            attributions = self.apply_polarity(attributions, polarity=polarity)
            attributions = np.abs(attributions)
            channel = self.negative_channel

        # 2. Take average over the channels
        attributions = np.average(attributions, axis=2)

        # 3. Apply linear transformation to the attributions
        attributions = self.apply_linear_transformation(
            attributions,
            clip_above_percentile=clip_above_percentile,
            clip_below_percentile=clip_below_percentile,
            lower_end=0.0,
        )

        # 4. Cleanup
        if morphological_cleanup:
            attributions = self.morphological_cleanup_fn(
                attributions, structure=structure
            )
        # 5. Draw the outlines
        if outlines:
            attributions = self.draw_outlines(
                attributions, percentage=outlines_component_percentage
            )

        # 6. Expand the channel axis and convert to RGB
        attributions = np.expand_dims(attributions, 2) * channel

        # 7.Superimpose on the original image
        if overlay:
            attributions = np.clip((attributions * 0.8 + image), 0, 255)
        return attributions

    def visualize(
        self,
        image,
        gradients,
        integrated_gradients,
        polarity="positive",
        clip_above_percentile=99.9,
        clip_below_percentile=0,
        morphological_cleanup=False,
        structure=np.ones((3, 3)),
        outlines=False,
        outlines_component_percentage=90,
        overlay=True,
        figsize=(15, 8),
    ):
        # 1. Make two copies of the original image
        img1 = np.copy(image)
        img2 = np.copy(image)

        # 2. Process the normal gradients
        grads_attr = self.process_grads(
            image=img1,
            attributions=gradients,
            polarity=polarity,
            clip_above_percentile=clip_above_percentile,
            clip_below_percentile=clip_below_percentile,
            morphological_cleanup=morphological_cleanup,
            structure=structure,
            outlines=outlines,
            outlines_component_percentage=outlines_component_percentage,
            overlay=overlay,
        )

        # 3. Process the integrated gradients
        igrads_attr = self.process_grads(
            image=img2,
            attributions=integrated_gradients,
            polarity=polarity,
            clip_above_percentile=clip_above_percentile,
            clip_below_percentile=clip_below_percentile,
            morphological_cleanup=morphological_cleanup,
            structure=structure,
            outlines=outlines,
            outlines_component_percentage=outlines_component_percentage,
            overlay=overlay,
        )

        _, ax = plt.subplots(1, 3, figsize=figsize)
        ax[0].imshow(image)
        ax[1].imshow(grads_attr.astype(np.uint8))
        ax[2].imshow(igrads_attr.astype(np.uint8))

        ax[0].set_title("Input")
        ax[1].set_title("Normal gradients")
        ax[2].set_title("Integrated gradients")
        plt.show()

In [None]:
img_path = [
    '../input/ranzcr-clip-catheter-line-classification/test/1.2.826.0.1.3680043.8.498.10003659706701445041816900371598078663.jpg',
    '../input/ranzcr-clip-catheter-line-classification/test/1.2.826.0.1.3680043.8.498.10003890246067211044742686138544513464.jpg'
]

# 1. Convert the image to numpy array
img = get_img_array(img_path[0])

# 2. Keep a copy of the original image
orig_img = np.copy(img[0]).astype(np.uint8)

# 3. Preprocess the image
img_processed = tf.cast(densenet.preprocess_input(img), dtype=tf.float32)

# 4. Get model predictions
preds = model.predict(img_processed)

In [None]:
img_size = (300, 300, 3)

top_pred_idx = tf.argmax(preds[0])

# 5. Get the gradients of the last layer for the predicted label
grads = get_gradients(img_processed, top_pred_idx=top_pred_idx)

# 6. Get the integrated gradients
igrads = random_baseline_integrated_gradients(
    np.copy(orig_img), top_pred_idx=top_pred_idx, num_steps=50, num_runs=2
)