# Initialize:

In [None]:
!pip install -q pydot
!pip install -q pydotplus
!apt-get install -q graphviz

import tensorflow_addons as tfa
import tensorflow as tf
import pandas as pd
pd.options.mode.chained_assignment = None
import numpy as np
import scipy

from collections import Counter
from datetime import datetime
import multiprocessing
from glob import glob
import warnings
import requests
import imageio
import IPython
import urllib
import zipfile
import pickle
import random
import shutil
import string
import math
import tqdm
import time
import gzip
import io
import os
import gc
import re

from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import plotly.express as px
import seaborn as sns
from PIL import Image
import matplotlib
import plotly
import PIL
import cv2
import ast

In [None]:
CLASS_LABELS = ["Nucleoplasm", "Nuclear Membrane", "Nucleoli", "Nucleoli Fibrillar Center", "Nuclear Speckles", "Nuclear Bodies", "Endoplasmic Reticulum", "Golgi Apparatus", "Intermediate Filaments", "Actin Filaments", "Microtubules", "Mitotic Spindle", "Centrosome", "Plasma Membrane", "Mitochondria", "Aggresome", "Cytosol", "Vesicles", "Negative"]
INT_2_STR = {i:CLASS_LABELS[i] for i in np.arange(19)}
INT_2_STR_LOWER = {i:j.lower().replace(" ", "_") for i,j in INT_2_STR.items()}
STR_2_INT_LOWER = {j:i for i,j in INT_2_STR_LOWER.items()}
STR_2_INT = {j:i for i,j in INT_2_STR.items()}
FIG_FONT = dict(family="Helvetica, Arial", size=14, color="#7f7f7f")
LABEL_COLORS = [px.colors.label_rgb(px.colors.convert_to_RGB_255(i)) for i in sns.color_palette("Spectral", len(CLASS_LABELS))]
LABEL_COL_MAP = {str(i):j for i,j in enumerate(LABEL_COLORS)}

print(CLASS_LABELS)
print(INT_2_STR)
print(INT_2_STR_LOWER)
print(STR_2_INT_LOWER)
print(STR_2_INT)

In [None]:
train_image_path = '../input/hpa-single-cell-image-classification/train'
test_image_path = '../input/hpa-single-cell-image-classification/test'
train_tf_path = '../input/hpa-single-cell-image-classification/train_tfrecords'
test_tf_path = '../input/hpa-single-cell-image-classification/test_tf_records'

train_red_path = '../input/human-protein-atlas-red-cell-tile-dataset'
train_green_path = '../input/human-protein-atlas-green-cell-tile-dataset'
train_blue_path = '../input/human-protein-atlas-blue-cell-tile-dataset'
train_yellow_path = '../input/human-protein-atlas-yellow-cell-tile-dataset'

train_image = sorted([os.path.join(train_image_path, i) for i in os.listdir(train_image_path)])
test_image = sorted([os.path.join(test_image_path, i) for i in os.listdir(test_image_path)])
train_label = pd.read_csv('../input/hpa-single-cell-image-classification/train.csv')
print(train_label.head())

# Helper functions:

In [None]:
def load_image_scaled(img_id, img_dir, img_size=512, load_style="tf"):
    """ Load An Image Using ID and Directory Path - Composes 4 Individual Images """
    def __load_with_tf(path, img_size=512):
        img = tf.io.read_file(path)
        img = tf.image.decode_png(img, channels=1)
        return tf.image.resize(img, (img_size, img_size))[..., 0]
    
    def __load_with_pil(path, img_size=512):
        img = Image.open(path)
        img = img.resize((img_size, img_size))
        return np.asarray(img)
    
    def __load_with_cv2(path, img_size=512):
        img = cv2.imread(path, 0)
        img = cv2.resize(img, (img_size, img_size))
        return img
        
    if load_style is "tf":
        load_fn = __load_with_tf
    elif load_style is "PIL":
        load_fn = __load_with_pil
    else:
        load_fn = __load_with_cv2
    
    return np.stack(
        [np.asarray(load_fn(os.path.join(img_dir, img_id+f"_{c}.png"), img_size)/255.) for c in ["red", "yellow", "blue"]], axis=2
    )


