# **Diabetic retinopathy detection**
## **Author: [Dr. Rahul Remanan](https://www.linkedin.com/in/rahulremanan/)**
## **CEO, [Moad Computer](http://www.moad.computer/)**

# Configuration

In [None]:
ROOT_DIR = '/kaggle/input'

In [None]:
class CONFIG():
  NOTEBOOK_ID = 'diabetic-retinopathy-detection'
  TRAIN_CSV = 'trainLabels.csv'

  ENABLE_TRAINING = False
  
  BACKBONE = 'EfficientNetB0'

  PRE_TRAINED_WEIGHTS = 'imagenet'

  IMAGE_SIZE = (256, 256) # (512, 512)

  EPOCHS = 1 # 20 #
  
  HORIZONTAL_FLIP = True
  VERTICAL_FLIP = True
  RANDOM_BRIGHTNESS = True
  RANDOM_SATURATION = True
  RANDOM_GAMMA = False
  RANDOM_HUE = True
  RANDOM_CONTRAST = True

  TRAIN_ATTN_CONV = True
    
  BATCH_SIZE = 24 # 16 #
    
  SHUFFLE_BUFFER = max(BATCH_SIZE*25, 500) #

  DROPOUT = 0.35

  MODEL_SUMMARY = 'summary' # 'plot' #
    
  SAVED_WEIGHTS_DIR = f'{ROOT_DIR}/diabetic-retinopathy-detection-weights'
  SAVED_WEIGHTS = 'output'
  SAVED_ATTN_WEIGHTS = 'output_attn'  
    
  OUT_WEIGHTS = 'output.h5'  
  OUT_ATTN_WEIGHTS = 'output_attn.h5'
  
  VERBOSE = True

# Imports

In [None]:
import os, gc, cv2, numpy as np, pandas as pd, tensorflow as tf, \
       tensorflow_addons as tfa,matplotlib.pyplot as plt

from glob import glob
from skimage.io import imread
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from keras import backend as K
from keras.applications.inception_v3 import preprocess_input

%matplotlib inline 

# Manage training data

In [None]:
zip_dir = os.path.join(ROOT_DIR, CONFIG.NOTEBOOK_ID)
train_images_dir = '/tmp/train/'
train_csv_dir = './'
base_image_dir = os.path.join(train_images_dir, 'train')

## Helper function to run OS commands from within Python

In [None]:
def linux_shell(cmd):
  os.system(cmd)

## Unzip the training data

In [None]:
%%capture
if len(glob(f'{base_image_dir}/*.jpeg')) != 35126:
  cmds= ['apt-get install -y p7zip-full',
         f'mkdir {train_images_dir}',
         f'7z x {zip_dir}/train.zip.001 -o{train_images_dir}',
         f'7z x {zip_dir}/trainLabels.csv.zip -o{train_csv_dir}',]
  for c in cmds:
    linux_shell(c)

## Create and process the training dataframe

In [None]:
retina_df = pd.read_csv(os.path.join(train_csv_dir, CONFIG.TRAIN_CSV))

In [None]:
print(len(glob(f'{train_images_dir}/train/*.jpeg')))
print(len(retina_df))

In [None]:
retina_df['PatientId'] = retina_df['image'].map(lambda x: x.split('_')[0])
retina_df['path'] = retina_df['image'].map(
                      lambda x: os.path.join(base_image_dir,
                                             '{}.jpeg'.format(x))
                      )
retina_df['exists'] = retina_df['path'].map(os.path.exists)
retina_df['eye'] = retina_df['image'].map(
                     lambda x: 1 if x.split('_')[-1]=='left' else 0
                     )

retina_df['level_cat'] = retina_df['level'].map(
                           lambda x: to_categorical(x, 1+retina_df['level'].max())
                           )

retina_df.dropna(inplace = True)
retina_df = retina_df[retina_df['exists']]

In [None]:
print(retina_df['exists'].sum(), 'images found of', retina_df.shape[0], 'total')

In [None]:
display(retina_df.sample(3))

## Histogram summary of the training dataset

In [None]:
retina_df[['level', 'eye']].hist(figsize = (10, 5))

# Train-validation split

