# Single-cell classification

## Part ONE: Preprocessing  
1. create segmentation masks for all images  
2. extract bounding boxes from segmentation masks  
3. load img_ids, bboxes, class_labels into one dataframe and save to pickle file  

## Part TWO: Visualization  
1. create plot of label distribution  
2. create bar plot of img dimensions  
3. create bar plot of # of imgs with labels per img  
4. show some channel-combined sample images from dataset  
5. show some images with bounding boxes
6. create color histograms?  

## Part THREE: Single-Cell Multi-Label Classification Network  
### = "Which cell type is this cropped cell?"  
this is trained with cropped cells and their labels  
1. load dataframe from pickle file  
2. shuffle dataframe  
3. partition dataframe  
4. split dataframe into train_df and val_df  
5. set hyperparameters  
6. set callbacks: early stopping and checkpoint save
7. set training visualizations/plots from history  
8. train network  

## Part FOUR: Binary Classification Network  
### = "Where are the cells in this image?"  
this is trained with the bounding boxes and the images  
1. load dataframe from pickle file  
2. shuffle dataframe  
3. partition dataframe  
4. split dataframe into train_df and val_df  
5. set hyperparameters  
6. set callbacks: early stopping and checkpoint save
7. set training visualizations/plots from history  
8. train network  

## Part FIVE: Combine Networks
### build a two stage detector and classifier that creates cell bboxes and then classifies those cells

# Part ONE: Preprocessing

In [None]:
%%script echo skipping
#constants
img_folder_path="../input/hpa-single-cell-image-classification/train/"
csv_file_path="../input/hpa-single-cell-image-classification/train.csv"
mask_folder_path="./masks"

In [None]:
LABELS= {
0: "Nucleoplasm",
1: "Nuclear membrane",
2: "Nucleoli",
3: "Nucleoli fibrillar center",
4: "Nuclear speckles",
5: "Nuclear bodies",
6: "Endoplasmic reticulum",
7: "Golgi apparatus",
8: "Intermediate filaments",
9: "Actin filaments",
10: "Microtubules",
11: "Mitotic spindle",
12: "Centrosome",
13: "Plasma membrane",
14: "Mitochondria",
15: "Aggresome",
16: "Cytosol",
17: "Vesicles and punctate cytosolic patterns",
18: "Negative"
}

In [None]:
%%script echo skipping
#code to load csv file goes here

#imgid_labels_array: array of all imgids and corresponding labels
#imgid_array: array of all imgids
#labels_dict: dictionary with imgids and the corresponding labels
import pandas as pd
id_labels_array=pd.read_csv(csv_file_path)
#this line is not necessary anymore as the tf pipeline separates the values already
id_labels_array_separated=id_labels_array["Label"].apply(lambda x:list(map(int, x.split("|"))))
id_array=(id_labels_array["ID"]).tolist()
labels_dict=id_labels_array.set_index('ID').T.to_dict('list')
labels_dict = {num: labels[0] for num, labels in labels_dict.items()}

In [None]:
%%script echo skipping
#function to combine rgby to rgb
#source:https://www.kaggle.com/kwentar/visualization-examples-of-each-class-in-rgb#Load-data:
import numpy as np
def rgby_to_rgb(r,g,b,y):
    image_width,image_height=r.size
    rgb_image = np.zeros(shape=(image_height, image_width, 3), dtype=np.float)
    yellow = np.array(y)
    # yellow is red + green
    rgb_image[:, :, 0] += yellow/2   
    rgb_image[:, :, 1] += yellow/2
    # loop for R,G and B channels
    for index, channel in enumerate([r,g,b]):
        current_image = channel
        rgb_image[:, :, index] += current_image
    # Normalize image
    rgb_image = rgb_image / rgb_image.max() * 255
    return rgb_image.astype(np.uint8)

In [None]:
%%script echo skipping
#function to get rgb image from only img_id
from PIL import Image
def imgid_to_rgb(img_id):
    r=Image.open(img_folder_path+img_id+"_red.png")
    g=Image.open(img_folder_path+img_id+"_green.png")
    b=Image.open(img_folder_path+img_id+"_blue.png")
    y=Image.open(img_folder_path+img_id+"_yellow.png")
    rgb=rgby_to_rgb(r,g,b,y)
    return rgb