def decode_img(img, img_size=(224,224)):
    """TBD"""
    
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_png(img, channels=1)

    # resize the image to the desired size
    return tf.cast(tf.image.resize(img, img_size), tf.uint8)


def get_color_path_maps(color_dirs, tp_id_map):
    c_p_maps = [{k:[] for k in INT_2_STR.keys()} for _ in range(len(color_dirs))]
    color_d_paths = [
        [d_path for d_path in os.listdir(color_dir) if d_path.endswith("_256")] \
        for color_dir in color_dirs
    ]
    for c in tqdm(color_d_paths[0], total=len(color_d_paths[0])):
        
        # Get class stuff
        cls = c.split("_", 1)[1].rsplit("_",1)[0]
        cls_idx = STR_2_INT_LOWER[cls]
        
        # Get the relevant color directories
        c_dirs = [
            os.path.join(color_dir, c.replace("red", clr), "data", "train_tiles", cls) \
            for clr, color_dir in zip(["red", "green", "blue", "yellow"], color_dirs)
        ]

        # Update map
        for f_name in tqdm(os.listdir(c_dirs[0]), total=len(os.listdir(c_dirs[0]))):
            # get the relevant full paths
            full_paths = [os.path.join(c_dir, f_name.replace("red", clr)) for clr, c_dir in zip(["red", "green", "blue", "yellow"], c_dirs)]
            if tp_id_map==None:
                for c_p_map, full_path in zip(c_p_maps, full_paths):
                    c_p_map[cls_idx].append(full_path)
            elif (f_name.endswith(".png") and ("negative" in full_paths[0] or f_name.rsplit("_", 1)[0] in tp_id_map[cls_idx])):
                for c_p_map, full_path in zip(c_p_maps, full_paths):
                    c_p_map[cls_idx].append(full_path)
            else:
                for c_p_map, full_path in zip(c_p_maps, full_paths):
                    c_p_map[STR_2_INT["Negative"]].append(full_path)
    return [{k:sorted(v) for k,v in c_p_map.items()} for c_p_map in c_p_maps]


def get_tp_id_map(pkl_dir):
    """ TBD """
    # Capture all relevant paths
    pkl_paths = [
        os.path.join(pkl_dir, f_name) \
        for f_name in os.listdir(pkl_dir) \
        if f_name.endswith(".pkl")
    ]
    
    # REMOVE AFTER UPDATING CLASSBASED NOTEBOOK
    pkl_paths.append("/kaggle/input/tmp-intermediate-filaments-pkl-file/intermediate_filaments_tp_list.pkl")
    
    # Initialize
    tp_id_map = {}
    for path in pkl_paths:
        class_id = STR_2_INT_LOWER[path.rsplit("/", 1)[1].replace("_tp_list.pkl", "")]
        with open(path, "rb") as f:
            tp_id_map[class_id] = pickle.load(f)
    return tp_id_map

    
def plot_rgb(arr, figsize=(12,12)):
    """ Plot 3 Channel Microscopy Image """
    plt.figure(figsize=figsize)
    plt.title(f"RGB Composite Image", fontweight="bold")
    plt.imshow(arr)
    plt.axis(False)
    plt.show()    

    
def convert_rgby_to_rgb(arr):
    """ Convert a 4 channel (RGBY) image to a 3 channel RGB image.
    
    Advice From Competition Host/User: lnhtrang

    For annotation (by experts) and for the model, I guess we agree that individual 
    channels with full range px values are better. 
    In annotation, we toggled the channels. 
    For visualization purpose only, you can try blending the channels. 
    For example, 
        - red = red + yellow
        - green = green + yellow/2
        - blue=blue.
        
    Args:
        arr (numpy array): The RGBY, 4 channel numpy array for a given image
    
    Returns:
        RGB Image
    """
    
    rgb_arr = np.zeros_like(arr[..., :-1])
    rgb_arr[..., 0] = arr[..., 0]
    rgb_arr[..., 1] = arr[..., 1]+arr[..., 3]/2
    rgb_arr[..., 2] = arr[..., 2]
    
    return rgb_arr
    
    