In [None]:
rr_df = retina_df[['PatientId', 'level']].drop_duplicates()
train_ids, valid_ids = train_test_split(
                         rr_df['PatientId'], 
                         test_size = 0.25, 
                         random_state = 2018, 
                         stratify = rr_df['level']
                         )
raw_train_df = retina_df[retina_df['PatientId'].isin(train_ids)]
val_df = retina_df[retina_df['PatientId'].isin(valid_ids)]

In [None]:
print('train', raw_train_df.shape[0], 'validation', val_df.shape[0])

# Balance the training data

In [None]:
train_df = raw_train_df.groupby(['level', 'eye']).apply(
             lambda x: x.sample(75, replace = True)
             ).reset_index(drop = True)
train_df[['level', 'eye']].hist(figsize = (10, 5))

In [None]:
print('New data size:', train_df.shape[0], 'Old Size:', raw_train_df.shape[0])

# Train-validation data generators

In [None]:
@tf.function
def tf_load_image(path)->tf.Tensor:
  """ Load an image with the correct shape using only TF
    
  Args:
      path (tf.string): Path to the image to be loaded
      resize_to (tuple, optional): Size to reshape image
    
  Returns:
      3 channel tf.Constant image ready for training/inference
  """
  img_bytes = tf.io.read_file(path)
  img = tf.image.decode_jpeg(img_bytes, channels=3)
  img = 255.*((img-tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img)))
  img = tf.image.resize(img, (tf.constant(CONFIG.IMAGE_SIZE[0]), 
                              tf.constant(CONFIG.IMAGE_SIZE[1])))
  return img

@tf.function
def tf_labels(y:tf.Tensor)->tf.Tensor:  
  return tf.convert_to_tensor([y], tf.float32)

@tf.function
def tf_img(img:tf.Tensor)->tf.Tensor:
  return img

@tf.function
def tf_pair_cond(img, true_fn, false_fn)->tf.Tensor:
  p = tf.random.uniform([])<=tf.constant(0.5)   
  img = tf.cond(
           p, 
           lambda: tf.image.flip_left_right(img), 
           lambda: tf_img(img)
           )
  return img

@tf.function
def tf_augment_batch(img:tf.Tensor, y:tf.Tensor)->tf.Tensor: 
  if CONFIG.HORIZONTAL_FLIP:
    img = tf_pair_cond(img, tf.image.flip_left_right, tf_img)
  if CONFIG.VERTICAL_FLIP:
    img = tf_pair_cond(img, tf.image.flip_up_down, tf_img)
  if CONFIG.RANDOM_BRIGHTNESS:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.random_brightness(img, 0.1), 
           lambda: tf_img(img)
           )
  if CONFIG.RANDOM_CONTRAST:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.random_contrast(img, 0.1, 0.125), 
           lambda: tf_img(img)
           )
  if CONFIG.RANDOM_GAMMA:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.adjust_gamma(img, 1e-6), 
           lambda: tf_img(img)
           )
  if CONFIG.RANDOM_HUE:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.random_hue(img, 0.1), 
           lambda: tf_img(img)
           )
  if CONFIG.RANDOM_SATURATION:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.random_saturation(img, 0.475, 0.525), 
           lambda: tf_img(img)
           )
  return img, y

@tf.function
def tf_model_preprocessing_train(img:tf.Tensor, y:tf.Tensor)->tf.Tensor:
  img = img/tf.constant(127.5)-tf.constant(1.0)
  return img, y

@tf.function
def tf_model_preprocessing_test(img:tf.Tensor, y:tf.Tensor)->tf.Tensor:
  img = img/tf.constant(127.5)-tf.constant(1.0)
  return img, y

In [None]:
def df_make_dataset(df, autotune=None):
  ds = tf.data.Dataset.from_tensor_slices((df.path, df.level_cat.to_list()))
  ds = ds.map(lambda x,y: (tf_load_image(x), tf_labels(y)), num_parallel_calls=autotune)
  return ds

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

## Create train data generator

In [None]:
train_ds = df_make_dataset(train_df, autotune=AUTOTUNE)

train_ds = train_ds.shuffle(len(train_df))           \
                   .batch(CONFIG.BATCH_SIZE, 
                          drop_remainder=True)       \
                   .map(tf_augment_batch, 
                        num_parallel_calls=AUTOTUNE) \
                   .map(tf_model_preprocessing_train, 
                        num_parallel_calls=AUTOTUNE) \
                   .prefetch(AUTOTUNE)
