In [None]:
import numpy as np
import pandas as pd
import os
import h5py
import glob

from utils.dataGenerator import DataGenerator, DataGenerator_metaData
from models import models

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow as tf 
from tensorflow import keras

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [None]:
tf.test.gpu_device_name()
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

from tensorflow.python.client import device_lib 
print(device_lib.list_local_devices())


In [None]:
# data folders
patches_folder = r"D:\annotated_slides\separate_patches_and_labels_hooknet_exp_1"
svsfolder = r"D:\annotated_slides\Slides"

slide_patches = glob.glob(os.path.join(patches_folder,'*.h5'))
print("number of slide patches: ", len(slide_patches))

### train, validation and test splits

In [None]:
# # Load csv file
# df = pd.read_csv(os.path.join(svsfolder, "csv_file.csv"))

# slide_ids = list(df['slide_id'].values)
# train_ids, val_test_ids = train_test_split(slide_ids, test_size=0.30, random_state=42)
# validation_ids, test_ids = train_test_split(val_test_ids, test_size=0.35, random_state=42)

# print("train: {:d}, validation: {:d}, test: {:d}".format(len(train_ids), len(validation_ids), len(test_ids)))  

# train_patches = []
# validation_patches = []
# test_patches = []
# for patch_path in slide_patches:
#     patch_name = patch_path.split("\\")[-1]
    
#     for tran_id in train_ids:
#         if tran_id in patch_name:
#             train_patches.append(patch_name)
#             break
    
#     for val_id in validation_ids:
#         if val_id in patch_name:
#             validation_patches.append(patch_name)
#             break
            
#     for test_id in test_ids:
#         if test_id in patch_name:
#             test_patches.append(patch_name)
#             break
            
# # # print([x for x in train_patches if x in validation_pathes])
# print("train patches: {:d}, validation patches: {:d}, test patches: {:d}".format(
#     len(train_patches), len(validation_patches), len(test_patches)))  

# partition = {'train': train_patches,
#              'validation': validation_patches}

classmeta_train = pd.read_csv(os.path.join(patches_folder, "classmeta_train.csv"))
classmeta_validation = pd.read_csv(os.path.join(patches_folder, "classmeta_validation.csv"))
classmeta_test = pd.read_csv(os.path.join(patches_folder, "classmeta_test.csv"))

print("train patches: {:d}, validation patches: {:d}, test patches: {:d}".format(
    len(classmeta_train), len(classmeta_validation), len(classmeta_test)))  


In [None]:
# # calculate sample weights
# class_counts = {}
# for patch in train_patches[0:20]:
#     patchfile = os.path.join(patches_folder, patch)  # train_patches[1])
# #     print("patchfile: ", patchfile)
    
#     with h5py.File(patchfile, 'r') as f:
#         seg = f['patches_20x']['segmentation'][:]
#         labels, counts = np.unique(seg, return_counts=True)
#         dict_count = dict(zip(labels, counts))
# #         print(dict_count)
        
#         for label in labels:
#             print(label)
#             if label in class_counts.keys():
#                 class_counts[label] += dict_count[label]
#             else:
#                 class_counts[label] = dict_count[label]
                
# print(class_counts)

In [None]:
# with h5py.File(patchfile, 'r') as f:
#     print(list(f.keys()))
#     print(list(f.values()))

#     patch = f['patches_20x']['patch'][:]
#     seg = f['patches_20x']['segmentation'][:]
#     print(patch.shape)
#     print(patch.dtype)
#     print(seg.shape)
#     print(seg.dtype)
    

### model folder

In [None]:
main_folder = r"D:\annotated_slides"
modeldir = os.path.join(main_folder, 'models')
os.makedirs(modeldir, exist_ok=True)


### Generators

In [None]:
params = {'dim': (256, 256),
          'batch_size': 14, 
          'n_classes': 7,
          'n_channels': 3,
          'shuffle': True,
          'data_folder': r"D:\annotated_slides\separate_patches_and_labels_hooknet_exp_1"}