def plot_ex(arr, figsize=(20,6), title=None, plot_merged=True, rgb_only=False):
    """ Plot 4 Channels Side by Side """
    if plot_merged and not rgb_only:
        n_images=5 
    elif plot_merged and rgb_only:
        n_images=4
    elif not plot_merged and rgb_only:
        n_images=4
    else:
        n_images=3
    plt.figure(figsize=figsize)
    if type(title) == str:
        plt.suptitle(title, fontsize=20, fontweight="bold")

    for i, c in enumerate(["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus", "Yellow – Endoplasmic Reticulum"]):
        if not rgb_only:
            ch_arr = np.zeros_like(arr[..., :-1])        
        else:
            ch_arr = np.zeros_like(arr)
        if c in ["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus"]:
            ch_arr[..., i] = arr[..., i]
        else:
            if rgb_only:
                continue
            ch_arr[..., 0] = arr[..., i]
            ch_arr[..., 1] = arr[..., i]
        plt.subplot(1,n_images,i+1)
        plt.title(f"{c.title()}", fontweight="bold")
        plt.imshow(ch_arr)
        plt.axis(False)
        
    if plot_merged:
        plt.subplot(1,n_images,n_images)
        
        if rgb_only:
            plt.title(f"Merged RGB", fontweight="bold")
            plt.imshow(arr)
        else:
            plt.title(f"Merged RGBY into RGB", fontweight="bold")
            plt.imshow(convert_rgby_to_rgb(arr))
        plt.axis(False)
        
    plt.tight_layout(rect=[0, 0.2, 1, 0.97])
    plt.show()
    
    
def flatten_list_of_lists(l_o_l):
    return [item for sublist in l_o_l for item in sublist]


def create_input_list(crp, cgp, cbp, cyp, shuffle=True, val_split=0.025):
    lbl_arr = flatten_list_of_lists([[k,]*len(v) for k, v in sorted(crp.items())])
    cr_arr = flatten_list_of_lists([v for k,v in sorted(crp.items())])
    cg_arr = flatten_list_of_lists([v for k,v in sorted(cgp.items())])
    cb_arr = flatten_list_of_lists([v for k,v in sorted(cbp.items())])
    cy_arr = flatten_list_of_lists([v for k,v in sorted(cyp.items())])
    
    if val_split is not None:
        val_lbl_arr = lbl_arr[:int(len(lbl_arr)*val_split)]
        lbl_arr = lbl_arr[int(len(lbl_arr)*val_split):]
        
        val_cr_arr = cr_arr[:int(len(cr_arr)*val_split)]
        cr_arr = cr_arr[int(len(cr_arr)*val_split):]
        
        val_cg_arr = cg_arr[:int(len(cg_arr)*val_split)]
        cg_arr = cg_arr[int(len(cg_arr)*val_split):]
        
        val_cb_arr = cb_arr[:int(len(cb_arr)*val_split)]
        cb_arr = cb_arr[int(len(cb_arr)*val_split):]

        val_cy_arr = cy_arr[:int(len(cy_arr)*val_split)]
        cy_arr = cy_arr[int(len(cy_arr)*val_split):]
        
    if shuffle:
        to_shuffle = list(zip(cr_arr, cg_arr, cb_arr, cy_arr, lbl_arr))
        random.shuffle(to_shuffle)
        cr_arr, cg_arr, cb_arr, cy_arr, lbl_arr = zip(*to_shuffle)
        
        if val_split is not None:
            val_to_shuffle = list(zip(val_cr_arr, val_cg_arr, val_cb_arr, val_cy_arr, val_lbl_arr))
            random.shuffle(val_to_shuffle)
            val_cr_arr, val_cg_arr, val_cb_arr, val_cy_arr, val_lbl_arr = zip(*val_to_shuffle)
    
    if val_split is None:
        return list(cr_arr), list(cg_arr), list(cb_arr), list(cy_arr), list(lbl_arr)
    else:
        return (list(cr_arr), list(cg_arr), list(cb_arr), list(cy_arr), list(lbl_arr)), \
               (list(val_cr_arr), list(val_cg_arr), list(val_cb_arr), list(val_cy_arr), list(val_lbl_arr))