In [None]:
%%script echo skipping
#function to convert all images to rgb images
#exceeds 9h kaggle runtime so i ran this offline
#these files have been saved in the datasets named hpa-composite-images-x-of-20
from PIL import Image
import numpy
from tqdm import tqdm

COMPOSITE_IMG_PATH="./composites/"
import os
if not os.path.exists(COMPOSITE_IMG_PATH):
    os.makedirs(COMPOSITE_IMG_PATH)

for img_id in tqdm(id_array[:5]):
    img_rgb = Image.fromarray(imgid_to_rgb(img_id))
    img_rgb.save(COMPOSITE_IMG_PATH+img_id+".png")
    
#convert the id_labels_array to proper file paths of the composite images
import numpy as np
from tqdm import tqdm
img_paths={}
id_array_split=np.array_split(id_array, 20) #the dataset was split into 20 parts
for i in range(20):
    n=i+1
    for img_id in tqdm(id_array_split[i]):
        img_paths[img_id]="../input/hpa-composite-images-"+str(n)+"-of-20/"+img_id+".png"

In [None]:
%%script echo skipping
%%capture
#function for cell segmentation
!pip install https://github.com/CellProfiling/HPA-Cell-Segmentation/archive/master.zip
import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei
from PIL import Image
import numpy as np
NUC_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth"
CELL_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth"
segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device="cpu",
    padding=False,
    multi_channel_model=True,
)
def get_mask(img_id):
    ch_r=Image.open(img_folder_path+img_id+"_red.png")
    ch_y=Image.open(img_folder_path+img_id+"_yellow.png")
    ch_b=Image.open(img_folder_path+img_id+"_blue.png")
    nuc_segmentations = segmentator.pred_nuclei([np.asarray( ch_b )])
    cell_segmentations = segmentator.pred_cells([
            [np.asarray( ch_r )],
            [np.asarray( ch_y )],
            [np.asarray( ch_b )]
        ])
    nuclei_mask, mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
    mask = np.uint8(mask)
    return mask

In [None]:
%%script echo skipping
#https://www.kaggle.com/paulorzp/rle-functions-run-lenght-encode-decode
#these are only helper functions
def mask2rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)
 