del train_df

## Create validation data generator

In [None]:
val_ds = df_make_dataset(val_df, autotune=AUTOTUNE)
val_ds = val_ds.shuffle(len(val_df)//5)               \
               .batch(CONFIG.BATCH_SIZE//2, 
                      drop_remainder=True)            \
               .map(tf_model_preprocessing_train, 
                    num_parallel_calls=AUTOTUNE)      \
               .prefetch(AUTOTUNE)
del val_df

# Visualize the data generator outputs

In [None]:
def plot_data(x, y, fig_size=(16, 8)):
  fig, m_axs = plt.subplots(2, 4, figsize=fig_size)
  for (c_x, c_y, c_ax) in zip(x, y, m_axs.flatten()):
    c_ax.imshow(np.clip(c_x*127+127, 0, 255).astype(np.uint8))
    c_ax.set_title('Severity {}'.format(np.argmax(c_y, -1)))
    c_ax.axis('off')

## Visualize train data generator

In [None]:
for t_x, t_y in train_ds.take(1):
  plot_data(t_x, t_y)

## Visualize validation data generator

In [None]:
for v_x, v_y in val_ds.take(1):
  plot_data(v_x, v_y)

In [None]:
print('Input shape: ', t_x.shape[1:], '\nNumber of classes:', t_y.shape[-1])

# Attention mechanism

In [None]:
def conv_attention(x:tf.Tensor, 
                   base_model, 
                   filter_dim:int=8, 
                   batch_size:int=8, 
                   dropout:int=0.2, 
                   padding:str='same', 
                   name:str='conv_attn', 
                   activation:str='relu',
                   train_output_conv:bool=False, 
                   eager_execution:bool=True)->tf.Tensor:
  attn = tf.keras.layers.Dropout(dropout, name=f'{name}_attn_dropout_in')(x)
  for i, ff in enumerate([8, 2, 1]):
    attn = tf.keras.layers.Conv2D(ff*filter_dim, 
                                  kernel_size=(1,1), 
                                  padding=padding, 
                                  activation=activation,
                                  name=f'{name}_attn_conv2D_{ff*filter_dim}')(attn)
    attn = tf.keras.layers.Dropout(dropout, name=f'{name}_attn_dropout_{i}')(attn) 
  attn = tf.keras.layers.Conv2D(1, kernel_size=(1,1), 
                                padding='valid', 
                                activation='sigmoid',
                                name=f'{name}_attn_conv2D_1')(attn)
  attn = tf.keras.layers.Dropout(dropout, name=f'{name}_attn_dropout_4')(attn)  
  
  base_depth = base_model.get_output_shape_at(0)[-1]
  if eager_execution:
    up_conv_wt = tf.ones((1, 1, 1, base_depth))
  else:
    up_conv_wt = np.ones((1, 1, 1, base_depth))

  up_conv_2D = tf.keras.layers.Conv2D(base_depth, kernel_size=(1,1), padding='same', 
                                      activation='linear', use_bias=False, 
                                      weights=[up_conv_wt], name=f'{name}_attn_final')
  up_conv_2D.trainable = train_output_conv
  attn = up_conv_2D(attn)
  out_feat = tf.keras.layers.multiply([attn, x], name=f'{name}_attn_out')
  return out_feat, attn

In [None]:
def dense_attention(x:tf.Tensor, 
                    dropout:int=0.2, 
                    activation:str='relu', 
                    name:str='dense_attn',
                    train_output_conv:bool=False, 
                    eager_execution:bool=True)->tf.Tensor:
  x = tf.keras.layers.Dropout(dropout, name=f'{name}_attn_dropout_in')(x) 
  attn = tf.keras.layers.Dense(x.shape[1]*x.shape[2]*x.shape[3], 
                               activation=activation, name=f'{name}_attn_dense')(x)
  attn = tf.keras.layers.Reshape((x.shape[1], x.shape[2], x.shape[3]),
                                 name=f'{name}_attn_reshape')(x)
  out_feat = tf.keras.layers.multiply([attn, x], name=f'{name}_attn_out')
  return out_feat, attn

# Diabetic retinopathy classifier

## Diabetic retinopathy classifier without the attention mechanism

In [None]:
def DiabeticRetinopathyClassifier(input_shape, 
                                  num_classes, 
                                  model, 
                                  fc_size=128, 
                                  batch_size=8,
                                  dropout=0.25, 
                                  pre_trained_weights=None, 
                                  activation='relu', 
                                  name='retina_model', 
                                  train_all:bool=True, 
                                  train_output_conv:bool=False, 
                                  eager_execution:bool=True,
                                  feature_dense:bool=True):
  base_model = model(input_shape=input_shape, include_top=False, weights=pre_trained_weights)
  base_model.trainable = False if (pre_trained_weights is not None or train_all) else True
  inp = tf.keras.layers.Input(input_shape)
  x = base_model(inp)
  x = tf.keras.layers.Dropout(dropout, name=f'{name}_enc_dropout')(x) 
  x = tf.keras.layers.BatchNormalization(name=f'{name}_bn_inp')(x) 
  x = tf.keras.layers.Dropout(dropout, name=f'{name}_bn_dropout')(x)
  gap = tf.keras.layers.GlobalAveragePooling2D(name=f'{name}_gap_')(x)
  x = tf.keras.layers.Dropout(dropout, name=f'{name}_gap_dropout')(gap)
  if feature_dense:
    ft = tf.keras.layers.Dense(num_classes, activation=activation, name=f'{name}_feat')(gap)
    x = tf.keras.layers.Dropout(dropout, name=f'{name}_feat_dropout')(ft)
  if x.shape[1]==num_classes:  
    x = tf.expand_dims(x, axis=1, name=f'{name}_reshape')  
  out = tf.keras.layers.Dense(num_classes, activation='softmax', name=f'{name}_output')(x)
  return tf.keras.Model(inputs=[inp], outputs=[out])

## Diabetic retinopathy classifier with attention mechanism

In [None]:
def DiabeticRetinopathyAttnClassifier(input_shape, 
                                      num_classes, 
                                      model, 
                                      fc_size=128,
                                      batch_size=8, 
                                      dropout=0.25, 
                                      pre_trained_weights=None, 
                                      activation='relu', 
                                      name='retina_model', 
                                      train_all=True,
                                      train_output_conv=False, 
                                      eager_execution=True):
  base_model = model(input_shape=input_shape, include_top=False, weights=pre_trained_weights)
  base_model.trainable = False if (pre_trained_weights is not None or train_all) else True
  inp = tf.keras.layers.Input(input_shape)
  x = base_model(inp)
  x = tf.keras.layers.BatchNormalization(name=f'{name}_bn_inp')(x)
  a, o = conv_attention(x, base_model, filter_dim=8, batch_size=batch_size,
                        padding='same', name=name, activation=activation,
                        dropout=dropout, train_output_conv=train_output_conv,
                        eager_execution=eager_execution)
  gap_feat = tf.keras.layers.GlobalAveragePooling2D(name=f'{name}_gap_feat')(o)
  gap_mask = tf.keras.layers.GlobalAveragePooling2D(name=f'{name}_gap_mask')(a)
  gap_feat = tf.keras.layers.Reshape((1,gap_mask.shape[1]),
                                     name=f'{name}_reshape_feat')(gap_feat)   
  gap_mask = tf.keras.layers.Reshape((1,gap_mask.shape[1]),
                                     name=f'{name}_reshape_mask')(gap_mask) 
  gap = tf.keras.layers.Lambda(lambda x: x[0]/x[1], 
                               name=f'{name}_gap_rescale')([gap_feat, gap_mask])
  gap = tf.keras.layers.Dropout(dropout, name=f'{name}_gap_dropout')(gap)
  fc = tf.keras.layers.Dense(fc_size, activation=activation,
                             name=f'{name}_fc_dense_{fc_size}')(gap)
  fc = tf.keras.layers.Dropout(dropout, name=f'{name}_fc_dropout')(fc)
  out = tf.keras.layers.Dense(num_classes, activation='softmax', name=f'{name}_output')(fc)
  return tf.keras.Model(inputs=[inp], outputs=[out])

In [None]:
encoder = getattr(tf.keras.applications, CONFIG.BACKBONE)

In [None]:
input_shape, num_classes = t_x.shape[1:], t_y.shape[-1]
# input_shape, num_classes = (512, 512, 3), 5
print(input_shape, num_classes)

In [None]:
def get_retina_model(eager_execution:bool=True):
  return DiabeticRetinopathyClassifier(
           input_shape, num_classes, encoder,
             batch_size=CONFIG.BATCH_SIZE, 
               train_output_conv=CONFIG.TRAIN_ATTN_CONV,
                 pre_trained_weights=CONFIG.PRE_TRAINED_WEIGHTS, 
                   eager_execution=eager_execution
           )

retina_model = get_retina_model()

In [None]:
def get_retina_attn_model(eager_execution:bool=True):
  return DiabeticRetinopathyAttnClassifier(
           input_shape, num_classes, encoder,
             batch_size=CONFIG.BATCH_SIZE, 
               train_output_conv=CONFIG.TRAIN_ATTN_CONV,
                 pre_trained_weights=CONFIG.PRE_TRAINED_WEIGHTS, 
                   name='retina_attn_model', 
                     eager_execution=eager_execution
           )

retina_attn_model = get_retina_attn_model()

# Model summary

In [None]:
if CONFIG.MODEL_SUMMARY=='plot':
  display(tf.keras.utils.plot_model(retina_model))
elif CONFIG.MODEL_SUMMARY=='summary' and CONFIG.VERBOSE:
  print(retina_model.summary())

In [None]:
if CONFIG.MODEL_SUMMARY=='plot':
  display(tf.keras.utils.plot_model(retina_attn_model))
elif CONFIG.MODEL_SUMMARY=='summary' and CONFIG.VERBOSE:
  print(retina_attn_model.summary())

In [None]:
del retina_model, retina_attn_model; tf.keras.backend.clear_session()

# Custom callbacks for training

In [None]:
class GarbageCollectorCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    _ =  gc.collect()
    tf.keras.backend.clear_session()

# Train model

In [None]:
def train_model(train_ds, 
                val_ds, 
                epochs=1, 
                model=None, 
                fold=None,  
                custom_objects=None, 
                opt='Adam', 
                metrics=['categorical_accuracy'], 
                saved_weights_dir='./', 
                saved_weights='output.h5', 
                out_dir='./',
                dropout=0.25, 
                out_weights_file='output.h5', 
                learning_rate=1e-4):
  opt = getattr(tf.keras.optimizers, opt)(learning_rate)

  loss = 'categorical_crossentropy'

  lr_cb = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.75, 
                                               patience=2, verbose=1, mode='min')
  es_cb = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=4, mode='min',
                                           verbose=1, restore_best_weights=True)
  ckpt_file = f'./{out_weights_file}_{fold}.h5' if fold is not None else \
              f'./{out_weights_file}.h5'
  ckpt_cb = tf.keras.callbacks.ModelCheckpoint(ckpt_file, monitor='val_loss', 
                                               mode='min', save_best_only=True)
  gc_cb = GarbageCollectorCallback()
  cb = [es_cb, ckpt_cb, lr_cb, #gc_cb
       ]

  model.compile(optimizer=opt, loss=loss, metrics=metrics)

  if os.path.exists(os.path.join(saved_weights_dir, saved_weights)):
    model.load_weights(os.path.join(saved_weights_dir, saved_weights))
    print('Loaded weights: ', os.path.join(saved_weights_dir, saved_weights))
    
  model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=cb)