def get_class_wts(single_ch_paths, n_classes=19, exclude_mitotic=True, multiplier=10, return_counts=False):
    """ TBD """
    # Get class counts
    class_counts = {c_idx:len(single_ch_paths[c_idx]) for c_idx in range(n_classes)}

    # Exclude mitotic spindle
    if exclude_mitotic:
        real_min_count = list(sorted(class_counts.values(), reverse=True))[-2]
    else:
        real_min_count = list(sorted(class_counts.values(), reverse=True))[-1]

    # Calculate weights
    class_wts = {k:min(1, multiplier*(real_min_count/v)) for k,v in class_counts.items()}

    if exclude_mitotic:
        # Manually adjust mitotic spindle to a more appropriate value
        class_wts[min(class_counts, key=class_counts.get)] = 1.0

    if return_counts:
        return class_wts, class_counts
    else:
        return class_wts

# Preparing the data and model:

In [None]:
train_each_path = [train_red_path, train_green_path, train_blue_path, train_yellow_path]
print(train_each_path)

# Define the paths to the training files for the tile dataset as a map from class index to list of paths
train_red_map, train_green_map, train_blue_map, train_yellow_map = get_color_path_maps(train_each_path, None)

In [None]:
# red_inputs, green_inputs, blue_inputs, yellow_inputs, labels
train_inputs, val_inputs = create_input_list(
    train_red_map, 
    train_green_map, 
    train_blue_map, 
    train_yellow_map, 
    shuffle=True,
    val_split=0.075,
)

In [None]:
class_weights, class_counts = get_class_wts(train_red_map, return_counts=True, multiplier=50)
print(class_weights)
print(class_counts)

In [None]:
# Using an LR ramp up because fine-tuning a pre-trained model.
N_EPOCHS=15
LR_START = 0.001
LR_MAX = 0.0011
LR_MIN = 0.0005
LR_RAMPUP_EPOCHS = 3
LR_SUSTAIN_EPOCHS = 2
LR_EXP_DECAY = 0.75

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr

# VIEW SCHEDULE
rng = [i for i in range(N_EPOCHS)]
y = [lrfn(x) for x in rng]

plt.figure(figsize=(10,4))
plt.plot(rng, y)
plt.title("CUSTOM LR SCHEDULE", fontweight="bold")
plt.show()

print(f"Learning rate schedule: {y[0]:.3g} to {max(y):.3g} to {y[-1]:.3g}")

In [None]:
#PARAMS
MODEL_CKPT_DIR = "/kaggle/working/ebnet_b2_wdensehead"
DROP_YELLOW = True
NO_NEG_CLASS = False

if NO_NEG_CLASS:
    class_wts = {k:v for k,v in class_wts.items() if k!=18}
    class_cnts = {k:v for k,v in class_cnts.items() if k!=18}
    n_classes = 18
else:
    n_classes=19
    
BATCH_SIZE=32
OPTIMIZER = tf.keras.optimizers.Adam(lr=LR_START)
LOSS_FN = "binary_crossentropy"
SHUFF_BUFF = 500


# AUTO-CALCULATED
N_EX = len(train_red_map[0])
N_VAL = int(0.075*N_EX)
N_TRAIN = N_EX-N_VAL

if not os.path.isdir(MODEL_CKPT_DIR):
    os.makedirs(MODEL_CKPT_DIR, exist_ok=True)
    
print(f"{N_TRAIN:<7} TRAINING EXAMPLES")
print(f"{N_VAL:<7} VALIDATION EXAMPLES")