def rle2mask(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

In [None]:
%%script echo skipping
#use these for the masks of the dataset (they have different cell instances in one mask)
def multicell_rlencode(mask):
    flat_mask=np.ravel(mask)
    cell_ids=set(flat_mask)
    cell_ids.remove(0)
    mask_rle=list()
    for cell_id in cell_ids:
        mask_rle.append(mask2rle(np.where(cell_id==mask,1,0)))
    return mask_rle
def multicell_rldecode(mask_rle,shape):
    mask = np.zeros(shape, dtype=np.uint8)
    for i,layer in enumerate(mask_rle):
        mask_decoded=rle2mask(layer, shape)
        tmp=np.where(mask_decoded==1,i+1,0)
        mask=np.where(tmp == 0, mask, tmp)
    return mask

In [None]:
%%script echo skipping
#function for bbox creation
def get_bboxes(mask):
    mask_flattened=np.ravel(mask)
    cell_ids=set(mask_flattened)
    cell_ids.remove(0)
    bboxes=list()
    for cell_id in cell_ids:
        a = np.where(mask == cell_id)
        ymin, ymax, xmin, xmax = np.min(a[0]), np.max(a[0]), np.min(a[1]), np.max(a[1])
        bboxes.append([ymin,ymax,xmin,xmax])
    return bboxes

In [None]:
%%script echo skipping
from PIL import Image
def get_shape(img_id,img_path):
    im = Image.open(img_path+"/"+img_id+"_green.png")
    shape = im.size
    return shape
def load_mask_from_file(img_id,mask_folder_path):
    mask_rle=np.load(path+"/mask_rle_"+img_id+'.npy', allow_pickle=True)
    shape=get_shape(img_id,"../input/hpa-single-cell-image-classification/train")
    return multicell_rldecode(mask_rle,shape)

In [None]:
%%script echo skipping
#iterates through image ids in array id_array:
#    1.creates mask as rle string and saves as numpy array
#    2.creates bboxes from mask and adds to dictionary(saved as pickle file)
#I've done this offline as the process takes longer than the max 9h offered by kaggle
#the results of this are stored in my dataset on kaggle:
#    bboxes stored in dataset "hpa-bboxes"
#    rle encoded masks not stored online because too big for proper upload


from tqdm import tqdm
import os
import pickle
bboxes_dict={}
if not os.path.exists(mask_folder_path):
    os.makedirs(mask_folder_path)
for img_id in tqdm(id_array):
    maskpath=mask_folder_path+"/mask_rle_"+img_id+".npy"
    if os.path.isfile(maskpath)==False:
        mask=get_mask(img_id)
        bboxes_dict[img_id]=get_bboxes(mask)
        mask= multicell_rle(mask)
        np.save(maskpath,mask)
if os.path.exists("bboxes.pkl"):
    os.remove("bboxes.pkl")
bboxes_dict_file = open("bboxes.pkl", "wb")
pickle.dump(bboxes_dict, bboxes_dict_file)
bboxes_dict_file.close()

In [None]:
%%script echo skipping
#visualize that the rle encoded mask numpy file is correct
#just some random image id
import matplotlib.pyplot as plt
testmask=load_mask_from_file("466a98c6-bbae-11e8-b2ba-ac1f6b6435d0","../input/test-rle-mask")
plt.imshow(testmask)

In [None]:
%%script echo skipping
#show cell cropping and resize script here
#cropped resized cells were saved in the datasets hpa-resized-224x224-cropped-cells-x20
#example img path:
#../input/hpa-resized-224x224-cropped-cells-420/00481c70-bba3-11e8-b2b9-ac1f6b6435d0_13.png
#-->img_id:00481c70-bba3-11e8-b2b9-ac1f6b6435d0
#-->cell_id:13

#make a dict with all the correct file locations:

import numpy as np
from tqdm import tqdm
import os
img_paths={}
img_folders = ["../input/hpa-resized-224x224-cropped-cells-"+str(i+1)+"20/" for i in range(20)]
for folder in img_folders:
    imgs=os.listdir(folder)
    for img in tqdm(imgs):
        tmp=img.split("_")
        tmp=tmp[0]
        tmp=tmp.split("/")
        img_id=tmp[-1]
        tmp=img.split("_")
        tmp=tmp[-1]
        tmp=tmp.split(".png")
        cell_id=tmp[0]
        img_paths[img_id]=folder+img_id+"_"+cell_id+".png"

In [None]:
#%%script echo skipping
#load readymade dataframe with ids, labels, and bounding boxes
#change ids to new file paths
import pickle
import pandas as pd
id_labels_cells_array = pd.read_pickle("../input/hpa-bboxes/hpa-data.pkl")

def id_to_path(x):
    return img_paths[x]

#id_labels_cells_array["ID"] = id_labels_cells_array["ID"].apply(id_to_path)

In [None]:
#%%script echo skipping
id_labels_cells_array

In [None]:
#%%script echo skipping
id_labels_cells_array.to_pickle("./hpadataframe_cropped_resized_cells_filepaths.pkl")

In [None]:
%%script echo skipping
#load ALL of the info necessary for the training of the network
#bboxes dict has been saved to dataset hpa-bboxes as a pickle file (bboxesone.pkl)
#combine all the information and save overall dataframe to hpa-data.pkl
import pickle
from tqdm import tqdm

bboxes_dict_file = open("../input/hpa-bboxes/bboxesone.pkl", "rb")
bboxes_dict = pickle.load(bboxes_dict_file)

id_labels_cells_array = pd.DataFrame(columns=['ID','cell','Label','ymin','ymax','xmin','xmax'])
for i,img_id in tqdm(enumerate(id_array)):
    n_cells=len(bboxes_dict[img_id])
    for j in range(n_cells):
        id_labels_cells_array = id_labels_cells_array.append({'ID': str(img_id),'cell':str(j+1),'Label':labels_dict[img_id],'ymin':str(bboxes_dict[img_id][j][0]),'ymax':str(bboxes_dict[img_id][j][1]),'xmin':str(bboxes_dict[img_id][j][2]),'xmax':str(bboxes_dict[img_id][j][3])}, ignore_index=True)
#id_labels_cells_array.to_csv(r'./hpa-dataframe.csv', index = True)
id_labels_cells_array.to_pickle("./hpa-data.pkl")

# Part TWO: Visualization

In [None]:
#%%script echo skipping
#visualize some input images

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

img_id="b8f0d89e-bbc0-11e8-b2bb-ac1f6b6435d0"
base="../input/hpa-single-cell-image-classification/train/"
ext=".png"

r=Image.open(base+img_id+"_red"+ext)
g=Image.open(base+img_id+"_green"+ext)
b=Image.open(base+img_id+"_blue"+ext)
y=Image.open(base+img_id+"_yellow"+ext)

paths=[r,g,b,y]

fig, axs = plt.subplots(1, 4, figsize=(20, 10))
for path, ax, interp in zip(paths, axs, ['Microtubules', 'Protein of Interest', 'Nucleus', 'Endoplasmatic Reticulum']):
    ax.imshow(np.asarray(path), vmin=0, vmax=255)
    ax.set_title(interp)
plt.show()
#plt.savefig('image_layers.png')

In [None]:
%%script echo skipping
#https://www.kaggle.com/hamditarek/exploring-human-protein-atlas-cell-classification
labels_num = [value.split('|') for value in id_labels_array['Label']]
labels_num_flat = list(map(int, [item for sublist in labels_num for item in sublist]))
labels = ["" for _ in range(len(labels_num_flat))]
for i in range(len(labels_num_flat)):
    labels[i] = LABELS[labels_num_flat[i]]

fig, ax = plt.subplots(figsize=(15, 5))
pd.Series(labels).value_counts().plot(kind='barh', fontsize=14)

In [None]:
%%script echo skipping
arr_len = [len(i) for i in labels_num]
lengths=set(arr_len)
count={}
for length in lengths:
    count[length]=[length,arr_len.count(length)]

In [None]:
%%script echo skipping
count=pd.DataFrame.from_dict(count, orient='index',
                       columns=['labels', 'occurence'])

In [None]:
%%script echo skipping
count

# Part THREE: Single-Cell Multi-Label Classification Network  
https://www.kaggle.com/ayuraj/hpa-multi-label-classification-with-tf-and-w-b  
https://cs230.stanford.edu/blog/datapipeline/  
   

In [None]:
#%%script echo skipping
import tensorflow as tf
print(tf.__version__)

from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import tensorflow_addons as tfa

import os
import re
import cv2
import glob
import numpy as np
import pandas as pd
import seaborn as sns
from functools import partial
import matplotlib.pyplot as plt

%matplotlib inline

# Imports for augmentations. 
from albumentations import (
    Compose, RandomCrop, RandomResizedCrop, HorizontalFlip, VerticalFlip, Resize 
)

In [None]:
import pandas as pd
id_labels_cells_array = pd.read_pickle("../input/hpa-cropped-resized-with-labels-and-paths/hpadataframe_cropped_resized_cells_filepaths.pkl")

In [None]:
id_labels_cells_array

In [None]:
import tensorflow as tf
#test the normalization
import matplotlib.pyplot as plt
import numpy as np
def histogram(image):
    # Display image in top subplot using color map 'gray'
    plt.subplot(2,1,1)
    plt.imshow(image, cmap='gray')
    plt.title('Original image')
    plt.axis('off')

    # Flatten the image into 1 dimension: pixels
    pixels = image.flatten()

    # Display a histogram of the pixels in the bottom subplot
    plt.subplot(2,1,2)
    pdf = plt.hist(pixels, bins=64, range=(0,1), density=True,
                   color='red', alpha=0.4)
    plt.grid('off')

    # Use plt.twinx() to overlay the CDF in the bottom subplot
    plt.twinx()

    # Display a cumulative histogram of the pixels
    cdf = plt.hist(pixels, bins=64, range=(0,1),
                   density=True, cumulative=True,
                   color='blue', alpha=0.4)

    # Specify x-axis range, hide axes, add title and display plot
    plt.xlim((0,1))
    plt.grid('off')
    plt.title('PDF & CDF (original image)')
    plt.show()

def numpyhisto(img):
    hist,bins = np.histogram(img.flatten(),256,[0,1])
    cdf = hist.cumsum()
    cdf_normalized = cdf * float(hist.max()) / cdf.max()
    plt.plot(cdf_normalized, color = 'b')
    plt.hist(img.flatten(),256,[0,1], color = 'r')
    plt.xlim([0,1])
    plt.legend(('cdf','histogram'), loc = 'upper left')
    plt.show()

# Load the image into an array: image
image = plt.imread('../input/hpa-composite-images-1-of-20/000a6c98-bb9b-11e8-b2b9-ac1f6b6435d0.png')
histogram(image)
numpyhisto(image)
rgb = tf.io.read_file("../input/hpa-composite-images-1-of-20/000a6c98-bb9b-11e8-b2b9-ac1f6b6435d0.png")
image = tf.image.decode_png(rgb, channels=3)
image=tf.image.per_image_standardization(image)
image=image.numpy()
histogram(image)
numpyhisto(image)

In [None]:
#%%script echo skipping
#WORKING_DIR_PATH = '../input/hpa-single-cell-image-classification/'

IMG_WIDTH = 224
IMG_HEIGHT = 224
BATCH_SIZE = 16

AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
#%%script echo skipping
#shuffle dataframe to prevent overfitting
#from sklearn.utils import shuffle
#id_labels_cells_array = shuffle(id_labels_cells_array)

In [None]:
#%%script echo skipping
#set an amount of folds to split dataframe into --> k-fold cross validation
# explanation: https://towardsdatascience.com/cross-validation-explained-evaluating-estimator-performance-e51e5430ff85
N_FOLDS=5
#choose which one of the 5 folds will be used as validation set this time
i_VAL_FOLD=1
id_labels_cells_array=np.array_split(id_labels_cells_array, N_FOLDS+1) #add one extra part for testing set
df_test_split=id_labels_cells_array[-1]
id_labels_cells_array=id_labels_cells_array[:-1]
i_training = [i for i in range(N_FOLDS)]
i_training.pop(i_VAL_FOLD-1)
i_validation=i_VAL_FOLD-1
df_train_split=list()
for i in i_training:
    df_train_split.append(id_labels_cells_array[i])
df_train_split=pd.concat(df_train_split)
df_val_split=id_labels_cells_array[i_validation]

In [None]:
df_train_split

In [None]:
df_val_split

In [None]:
df_test_split

In [None]:
#%%script echo skipping
#analyze class imbalance and set up class weights here
#https://www.analyticsvidhya.com/blog/2020/10/improve-class-imbalance-class-weights/
y_train=df_train_split["Label"].apply(lambda x:list(map(int, x.split("|"))))
y_train=y_train.values
y_train=np.concatenate(y_train)
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(y_train),
                                                 y_train)

