In [None]:
#Preprocessing the Data

In [None]:
# Try and get keras plot to work
!pip install -q pydot
!pip install -q pydotplus
!apt-get install -q graphviz

print("\n... OTHER IMPORTS STARTING ...\n")
print("\n\tVERSION INFORMATION")

# Machine Learning and Data Science Imports
import tensorflow_addons as tfa; print(f"\t\t– TENSORFLOW ADDONS VERSION: {tfa.__version__}");
import tensorflow as tf; print(f"\t\t– TENSORFLOW VERSION: {tf.__version__}");
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np; print(f"\t\t– NUMPY VERSION: {np.__version__}");
import scipy; print(f"\t\t– SCIPY VERSION: {scipy.__version__}");

# Built In Imports
from collections import Counter
from datetime import datetime
import multiprocessing
from glob import glob
import warnings
import requests
import imageio
import IPython
import urllib
import zipfile
import pickle
import random
import shutil
import string
import math
import tqdm
import time
import gzip
import io
import os
import gc
import re

# Visualization Imports
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import plotly.express as px
import seaborn as sns
from PIL import Image
import matplotlib; print(f"\t\t– MATPLOTLIB VERSION: {matplotlib.__version__}");
import plotly
import PIL
import cv2
import ast

# PRESETS
LBL_NAMES = ["Nucleoplasm", "Nuclear Membrane", "Nucleoli", "Nucleoli Fibrillar Center", "Nuclear Speckles", "Nuclear Bodies", "Endoplasmic Reticulum", "Golgi Apparatus", "Intermediate Filaments", "Actin Filaments", "Microtubules", "Mitotic Spindle", "Centrosome", "Plasma Membrane", "Mitochondria", "Aggresome", "Cytosol", "Vesicles", "Negative"]
INT_2_STR = {x:LBL_NAMES[x] for x in np.arange(19)}
INT_2_STR_LOWER = {k:v.lower().replace(" ", "_") for k,v in INT_2_STR.items()}
STR_2_INT_LOWER = {v:k for k,v in INT_2_STR_LOWER.items()}
STR_2_INT = {v:k for k,v in INT_2_STR.items()}
FIG_FONT = dict(family="Helvetica, Arial", size=14, color="#7f7f7f")
LABEL_COLORS = [px.colors.label_rgb(px.colors.convert_to_RGB_255(x)) for x in sns.color_palette("Spectral", len(LBL_NAMES))]
LABEL_COL_MAP = {str(i):x for i,x in enumerate(LABEL_COLORS)}

print("\n\n... IMPORTS COMPLETE ...\n")

In [None]:
# Define the path to the root data directory
ROOT_DIR = "/kaggle/input"

# Define the path to the competition data directory
COMP_DIR = os.path.join(ROOT_DIR, "hpa-single-cell-image-classification")

# Define path to the filtered TP IDs for each class
PKL_DIR = os.path.join(ROOT_DIR, "hpa-rule-based-single-cell-filtering")

# Define the paths to the training tiles for the cell-wise classification dataset
RED_TILE_DIR = os.path.join(ROOT_DIR, "human-protein-atlas-red-cell-tile-dataset")
GREEN_TILE_DIR = os.path.join(ROOT_DIR, "human-protein-atlas-green-cell-tile-dataset")
BLUE_TILE_DIR = os.path.join(ROOT_DIR, "human-protein-atlas-blue-cell-tile-dataset")
YELLOW_TILE_DIR = os.path.join(ROOT_DIR, "human-protein-atlas-yellow-cell-tile-dataset")

# Define the paths to the training and testing tfrecord and 
# image folders respectively for the competition data
TRAIN_IMG_DIR = os.path.join(COMP_DIR, "train")
TRAIN_TFREC_DIR = os.path.join(COMP_DIR, "train_tfrecords")
TEST_IMG_DIR = os.path.join(COMP_DIR, "test")
TEST_TFREC_DIR = os.path.join(COMP_DIR, "test_tfrecords")