In [None]:
# TRAIN DATASET
train_path_ds = tf.data.Dataset.zip(
    tuple([tf.data.Dataset.from_tensor_slices(input_ds) for input_ds in train_inputs])
)

# VALIDATION DATASET
val_path_ds = tf.data.Dataset.zip(
    tuple([tf.data.Dataset.from_tensor_slices(input_ds) for input_ds in val_inputs])
)

print(f"\n ... THERE ARE {N_EX} CELL TILES IN OUR FULL DATASET ... ")
print(f" ... THERE ARE {N_TRAIN} CELL TILES IN OUR TRAIN DATASET ... ")
print(f" ... THERE ARE {N_VAL} CELL TILES IN OUR VALIDATION DATASET ... \n")

print(train_path_ds)

for a,b,c,d,e in train_path_ds.take(1): 
    print(f"\tRed Path      --> {a}\n\t" \
          f"Green Path    --> {b}\n\t" \
          f"Blue Path     --> {c}\n\t" \
          f"Yellow Path   --> {d}\n\t" \
          f"Example Label --> {e} ({INT_2_STR[e.numpy()]})\n")

# Helper functions:

In [None]:
def preprocess_path_ds(rp, gp, bp, yp, lbl, img_size=(224,224), combine=True, drop_yellow=True, no_neg=True):
    """ TBD """
    
    # Adjust class output expectation
    if no_neg:
        if lbl==18:
            lbl_arr = tf.zeros((18,), dtype=tf.uint8)
        else:
            lbl_arr = tf.one_hot(lbl, 18, dtype=tf.uint8)
    else:
        lbl_arr = tf.one_hot(lbl, 19, dtype=tf.uint8)
    
    ri = decode_img(tf.io.read_file(rp), img_size)
    gi = decode_img(tf.io.read_file(gp), img_size)
    bi = decode_img(tf.io.read_file(bp), img_size)

    if combine and drop_yellow:
        return tf.stack([ri[..., 0], gi[..., 0], bi[..., 0]], axis=-1), lbl_arr
    elif combine:
        yi = decode_img(tf.io.read_file(yp), img_size)
        return tf.stack([ri[..., 0], gi[..., 0], bi[..., 0], yi[..., 0]], axis=-1), lbl_arr
    elif drop_yellow:
        return ri, gi, bi, lbl_arr
    else:
        yi = decode_img(tf.io.read_file(yp), img_size)
        return ri, gi, bi, yi, lbl_arr
    

def augment(img_batch, lbl_batch):
    # SEEDING & KERNEL INIT
    K = tf.random.uniform((1,), minval=0, maxval=4, dtype=tf.dtypes.int32)[0]
    
    img_batch = tf.image.random_flip_left_right(img_batch)
    img_batch = tf.image.random_flip_up_down(img_batch)
    img_batch = tf.image.rot90(img_batch, K)
    
    img_batch = tf.image.random_saturation(img_batch, 0.85, 1.15)
    img_batch = tf.image.random_brightness(img_batch, 0.1)
    img_batch = tf.image.random_contrast(img_batch, 0.85, 1.15)
        
    # Can't figure this out right now
    #     # Apply a random crop
    #     img_batch = tf.where(K==0, tf.map_fn(
    #         fn=lambda img: tf.image.resize(tf.image.random_crop(tf.cast(img, tf.float32), (192,192,3)), (224,224)),
    #         elems=img_batch,), img_batch)

    return img_batch, lbl_batch

# Validation set

In [None]:
TRAIN_CACHE_DIR = "/kaggle/train_cache"
VAL_CACHE_DIR = "/kaggle/val_cache"

if not os.path.isdir(TRAIN_CACHE_DIR):
    os.makedirs(TRAIN_CACHE_DIR, exist_ok=True)
if not os.path.isdir(VAL_CACHE_DIR):
    os.makedirs(VAL_CACHE_DIR, exist_ok=True)