In [None]:
#%%script echo skipping
tmp_dict={}
for i in range(len(LABELS)):
    tmp_dict[i]=class_weights[i]
class_weights=tmp_dict
class_weights

In [None]:
#%%script echo skipping
LABELS

In [None]:
#%%script echo skipping
@tf.function
def multiple_one_hot(cat_tensor, depth_list):
    """Creates one-hot-encodings for multiple categorical attributes and
    concatenates the resulting encodings

    Args:
        cat_tensor (tf.Tensor): tensor with mutiple columns containing categorical features
        depth_list (list): list of the no. of values (depth) for each categorical

    Returns:
        one_hot_enc_tensor (tf.Tensor): concatenated one-hot-encodings of cat_tensor
    """
    one_hot_enc_tensor = tf.one_hot(cat_int_tensor[:,0], depth_list[0], axis=1)
    for col in range(1, len(depth_list)):
        add = tf.one_hot(cat_int_tensor[:,col], depth_list[col], axis=1)
        one_hot_enc_tensor = tf.concat([one_hot_enc_tensor, add], axis=1)

    return one_hot_enc_tensor

def resize_val_image(image, label):
    return tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH]), label
def resize_train_image(image, label):
    return tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH]), label

@tf.function
def load_image(df_dict):
    # Load image
    rgb = tf.io.read_file(df_dict['ID'])
    image = tf.image.decode_png(rgb, channels=3)
    #image = tf.image.crop_to_bounding_box(image, int(df_dict['ymin']), int(df_dict['xmin']), int(df_dict['ymax']) - int(df_dict['ymin']), int(df_dict['xmax']) - int(df_dict['xmin']))
    #image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
    #https://medium.com/@kyawsawhtoon/a-tutorial-to-histogram-equalization-497600f270e2
    #image=tf.image.per_image_standardization(image)
    
    # Parse label
    label = tf.strings.split(df_dict['Label'], sep='|')
    label = tf.strings.to_number(label, out_type=tf.int32)
    label = tf.reduce_sum(tf.one_hot(indices=label, depth=19), axis=0)
    
    return image, label