# Capture all the relevant full image paths for the competition dataset
TRAIN_IMG_PATHS = sorted([os.path.join(TRAIN_IMG_DIR, f_name) for f_name in os.listdir(TRAIN_IMG_DIR)])
TEST_IMG_PATHS = sorted([os.path.join(TEST_IMG_DIR, f_name) for f_name in os.listdir(TEST_IMG_DIR)])
print(f"\n... Recall that 4 training images compose one example (R,G,B,Y) ...")
print(f"... \t– i.e. The first 4 training files are:")
for path in [x.rsplit('/',1)[1] for x in TRAIN_IMG_PATHS[:4]]: print(f"... \t\t– {path}")
print(f"\n... The number of training images is {len(TRAIN_IMG_PATHS)} i.e. {len(TRAIN_IMG_PATHS)//4} 4-channel images ...")
print(f"... The number of testing images is {len(TEST_IMG_PATHS)} i.e. {len(TEST_IMG_PATHS)//4} 4-channel images ...")

# Capture all the relevant full tfrec paths
TRAIN_TFREC_PATHS = sorted([os.path.join(TRAIN_TFREC_DIR, f_name) for f_name in os.listdir(TRAIN_TFREC_DIR)])
TEST_TFREC_PATHS = sorted([os.path.join(TEST_TFREC_DIR, f_name) for f_name in os.listdir(TEST_TFREC_DIR)])
print(f"\n... The number of training tfrecord files is {len(TRAIN_TFREC_PATHS)} ...")
print(f"... The number of testing tfrecord files is {len(TEST_TFREC_PATHS)} ...\n")

# Random Useful Info
ORIGINAL_DIST_MAP = {0: 37472, 1: 4845, 2: 12672, 3: 12882, 4: 17527, 5: 15337, 6: 10198, 7: 18825, 8: 11194, 9: 5322, 10: 7789, 11: 10, 12: 13952, 13: 21168, 14: 27494, 15: 2275, 16: 22738, 17: 5619, 18: 952}

# Define paths to the relevant csv files
TRAIN_CSV = os.path.join(ROOT_DIR, "hpa-train-data-with-additional-metadata/updated_train.csv")

print("\n... Loading massive train dataframe ...\n")
# Create the relevant dataframe objects
train_df = pd.read_csv(TRAIN_CSV)
# train_df.mask_rles = train_df.mask_rles.apply(lambda x: ast.literal_eval(x))
# train_df.mask_bboxes = train_df.mask_bboxes.apply(lambda x: ast.literal_eval(x))
    
print("\n\nTRAIN DATAFRAME\n\n")
display(train_df.head(3))

In [None]:
def load_image_scaled(img_id, img_dir, img_size=512, load_style="tf"):
    """ Load An Image Using ID and Directory Path - Composes 4 Individual Images """
    def __load_with_tf(path, img_size=512):
        img = tf.io.read_file(path)
        img = tf.image.decode_png(img, channels=1)
        return tf.image.resize(img, (img_size, img_size))[..., 0]
    
    def __load_with_pil(path, img_size=512):
        img = Image.open(path)
        img = img.resize((img_size, img_size))
        return np.asarray(img)
    
    def __load_with_cv2(path, img_size=512):
        img = cv2.imread(path, 0)
        img = cv2.resize(img, (img_size, img_size))
        return img
        
    if load_style is "tf":
        load_fn = __load_with_tf
    elif load_style is "PIL":
        load_fn = __load_with_pil
    else:
        load_fn = __load_with_cv2
    
    return np.stack(
        [np.asarray(load_fn(os.path.join(img_dir, img_id+f"_{c}.png"), img_size)/255.) for c in ["red", "yellow", "blue"]], axis=2
    )


def decode_img(img, img_size=(224,224)):
    """TBD"""
    
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_png(img, channels=1)

    # resize the image to the desired size
    return tf.cast(tf.image.resize(img, img_size), tf.uint8)