In [None]:
if CONFIG.ENABLE_TRAINING:
  retina_model = get_retina_model()  
  train_model(train_ds, val_ds, model=retina_model, 
              epochs=CONFIG.EPOCHS, dropout=CONFIG.DROPOUT,
              saved_weights_dir=CONFIG.SAVED_WEIGHTS_DIR, 
              saved_weights=CONFIG.SAVED_WEIGHTS, 
              out_weights_file=CONFIG.OUT_WEIGHTS)
  del retina_model; tf.keras.backend.clear_session()
  
  retina_attn_model = get_retina_attn_model()  
  train_model(train_ds, val_ds, model=retina_attn_model, 
              epochs=CONFIG.EPOCHS, dropout=CONFIG.DROPOUT,
              saved_weights_dir=CONFIG.SAVED_WEIGHTS_DIR, 
              saved_weights=CONFIG.SAVED_ATTN_WEIGHTS, 
              out_weights_file=CONFIG.OUT_ATTN_WEIGHTS)

# Visualize attention mechanism

In [None]:
sample_img, sample_label = v_x[0].numpy(), v_y[0].numpy()

# Disable eager execution

In [None]:
tf.compat.v1.disable_eager_execution()
tf.keras.backend.clear_session()

In [None]:
retina_attn_model = get_retina_attn_model(eager_execution=False)

