<h1 style="text-align: center; font-family: Verdana; font-size: 32px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; font-variant: small-caps; letter-spacing: 3px; color: #74d5dd; background-color: #ffffff;">Human Protein Atlas - Single Cell Classification</h1>
<h2 style="text-align: center; font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: underline; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">Similarity & Duplicate Detections Using RAPIDS and KNN</h2>
<h5 style="text-align: center; font-family: Verdana; font-size: 12px; font-style: normal; font-weight: bold; text-decoration: None; text-transform: none; letter-spacing: 1px; color: black; background-color: #ffffff;">CREATED BY: DARIEN SCHETTLER</h5>


<h1 style="font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;">TABLE OF CONTENTS</h1>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#imports">0&nbsp;&nbsp;&nbsp;&nbsp;IMPORTS</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#background_information">1&nbsp;&nbsp;&nbsp;&nbsp;BACKGROUND INFORMATION</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#setup">2&nbsp;&nbsp;&nbsp;&nbsp;SETUP</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#helper_functions">3&nbsp;&nbsp;&nbsp;&nbsp;HELPER FUNCTIONS</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#feature_embedding">4&nbsp;&nbsp;&nbsp;&nbsp;FEATURE EMBEDDING</a></h3>

---

<h1 style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; background-color: #ffffff; color: navy;" id="imports">0&nbsp;&nbsp;IMPORTS</h1>

In [None]:
print("\n... IMPORTS STARTING ...\n")
print("\n\tVERSION INFORMATION")

# Machine Learning and Data Science Imports
import tensorflow as tf; print(f"\t\t– TENSORFLOW VERSION: {tf.__version__}");
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np; print(f"\t\t– NUMPY VERSION: {np.__version__}");
import torch

# Built In Imports
from collections import Counter
from datetime import datetime
import multiprocessing
from glob import glob
import warnings
import requests
import imageio
import IPython
import urllib
import zipfile
import pickle
import random
import shutil
import string
import math
import tqdm
import time
import gzip
import ast
import io
import os
import gc
import re

# Visualization Imports
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import plotly.express as px
import seaborn as sns
from PIL import Image
import matplotlib; print(f"\t\t– MATPLOTLIB VERSION: {matplotlib.__version__}");
import plotly
import PIL
import cv2

# Submission Imports
import typing as t
import base64
import zlib

# PRESETS
LBL_NAMES = ["Nucleoplasm", "Nuclear Membrane", "Nucleoli", "Nucleoli Fibrillar Center", "Nuclear Speckles", "Nuclear Bodies", "Endoplasmic Reticulum", "Golgi Apparatus", "Intermediate Filaments", "Actin Filaments", "Microtubules", "Mitotic Spindle", "Centrosome", "Plasma Membrane", "Mitochondria", "Aggresome", "Cytosol", "Vesicles", "Negative"]
INT_2_STR = {x:LBL_NAMES[x] for x in np.arange(19)}
INT_2_STR_LOWER = {k:v.lower().replace(" ", "_") for k,v in INT_2_STR.items()}
STR_2_INT_LOWER = {v:k for k,v in INT_2_STR_LOWER.items()}
STR_2_INT = {v:k for k,v in INT_2_STR.items()}
FIG_FONT = dict(family="Helvetica, Arial", size=14, color="#7f7f7f")
LABEL_COLORS = [px.colors.label_rgb(px.colors.convert_to_RGB_255(x)) for x in sns.color_palette("Spectral", len(LBL_NAMES))]
LABEL_COL_MAP = {str(i):x for i,x in enumerate(LABEL_COLORS)}

print("\n\n... IMPORTS COMPLETE ...\n")


LIMIT = 12
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_virtual_device_configuration(
            gpus[0],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024*LIMIT)])
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

print('Restrict TensorFlow to max %iGB GPU RAM'%LIMIT)
print('so RAPIDS can use %iGB GPU RAM'%(16-LIMIT))

In [None]:
import cudf, cuml, cupy
from cuml.neighbors import NearestNeighbors

print('RAPIDS', cuml.__version__)