def get_color_path_maps(color_dirs, tp_id_map):
    c_p_maps = [{k:[] for k in INT_2_STR.keys()} for _ in range(len(color_dirs))]
    color_d_paths = [
        [d_path for d_path in os.listdir(color_dir) if d_path.endswith("_256")] \
        for color_dir in color_dirs
    ]
    for c in tqdm(color_d_paths[0], total=len(color_d_paths[0])):
        
        # Get class stuff
        cls = c.split("_", 1)[1].rsplit("_",1)[0]
        cls_idx = STR_2_INT_LOWER[cls]
        
        # Get the relevant color directories
        c_dirs = [
            os.path.join(color_dir, c.replace("red", clr), "data", "train_tiles", cls) \
            for clr, color_dir in zip(["red", "green", "blue", "yellow"], color_dirs)
        ]

        # Update map
        for f_name in tqdm(os.listdir(c_dirs[0]), total=len(os.listdir(c_dirs[0]))):
            # get the relevant full paths
            full_paths = [os.path.join(c_dir, f_name.replace("red", clr)) for clr, c_dir in zip(["red", "green", "blue", "yellow"], c_dirs)]
            if tp_id_map==None:
                for c_p_map, full_path in zip(c_p_maps, full_paths):
                    c_p_map[cls_idx].append(full_path)
            elif (f_name.endswith(".png") and ("negative" in full_paths[0] or f_name.rsplit("_", 1)[0] in tp_id_map[cls_idx])):
                for c_p_map, full_path in zip(c_p_maps, full_paths):
                    c_p_map[cls_idx].append(full_path)
            else:
                for c_p_map, full_path in zip(c_p_maps, full_paths):
                    c_p_map[STR_2_INT["Negative"]].append(full_path)
    return [{k:sorted(v) for k,v in c_p_map.items()} for c_p_map in c_p_maps]


def get_tp_id_map(pkl_dir):
    """ TBD """
    # Capture all relevant paths
    pkl_paths = [
        os.path.join(pkl_dir, f_name) \
        for f_name in os.listdir(pkl_dir) \
        if f_name.endswith(".pkl")
    ]
    
    # REMOVE AFTER UPDATING CLASSBASED NOTEBOOK
    pkl_paths.append("/kaggle/input/tmp-intermediate-filaments-pkl-file/intermediate_filaments_tp_list.pkl")
    
    # Initialize
    tp_id_map = {}
    for path in pkl_paths:
        class_id = STR_2_INT_LOWER[path.rsplit("/", 1)[1].replace("_tp_list.pkl", "")]
        with open(path, "rb") as f:
            tp_id_map[class_id] = pickle.load(f)
    return tp_id_map

    
def plot_rgb(arr, figsize=(12,12)):
    """ Plot 3 Channel Microscopy Image """
    plt.figure(figsize=figsize)
    plt.title(f"RGB Composite Image", fontweight="bold")
    plt.imshow(arr)
    plt.axis(False)
    plt.show()    

    
def convert_rgby_to_rgb(arr):
    """ Convert a 4 channel (RGBY) image to a 3 channel RGB image.
    
    Advice From Competition Host/User: lnhtrang

    For annotation (by experts) and for the model, I guess we agree that individual 
    channels with full range px values are better. 
    In annotation, we toggled the channels. 
    For visualization purpose only, you can try blending the channels. 
    For example, 
        - red = red + yellow
        - green = green + yellow/2
        - blue=blue.
        
    Args:
        arr (numpy array): The RGBY, 4 channel numpy array for a given image
    
    Returns:
        RGB Image
    """
    
    rgb_arr = np.zeros_like(arr[..., :-1])
    rgb_arr[..., 0] = arr[..., 0]
    rgb_arr[..., 1] = arr[..., 1]+arr[..., 3]/2
    rgb_arr[..., 2] = arr[..., 2]
    
    return rgb_arr
    
    