In [None]:
trained_weights = os.path.join('./', CONFIG.SAVED_WEIGHTS)
saved_weights = os.path.join(CONFIG.SAVED_WEIGHTS_DIR, CONFIG.SAVED_WEIGHTS)
if os.path.exists(trained_weights):
  retina_attn_model.load_weights(trained_weights)
  print('Loaded weights: ', trained_weights)
elif os.path.exists(saved_weights):
  retina_attn_model.load_weights(saved_weights)
  print('Loaded weights: ', saved_weights)

In [None]:
def get_attention_layer(attn_model, attn_layer=None):
  if attn_layer is not None:
    return attn_model.get_layer(attn_layer)
  else:  
    for i, _layer in enumerate(retina_model.layers):
      _shape = _layer.get_output_shape_at(0)
      if len(_shape)==4:
        if _shape[-1]==1:
          print(_layer)
          print(_layer.name)
          return _layer

attn_layer = get_attention_layer(
               retina_attn_model, attn_layer='retina_attn_model_attn_conv2D_1'
               )

In [None]:
def get_attention(x, attn_model, attn_layer):
  attn_inp = attn_model.input  
  attn_fn = K.function(
              inputs=[attn_model.input, K.learning_phase()],
              outputs=[attn_layer.output]
              )
  return attn_fn([x, 0])