train_ds = train_path_ds.map(
    lambda r,g,b,y,l: preprocess_path_ds(r,g,b,y,l, drop_yellow=DROP_YELLOW, no_neg=NO_NEG_CLASS), 
    num_parallel_calls=tf.data.AUTOTUNE
)
val_ds = val_path_ds.map(
    lambda r,g,b,y,l: preprocess_path_ds(r,g,b,y,l, drop_yellow=DROP_YELLOW, no_neg=NO_NEG_CLASS), 
    num_parallel_calls=tf.data.AUTOTUNE
)
    
train_ds = train_ds.cache(TRAIN_CACHE_DIR) \
                   .shuffle(SHUFF_BUFF) \
                   .batch(BATCH_SIZE) \
                   .map(augment, num_parallel_calls=tf.data.AUTOTUNE) \
                   .prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.cache(VAL_CACHE_DIR) \
               .batch(BATCH_SIZE) \
               .prefetch(tf.data.AUTOTUNE)

# Training the model:

In [None]:
 def get_backbone(efficientnet_name="efficientnet_b0", input_shape=(224,224,3), include_top=False, weights="imagenet", pooling="avg"):
     if "b0" in efficientnet_name:
         eb = tf.keras.applications.EfficientNetB0(
             include_top=include_top, weights=weights, pooling=pooling, input_shape=input_shape
             )
     elif "b1" in efficientnet_name:
         eb = tf.keras.applications.EfficientNetB1(
             include_top=include_top, weights=weights, pooling=pooling, input_shape=input_shape
             )
     elif "b2" in efficientnet_name:
         eb = tf.keras.applications.EfficientNetB2(
             include_top=include_top, weights=weights, pooling=pooling, input_shape=input_shape
             )
     elif "b3" in efficientnet_name:
         eb = tf.keras.applications.EfficientNetB3(
             include_top=include_top, weights=weights, pooling=pooling, input_shape=input_shape
             )
     elif "b4" in efficientnet_name:
         eb = tf.keras.applications.EfficientNetB4(
             include_top=include_top, weights=weights, pooling=pooling, input_shape=input_shape
             )
     elif "b5" in efficientnet_name:
         eb = tf.keras.applications.EfficientNetB5(
             include_top=include_top, weights=weights, pooling=pooling, input_shape=input_shape
             )
     elif "b6" in efficientnet_name:
         eb = tf.keras.applications.EfficientNetB6(
             include_top=include_top, weights=weights, pooling=pooling, input_shape=input_shape
             )
     elif "b7" in efficientnet_name:
         eb = tf.keras.applications.EfficientNetB7(
             include_top=include_top, weights=weights, pooling=pooling, input_shape=input_shape
             )
     else:
         raise ValueError("Invalid EfficientNet Name!!!")
     return eb


 def add_head_to_bb(bb, n_classes=19, dropout=0.05, head_layer_nodes=(512,)):
     x = tf.keras.layers.BatchNormalization()(bb.output)
     x = tf.keras.layers.Dropout(dropout)(x)
   
     for n_nodes in head_layer_nodes:
         x = tf.keras.layers.Dense(n_nodes, activation="relu")(x)
         x = tf.keras.layers.BatchNormalization()(x)
         x = tf.keras.layers.Dropout(dropout/2)(x)
   
     output = tf.keras.layers.Dense(n_classes, activation="sigmoid")(x)
     return tf.keras.Model(inputs=bb.inputs, outputs=output)

 eb = get_backbone("b2")
 eb = add_head_to_bb(eb, n_classes, dropout=0.5)
 eb.compile(optimizer=OPTIMIZER, loss=LOSS_FN, metrics=["acc", tf.keras.metrics.AUC(name="auc", multi_label=True)])

 tf.keras.utils.plot_model(eb, show_shapes=True, show_dtype=True, dpi=55)

In [None]:
history = eb.fit(
    train_ds, 
    validation_data=val_ds, 
    callbacks=[
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(MODEL_CKPT_DIR, "ckpt-{epoch:04d}-{val_loss:.4f}.ckpt"), verbose=1),
        tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=1)
    ], 
    class_weight=class_weights, 
    epochs=N_EPOCHS
)
eb.save("./model_efficientnet_b2")