def plot_ex(arr, figsize=(20,6), title=None, plot_merged=True, rgb_only=False):
    """ Plot 4 Channels Side by Side """
    if plot_merged and not rgb_only:
        n_images=5 
    elif plot_merged and rgb_only:
        n_images=4
    elif not plot_merged and rgb_only:
        n_images=4
    else:
        n_images=3
    plt.figure(figsize=figsize)
    if type(title) == str:
        plt.suptitle(title, fontsize=20, fontweight="bold")

    for i, c in enumerate(["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus", "Yellow – Endoplasmic Reticulum"]):
        if not rgb_only:
            ch_arr = np.zeros_like(arr[..., :-1])        
        else:
            ch_arr = np.zeros_like(arr)
        if c in ["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus"]:
            ch_arr[..., i] = arr[..., i]
        else:
            if rgb_only:
                continue
            ch_arr[..., 0] = arr[..., i]
            ch_arr[..., 1] = arr[..., i]
        plt.subplot(1,n_images,i+1)
        plt.title(f"{c.title()}", fontweight="bold")
        plt.imshow(ch_arr)
        plt.axis(False)
        
    if plot_merged:
        plt.subplot(1,n_images,n_images)
        
        if rgb_only:
            plt.title(f"Merged RGB", fontweight="bold")
            plt.imshow(arr)
        else:
            plt.title(f"Merged RGBY into RGB", fontweight="bold")
            plt.imshow(convert_rgby_to_rgb(arr))
        plt.axis(False)
        
    plt.tight_layout(rect=[0, 0.2, 1, 0.97])
    plt.show()
    
    
def flatten_list_of_lists(l_o_l):
    return [item for sublist in l_o_l for item in sublist]


def create_input_list(crp, cgp, cbp, cyp, shuffle=True, val_split=0.025):
    lbl_arr = flatten_list_of_lists([[k,]*len(v) for k, v in sorted(crp.items())])
    cr_arr = flatten_list_of_lists([v for k,v in sorted(crp.items())])
    cg_arr = flatten_list_of_lists([v for k,v in sorted(cgp.items())])
    cb_arr = flatten_list_of_lists([v for k,v in sorted(cbp.items())])
    cy_arr = flatten_list_of_lists([v for k,v in sorted(cyp.items())])
    
    if val_split is not None:
        val_lbl_arr = lbl_arr[:int(len(lbl_arr)*val_split)]
        lbl_arr = lbl_arr[int(len(lbl_arr)*val_split):]
        
        val_cr_arr = cr_arr[:int(len(cr_arr)*val_split)]
        cr_arr = cr_arr[int(len(cr_arr)*val_split):]
        
        val_cg_arr = cg_arr[:int(len(cg_arr)*val_split)]
        cg_arr = cg_arr[int(len(cg_arr)*val_split):]
        
        val_cb_arr = cb_arr[:int(len(cb_arr)*val_split)]
        cb_arr = cb_arr[int(len(cb_arr)*val_split):]

        val_cy_arr = cy_arr[:int(len(cy_arr)*val_split)]
        cy_arr = cy_arr[int(len(cy_arr)*val_split):]
        
    if shuffle:
        to_shuffle = list(zip(cr_arr, cg_arr, cb_arr, cy_arr, lbl_arr))
        random.shuffle(to_shuffle)
        cr_arr, cg_arr, cb_arr, cy_arr, lbl_arr = zip(*to_shuffle)
        
        if val_split is not None:
            val_to_shuffle = list(zip(val_cr_arr, val_cg_arr, val_cb_arr, val_cy_arr, val_lbl_arr))
            random.shuffle(val_to_shuffle)
            val_cr_arr, val_cg_arr, val_cb_arr, val_cy_arr, val_lbl_arr = zip(*val_to_shuffle)
    
    if val_split is None:
        return list(cr_arr), list(cg_arr), list(cb_arr), list(cy_arr), list(lbl_arr)
    else:
        return (list(cr_arr), list(cg_arr), list(cb_arr), list(cy_arr), list(lbl_arr)), \
               (list(val_cr_arr), list(val_cg_arr), list(val_cb_arr), list(val_cy_arr), list(val_lbl_arr))