<h1 style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="background_information">1&nbsp;&nbsp;BACKGROUND INFORMATION</h1>

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">1.1  THE DATA</h3>

---

<b style="text-decoration: underline; font-family: Verdana;">BACKGROUND INFORMATION</b>

WIP

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">1.2  THE GOAL</h3>

---

WIP

<h1 style="font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="setup">2&nbsp;&nbsp;NOTEBOOK SETUP</h1>

In [None]:
# Define the root data directory
ROOT_DIR = "/kaggle/input"
DATA_DIR = os.path.join(ROOT_DIR, "hpa-tfrecords-512x512-rgb-only")

# Define the paths to the training and testing tfrecord and image folders respectively
TRAIN_IMG_DIR = os.path.join(ROOT_DIR, "hpa-single-cell-image-classification", "train")
TRAIN_TFREC_DIR = os.path.join(DATA_DIR, "train_slide_records")

# Capture all the relevant full tfrec paths
TRAIN_TFREC_PATHS = tf.io.gfile.glob(os.path.join(TRAIN_TFREC_DIR, '*.tfrec'))
print(f"\n... The number of training tfrecord files is {len(TRAIN_TFREC_PATHS)} ...\n")

TRAIN_CSV_PATH = "../input/hpa-single-cell-image-classification/train.csv"
train_df = pd.read_csv(TRAIN_CSV_PATH)

train_df.head()

<h1 style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="helper_functions">3&nbsp;&nbsp;HELPER FUNCTIONS</h1>

In [None]:
def load_image_from_channel_paths(channel_paths):
    rgby = [np.asarray(Image.open(path), np.uint8) for path in channel_paths]
    return np.stack(rgby, axis=-1)

def load_image(img_id, img_dir):
    """ Load An Image Using ID and Directory Path - Composes 4 Individual Images """
    rgby = [
        np.asarray(Image.open(os.path.join(img_dir, img_id+f"_{c}.png")), np.uint8) \
        for c in ["red", "green", "blue", "yellow"]
    ]
    return np.stack(rgby, axis=-1)

def convert_rgby_to_rgb(arr):
    """ Convert a 4 channel (RGBY) image to a 3 channel RGB image.
    
    Advice From Competition Host/User: lnhtrang

    For annotation (by experts) and for the model, I guess we agree that individual 
    channels with full range px values are better. 
    In annotation, we toggled the channels. 
    For visualization purpose only, you can try blending the channels. 
    For example, 
        - red = red + yellow
        - green = green + yellow/2
        - blue=blue.
        
    Args:
        arr (numpy array): The RGBY, 4 channel numpy array for a given image
    
    Returns:
        RGB Image
    """
    
    rgb_arr = np.zeros_like(arr[..., :-1])
    rgb_arr[..., 0] = arr[..., 0]
    rgb_arr[..., 1] = arr[..., 1]+arr[..., 3]/2
    rgb_arr[..., 2] = arr[..., 2]
    
    return rgb_arr

def plot_ex(arr, figsize=(20,6), title=None, plot_merged=True, rgb_only=False):
    """ Plot 4 Channels Side by Side """
    if plot_merged and not rgb_only:
        n_images=5 
    elif plot_merged and rgb_only:
        n_images=4
    elif not plot_merged and rgb_only:
        n_images=4
    else:
        n_images=3
    plt.figure(figsize=figsize)
    if type(title) == str:
        plt.suptitle(title, fontsize=20, fontweight="bold")

    for i, c in enumerate(["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus", "Yellow – Endoplasmic Reticulum"]):
        if not rgb_only:
            ch_arr = np.zeros_like(arr[..., :-1])        
        else:
            ch_arr = np.zeros_like(arr)
        if c in ["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus"]:
            ch_arr[..., i] = arr[..., i]
        else:
            if rgb_only:
                continue
            ch_arr[..., 0] = arr[..., i]
            ch_arr[..., 1] = arr[..., i]
        plt.subplot(1,n_images,i+1)
        plt.title(f"{c.title()}", fontweight="bold")
        plt.imshow(ch_arr)
        plt.axis(False)
        
    if plot_merged:
        plt.subplot(1,n_images,n_images)
        
        if rgb_only:
            plt.title(f"Merged RGB", fontweight="bold")
            plt.imshow(arr)
        else:
            plt.title(f"Merged RGBY into RGB", fontweight="bold")
            plt.imshow(convert_rgby_to_rgb(arr))
        plt.axis(False)
        
    plt.tight_layout(rect=[0, 0.2, 1, 0.97])
    plt.show()
    
    