In [None]:
#%%script echo skipping
# Consume training CSV 
train_ds = tf.data.Dataset.from_tensor_slices(dict(df_train_split))
val_ds = tf.data.Dataset.from_tensor_slices(dict(df_val_split))

# Training Dataset
train_ds = (
    train_ds
    .shuffle(1024)
    .map(load_image, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

# Validation Dataset
val_ds = (
    val_ds
    .shuffle(1024)
    .map(load_image, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

In [None]:
#%%script echo skipping
def get_label_name(labels):
    l = np.where(labels == 1.)[0]
    label_names = []
    for label in l:
        label_names.append(LABELS[label])
        
    return '-'.join(str(label_name) for label_name in label_names)

def show_batch(image_batch, label_batch):
  plt.figure(figsize=(20,20))
  for n in range(10):
      ax = plt.subplot(5,5,n+1)
      plt.imshow(image_batch[n])
      plt.title(get_label_name(label_batch[n].numpy()))
      plt.axis('off')

In [None]:
#%%script echo skipping
# Training batch
image_batch, label_batch = next(iter(train_ds))
show_batch(image_batch, label_batch)
#print(label_batch)

In [None]:
#%%script echo skipping

def get_model():
    base_model = tf.keras.applications.EfficientNetB0(include_top=False, weights='imagenet')
    base_model.trainable = True

    inputs = Input((IMG_HEIGHT, IMG_WIDTH, 3))
    x = base_model(inputs, training=True)
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(len(LABELS), activation='sigmoid')(x)
    
    return Model(inputs, outputs)

tf.keras.backend.clear_session()
model = get_model()
model.summary()

In [None]:
#%%script echo skipping
earlystopper = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=10, verbose=0, mode='min',
    restore_best_weights=True
)

lronplateau = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.5, patience=5, verbose=0,
    mode='auto', min_delta=0.0001, cooldown=0, min_lr=0
)

In [None]:
#%%script echo skipping
#set up checkpoint save
#source:https://www.tensorflow.org/tutorials/keras/save_and_load
!pip install -q pyyaml h5py
import os
checkpoint_path = "./cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

In [None]:
#https://stackoverflow.com/questions/43198613/scikit-learn-f1-score-for-list-of-strings
from sklearn.metrics import f1_score

def f1_weighted(y_true, y_pred):
    f1=f1_score(binarizer.transform(y_true), 
         binarizer.transform(y_pred), 
         average='weighted')
    return f1

In [None]:
#%%script echo skipping
# tf.nn.sigmoid_cross_entropy_with_logits used as loss fct
# tensorflow says it can be used for multi-label multi-class problems
# https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
import keras.backend as K
K_epsilon = K.epsilon()
def f1(y_true, y_pred):
    #y_pred = K.round(y_pred)
    y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), 0.5), K.floatx())
    tp = K.sum(K.cast(y_true*y_pred, 'float'), axis=0)
    tn = K.sum(K.cast((1-y_true)*(1-y_pred), 'float'), axis=0)
    fp = K.sum(K.cast((1-y_true)*y_pred, 'float'), axis=0)
    fn = K.sum(K.cast(y_true*(1-y_pred), 'float'), axis=0)

    p = tp / (tp + fp + K_epsilon)
    r = tp / (tp + fn + K_epsilon)

    f1 = 2*p*r / (p+r+K_epsilon)
    f1 = tf.where(tf.math.is_nan(f1), tf.zeros_like(f1), f1)
    return K.mean(f1)