def get_class_wts(single_ch_paths, n_classes=19, exclude_mitotic=True, multiplier=10, return_counts=False):
    """ TBD """
    # Get class counts
    class_counts = {c_idx:len(single_ch_paths[c_idx]) for c_idx in range(n_classes)}

    # Exclude mitotic spindle
    if exclude_mitotic:
        real_min_count = list(sorted(class_counts.values(), reverse=True))[-2]
    else:
        real_min_count = list(sorted(class_counts.values(), reverse=True))[-1]

    # Calculate weights
    class_wts = {k:min(1, multiplier*(real_min_count/v)) for k,v in class_counts.items()}

    if exclude_mitotic:
        # Manually adjust mitotic spindle to a more appropriate value
        class_wts[min(class_counts, key=class_counts.get)] = 1.0

    if return_counts:
        return class_wts, class_counts
    else:
        return class_wts

In [None]:
TILE_DIRS = [RED_TILE_DIR, GREEN_TILE_DIR, BLUE_TILE_DIR, YELLOW_TILE_DIR]
# TP_ID_MAP = get_tp_id_map(PKL_DIR)

# Define the paths to the training files for the tile dataset as a map from class index to list of paths
RED_FILE_MAP, GREEN_FILE_MAP, BLUE_FILE_MAP, YELLOW_FILE_MAP = \
    get_color_path_maps(TILE_DIRS, None)

In [None]:
VAL_FRAC = 0.075

# red_inputs, green_inputs, blue_inputs, yellow_inputs, labels
train_inputs, val_inputs = create_input_list(
    RED_FILE_MAP, 
    GREEN_FILE_MAP, 
    BLUE_FILE_MAP, 
    YELLOW_FILE_MAP, 
    shuffle=True,
    val_split=VAL_FRAC,
)

In [None]:
# class_wts, class_cnts = get_class_wts(RED_FILE_MAP, return_counts=True, multiplier=23.203)
class_wts, class_cnts = get_class_wts(RED_FILE_MAP, return_counts=True, multiplier=50)
print("\n ... CLASSWISE COUNTS ... \n")
display(class_cnts)

print("\n ... CLASS WEIGHTING ... \n")
display(class_wts)

In [None]:
N_EPOCHS=10
LR_START = 0.0005
LR_MAX = 0.0011
LR_MIN = 0.0005
LR_RAMPUP_EPOCHS = 3
LR_SUSTAIN_EPOCHS = 2
LR_EXP_DECAY = 0.75

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr

# VIEW SCHEDULE
rng = [i for i in range(N_EPOCHS)]
y = [lrfn(x) for x in rng]

plt.figure(figsize=(10,4))
plt.plot(rng, y)
plt.title("CUSTOM LR SCHEDULE", fontweight="bold")
plt.show()

print(f"Learning rate schedule: {y[0]:.3g} to {max(y):.3g} to {y[-1]:.3g}")

In [None]:
#PARAMS
MODEL_CKPT_DIR = "/kaggle/working/ebnet_b2_wdensehead"
DROP_YELLOW = True
NO_NEG_CLASS = False

if NO_NEG_CLASS:
    class_wts = {k:v for k,v in class_wts.items() if k!=18}
    class_cnts = {k:v for k,v in class_cnts.items() if k!=18}
    n_classes = 18
else:
    n_classes=19
    
BATCH_SIZE=32
OPTIMIZER = tf.keras.optimizers.Adam(lr=LR_START)
LOSS_FN = "binary_crossentropy"
SHUFF_BUFF = 500


# AUTO-CALCULATED
N_EX = len(RED_FILE_MAP[0])
N_VAL = int(VAL_FRAC*N_EX)
N_TRAIN = N_EX-N_VAL

if not os.path.isdir(MODEL_CKPT_DIR):
    os.makedirs(MODEL_CKPT_DIR, exist_ok=True)
    
print(f"{N_TRAIN:<7} TRAINING EXAMPLES")
print(f"{N_VAL:<7} VALIDATION EXAMPLES")