In [None]:
def heatmap_overlay(img, heatmap, threshold=0.8, read_file=False):
  if read_file:
    img = cv2.imread(img)
  heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
  heatmap = np.uint8(255 * heatmap)
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
  hif = threshold
  superimposed_img = cv2.addWeighted(img,threshold,heatmap,1-threshold,0)
  return superimposed_img, heatmap

## Plot the predictions and the attention map

In [None]:
def normalize_arr(arr):
  return (arr - np.min(arr))/(np.max(arr) - np.min(arr))

In [None]:
def plot_preds(img, label, model, attn_layer):
  attn_img = np.array(get_attention(img, model, attn_layer))
  img = np.clip(img[:,:,:]*127+127, 0, 255).astype(np.uint8)
  attn_img = normalize_arr(attn_img[0, 0, :])
  [out_img, attn_img] = heatmap_overlay(img, attn_img)
  fig, m_axs = plt.subplots(1, 3, figsize = (8, 4))
  [c_ax.axis('off') for c_ax in m_axs.flatten()]
  for (img_ax, over_ax, attn_ax)  in [m_axs]:
    img_ax.imshow(img)
    over_ax.imshow(out_img)
    attn_ax.imshow(attn_img, cmap='viridis', vmin=0, vmax=1, interpolation='lanczos')
    real_cat = np.argmax(label)
    img_ax.set_title(
        'Eye image\nCat:%2d' % (real_cat))
    pred_cat = model.predict(np.expand_dims(img, axis=0))[0]
    over_ax.set_title(
        'Overlay of attention map\nCat (Pred):%2d (%1d)' % (
            real_cat,np.argmax(pred_cat)
        )
          )  
    attn_ax.set_title('Attention map\nProb:%2.2f%%' % (
        np.max(100*pred_cat[0,real_cat])
        )
          )

In [None]:
img, label = sample_img, sample_label
for i in range(3):
  plot_preds(img, label, retina_attn_model, attn_layer)

# References:
**1. [This notebook is forked and modified from the diabetic retinopathy detection notebook authored by @manifoldix](https://www.kaggle.com/code/manifoldix/inceptionv3-for-retinopathy-gpu-hr)**

**2. [Kaggle diabetic retinopathy dataset](https://www.kaggle.com/competitions/diabetic-retinopathy-detection)**