def f1_loss(y_true, y_pred):
    
    #y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), THRESHOLD), K.floatx())
    tp = K.sum(K.cast(y_true*y_pred, 'float'), axis=0)
    tn = K.sum(K.cast((1-y_true)*(1-y_pred), 'float'), axis=0)
    fp = K.sum(K.cast((1-y_true)*y_pred, 'float'), axis=0)
    fn = K.sum(K.cast(y_true*(1-y_pred), 'float'), axis=0)

    p = tp / (tp + fp + K_epsilon)
    r = tp / (tp + fn + K_epsilon)

    f1 = 2*p*r / (p+r+K_epsilon)
    f1 = tf.where(tf.math.is_nan(f1), tf.zeros_like(f1), f1)
    return 1-K.mean(f1)

In [None]:
#%%script echo skipping
import tensorflow as tf
import timeit

device_name = tf.test.gpu_device_name()
if "GPU" not in device_name:
    print("GPU device not found")
print('Found GPU at: {}'.format(device_name))

In [None]:
#%%script echo skipping
# Initialize model
tf.keras.backend.clear_session()
model = get_model()

#model.load_weights(checkpoint_path)

# Compile model
model.compile(optimizer='adam', loss=f1_loss, metrics=f1)

# Train
history=model.fit(train_ds,
                  epochs=20,
                  validation_data=val_ds,
                  class_weight=class_weights,
                  callbacks=[cp_callback,earlystopper])