def flatten_list_of_lists(l_o_l):
    return [item for sublist in l_o_l for item in sublist]


def get_class_wts(df, low_idx=4):
    label_counts = Counter([c for sublist in df.Label.str.split("|").to_list() for c in sublist])
    low_val = sorted(label_counts.values())[low_idx-1] # Not the lowest as it is very underrepresented
    class_wts ={int(k):min(1.0, low_val/v) for k,v in label_counts.items()}
    return {i:class_wts[i] for i in sorted(class_wts)}


def decode_image(image_data, n_channels=1, resize_to=(512,512), cast_to=tf.uint8):
    image = tf.image.decode_png(image_data, channels=n_channels)    
    image = tf.image.resize(image, resize_to) 
    return tf.cast(image, cast_to)


def str_2_multi_hot_encoding(tfstring, n_classes=19):
    ragged_indices = tf.strings.to_number(tf.strings.split(tfstring, sep="|"), out_type=tf.int32)
    one_hot_stack = tf.one_hot(ragged_indices, depth=n_classes)
    return tf.reduce_max(one_hot_stack, axis=-2)


def decode(serialized_example, multihot=False, n_channels=1, resize_to=(512,512)):
    """ Parses a set of features and label from the given `serialized_example`.
        
        It is used as a map function for `dataset.map`

    Args:
        serialized_example (tf.Example): A serialized example
        is_test (bool, optional): Whether to allow for the label feature
        
    Returns:
        A decoded tf.data.Dataset object representing the tfrecord dataset
    """
    # Defaults are not specified since both keys are required.
    feature_dict = {
        'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
        'image_name': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
        'target': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
    }
    # Define a parser
    features = tf.io.parse_single_example(serialized_example, features=feature_dict)   
    image = decode_image(features['image'], n_channels, resize_to)
    image_name = features["image_name"]
    if multihot:
        label = str_2_multi_hot_encoding(features["target"])
    else:
        label = features["target"]
    return image, image_name, label


def preprocess_tfrec_ds(red, green, blue, yellow=None, drop_yellow=True, return_id=True):
    if (yellow is None) or (drop_yellow):
        (ri, rn, rl), (gi, gn, gl), (bi, bn, bl) = red, green, blue
        yi, yn, yl = None, None, None
    else:
        (ri, rn, rl), (gi, gn, gl), (bi, bn, bl), (yi, yn, yl) = red, green, blue, yellow
    
    if yi is None:
        combo_img = tf.stack([ri[..., 0], gi[..., 0], bi[..., 0]], axis=-1)
    else:
        combo_img = tf.stack([ri[..., 0], gi[..., 0], bi[..., 0], yi[..., 0]], axis=-1)
    
    if return_id:
        img_id = tf.strings.substr(rn, pos=0, len=36) # 36 is length of id (always)
        return combo_img, img_id, rl
    else:
        return combo_img, rl
    
def augment(img_batch, lbl_batch):
    # SEEDING & KERNEL INIT
    K = tf.random.uniform((1,), minval=0, maxval=4, dtype=tf.dtypes.int32)[0]

    img_batch = tf.image.random_flip_left_right(img_batch)
    img_batch = tf.image.random_flip_up_down(img_batch)
    img_batch = tf.image.rot90(img_batch, K)
    
    img_batch = tf.image.random_saturation(img_batch, 0.875, 1.125)
    img_batch = tf.image.random_brightness(img_batch, 0.1125)
    img_batch = tf.image.random_contrast(img_batch, 0.825, 1.175)

    return img_batch, lbl_batch