In [None]:
logdir = os.path.join(modeldir, "logs")
os.makedirs(logdir, exist_ok=True)

my_callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss', min_delta=0, patience=50, verbose=0, mode='auto'),
    tf.keras.callbacks.ModelCheckpoint(
#         filepath=os.path.join(modeldir, "model.{epoch:02d}-{val_loss:.2f}.h5"),
        filepath=os.path.join(modeldir, "model_hooknet_exp_1.h5"),
        monitor='val_loss', save_best_only=True),
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
]


In [None]:
# model architecture
model = models.context_target_unet(context_input=params['dim'] + (params['n_channels'], ) ,
                                   target_input= params['dim'] + (params['n_channels'], ) ,
                                   nClass=params['n_classes'])


In [None]:
# plot the model
# keras.utils.plot_model(model, "context_target_unet.png", show_shapes=True)

In [None]:
# Compile
# model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), 
#               loss='categorical_crossentropy',
# #               sample_weight_mode="temporal",
#               metrics=['accuracy'])


model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss={
        "out_context": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "out_target": keras.losses.CategoricalCrossentropy(from_logits=True),
    },
    sample_weight_mode={
        "out_context": "temporal",
        "out_target": "temporal"
    },
    loss_weights=[0.25, 0.75],
    metrics=['accuracy']
)


In [None]:
# Generators
training_generator = DataGenerator_metaData(classmeta_train, **params)
validation_generator = DataGenerator_metaData(classmeta_validation, **params)

In [None]:
# example data generator
max_iter = 1  # maximum number of iterations, in each iteration one batch is generated; the proper value depends on batch size and size of whole data
i = 0
for (dc, dt), (lc, lt), (swc, swt) in training_generator:
    i += 1
    if i == max_iter:
        break

In [None]:
i = 0
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(dc[i])
ax[1].imshow(dt[i])

print("swc: ", np.unique(swc[i], return_counts=True))
print("swt: ", np.unique(swt[i], return_counts=True))

In [None]:
model.fit_generator(generator=training_generator,
                    validation_data=validation_generator,
                    callbacks=my_callbacks,
                    epochs=200)
#                     use_multiprocessing=True,
#                     workers=4)

In [None]:
# add data augmenttaion
# add option for different magnifications
# class/sample weights: calculate over a batch or entire training data (what are the class stats??) 

In [None]:
import numpy as np


In [None]:
n_classes = 11
mask = np.random.randint(0, n_classes-8, (256, 256))

In [None]:
unique, counts = np.unique(mask, return_counts=True)
dict_count = dict(zip(unique, counts))
print("dict_count_1: ", dict_count)

dict_count.pop(0, None)
print("dict_count_2: ", dict_count)

# if a class is not there, add count 0 for that class ???
for i in range(1, n_classes):
    if i not in dict_count.keys():
        dict_count[i] = 0
        
print("dict_count_3: ", dict_count)

counts_sum = 0
for i in dict_count.keys():
    counts_sum += dict_count[i]
    
print("counts sum:", counts_sum)

# calculate weights
class_weight_patch_dict = {}
for key in dict_count.keys():
    class_weight_patch_dict[key] = 1.0 - dict_count[key] / counts_sum

print("class_weight_patch_dict: ", class_weight_patch_dict)

weights_sum = 0
for i in class_weight_patch_dict.keys():
    weights_sum += class_weight_patch_dict[i]

print("weights_sum: ", weights_sum)

class_weight_patch_norm_dict = {}
for key in class_weight_patch_dict.keys():
    class_weight_patch_norm_dict[key] = class_weight_patch_dict[key] / weights_sum

print("class_weight_patch_norm_dict", class_weight_patch_norm_dict)

weights_sum = 0
for i in class_weight_patch_norm_dict:
    weights_sum += class_weight_patch_norm_dict[i]

print("weights_sum norm: ", weights_sum)