In [None]:
#%%script echo skipping
history.history

In [None]:
#%%script echo skipping
#source: https://machinelearningmastery.com/display-deep-learning-model-training-history-in-keras/
# list all data in history
print(history.history.keys())
# summarize history for f1
plt.plot(history.history['f1'])
plt.plot(history.history['val_f1'])
plt.title('model f1 score')
plt.ylabel('f1 score')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

## confusion matrix

In [None]:
#%%script echo skipping
for element in train_ds.as_numpy_iterator():
    x=element[0]
    y_true=element[1]
    break
y_pred = model.predict(x)
print(x.shape)
print(y_true.shape)
print(y_pred.shape)
#print("X:")
#print(x)
print("y_true:")
print(y_true)
print("y_pred:")
print(y_pred)
#y_pred=np.argmax(y_pred,axis=1)
y_pred[y_pred>0.5] = 1
y_pred[y_pred<0.5] = 0
print("y_pred:")
print(y_pred)
#y_pred=np.eye(len(LABELS))[y_pred]
#print("y_pred:")
#print(y_pred)

In [None]:
f1_score(y_true, y_pred, average='weighted')

In [None]:
#%%script echo skipping
#https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html
import sklearn.metrics as skm

cm = skm.multilabel_confusion_matrix(y_true, y_pred)
print(cm)
print(skm.classification_report(y_true,y_pred))

In [None]:
%%script echo skipping
target_names = np.array(list(LABELS.values()))
print(target_names)

In [None]:
%%script echo skipping
import sklearn.metrics as skm
cm = skm.multilabel_confusion_matrix(y_true, y_pred)
print(cm)
print( skm.classification_report(y_true,y_pred))

In [None]:
from mlxtend.evaluate import confusion_matrix

y_target =    [1, 1, 1, 0, 0, 2, 0, 3]
y_predicted = [1, 0, 1, 0, 0, 2, 1, 3]

cm = confusion_matrix(y_target=y_target, 
                      y_predicted=y_predicted, 
                      binary=False)
cm