<h1 style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="feature_embedding">4&nbsp;&nbsp;FEATURE EMBEDDING</h1>

WIP

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">4.1 DATASET</h3>

---

WIP

In [None]:
BATCH_SIZE = 64

train_ds = tf.data.TFRecordDataset(TRAIN_TFREC_PATHS, num_parallel_reads=None)
train_ds = train_ds.map(lambda x: decode(x, n_channels=3, multihot=True, resize_to=(256,256)))

# See examples
for i, (img, image_name, lbl) in enumerate(train_ds.take(3)):
    print(f"IMAGE SHAPE : {img.shape}")
    print(f"IMG #{i+1} -- IMAGE NAME  : {image_name.numpy().decode()}")
    print(f"IMAGE LABEL : {lbl}\n")
    plt.imshow(img.numpy().astype(np.uint8))
    plt.show()
    
# get ordered image ids    
img_ids = list(train_ds.map(lambda x,y,z: (y)).as_numpy_iterator())
img_ids = [x.decode() for x in img_ids]

train_ds = train_ds.map(lambda x,y,z: (x)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">4.2 MODEL</h3>

---

WIP

In [None]:
def get_full_model(backbone, preprocessing_fn, input_shape=(256,256,3)):
    inputs = tf.keras.layers.Input(input_shape, dtype=tf.uint8)
    prep_inputs = preprocessing_fn(tf.cast(inputs, tf.float32))
    backbone_outputs = backbone(prep_inputs)
    return tf.keras.Model(inputs=[inputs], outputs=[backbone_outputs])

bb = tf.keras.applications.EfficientNetB0(weights='imagenet', include_top=False, pooling="avg")
pp = tf.keras.applications.efficientnet.preprocess_input
model = get_full_model(bb, pp, input_shape=(256,256,3))
model.summary()

model.save("/kaggle/working/eb0_embedding_model_at_256.h5")

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">4.3 FIND SIMILAR IMAGES WITH RAPIDS KNN AND IMAGE EMBEDDINGS</h3>

---

WIP

In [None]:
image_embeddings = model.predict(train_ds,verbose=1)
print('image embeddings shape is', image_embeddings.shape)

In [None]:
IMG_KNN_N = 75
model = NearestNeighbors(n_neighbors=IMG_KNN_N)
model.fit(image_embeddings)
distances, indices = model.kneighbors(image_embeddings)

print(distances.shape, indices.shape)

In [None]:
preds = []
for k in tqdm(range(len(distances)), total=len(distances)):
    IDX = np.where(distances[k,]<2.5)[0]
    IDS = indices[k,IDX]
    preds.append(list(np.array(img_ids)[IDS]))

In [None]:
preds_map = {p[0]:p[1:] for p in preds if len(p)>1}
keys_to_remove = []
for pred_key, pred_vals in tqdm(preds_map.copy().items(), total=len(preds_map)):
    if pred_key not in keys_to_remove:
        keys_to_remove += [v for v in pred_vals if v in preds_map.keys()]
        
for k in keys_to_remove:
    try:
        _ = preds_map.pop(k)
    except:
        pass
    
len(preds_map)

In [None]:
MAX_DUPS = 2

for root_id, duplicate_ids in tqdm(preds_map.items(), total=len(preds_map)):
    print("\n------------------------------------------------------------------------\n")
    plot_ex(load_image(root_id, TRAIN_IMG_DIR), title=f"\nROOT\nID  : {train_df[train_df.ID==root_id].ID.values[0]}\nLBL : {train_df[train_df.ID==root_id].Label.values[0]}")
    print("\n------------------------------------------------------------------------\n")
    for i, _id in enumerate(duplicate_ids):
        plot_ex(load_image(_id, TRAIN_IMG_DIR), title=f"\nDUPLICATE\nID  : {train_df[train_df.ID==_id].ID.values[0]}\nLBL : {train_df[train_df.ID==_id].Label.values[0]}")
        if i==2:
            print(f"\n... ONLY DISPLAYING {MAX_DUPS} DUPLICATES FOR BREVITY ...\n")
            break