In [None]:
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import numpy as np
from PIL import Image
import hsn_v1
from tensorflow import keras
import csv
import tensorflow as tf
from matplotlib import pyplot as plt

In [None]:
folds_path = 'folds/glas/split_0/'

In [None]:
folds_files = [str(path) for path in Path(folds_path).rglob('*.csv')]

In [None]:
def make_filter(pattern):
    def filter_inst(file):
        if pattern in file:
            return True
        else:
            return False
    return filter_inst

### Store and sort CV csv files

In [None]:
valid_csv = list(filter(make_filter('valid'), folds_files))
valid_csv.sort()

In [None]:
test_csv = list(filter(make_filter('test'), folds_files))
test_csv.sort()

In [None]:
train_csv = list(filter(make_filter('train'), folds_files))
train_csv.sort()

#### Get file names

In [None]:
import pandas as pd

In [None]:
def read_csv(file):
    cols = [0,1,2]
    col_names = ['img', 'gt', 'class']    
    df = pd.read_csv(file, header=None, usecols=cols, names=col_names)
    files = df[col_names[0]].tolist()
    names = [f.replace('.bmp', '.png') for f in files]
    return names

In [None]:
train_files = [read_csv(file) for file in train_csv]
test_files = [read_csv(file) for file in test_csv]
valid_files = [read_csv(file) for file in valid_csv]

#### Find patches

In [None]:
imgs_path = 'img/02_glas_full'

In [None]:
glas_paths = [str(path) for path in Path(imgs_path).rglob('*.png')]

In [None]:
def get_patches_files(folds, all_files):
    out = []
    for fold in folds:
        matches = []
        for name in fold:
            for f in all_files:
                if name in f:
                    matches.append(f)
        out.append(list(set(matches)))
    return out

In [None]:
train_patches = get_patches_files(train_files, glas_paths)

In [None]:
test_patches = get_patches_files(test_files, glas_paths)

In [None]:
val_patches = get_patches_files(valid_files, glas_paths)

In [None]:
for fold in train_patches:
    print(len(fold))

### Load model

In [None]:
IS_FINETUNE = False

In [None]:
if IS_FINETUNE:
    MODEL_NAME = 'histonet_X1.7_clrdecay_5'
else:
    MODEL_NAME = 'histonet_glas'
    
MODEL_NAME

In [None]:
INPUT_NAME = '02_glas_full'
INPUT_MODE = 'patch'                    # {'patch', 'wsi'}
INPUT_SIZE = [224, 224]                 # [<int>, <int>] > 0
HTT_MODE = 'glas'                       # {'both', 'morph', 'func', 'glas'}
BATCH_SIZE = 1                          # int > 0
GT_MODE = 'on'                          # {'on', 'off'}
RUN_LEVEL = 3                           # {1: HTT confidence scores, 2: Grad-CAMs, 3: Segmentation masks}
SAVE_TYPES = [1, 1, 1, 1]               # {HTT confidence scores, Grad-CAMs, Segmentation masks, Summary images}
VERBOSITY = 'QUIET'                    # {'NORMAL', 'QUIET'}
# Settings for image set
IN_PX_RESOL = 0.620
OUT_PX_RESOL = 0.25 * 1088 / 224    # 1.21428571429
DOWNSAMPLE_FACTOR = OUT_PX_RESOL / IN_PX_RESOL

In [None]:
hsn = hsn_v1.HistoSegNetV1(params={'input_name': INPUT_NAME, 'input_size': INPUT_SIZE, 'input_mode': INPUT_MODE,
                                       'down_fac': DOWNSAMPLE_FACTOR, 'batch_size': BATCH_SIZE, 'htt_mode': HTT_MODE,
                                       'gt_mode': GT_MODE, 'run_level': RUN_LEVEL, 'save_types': SAVE_TYPES,
                                       'verbosity': VERBOSITY})

In [None]:
hsn.load_histonet(params={'model_name': MODEL_NAME}, pretrained=IS_FINETUNE)

In [None]:
histonet = hsn.hn
histonet.model.summary()

### Load images

In [None]:
def preprocess(x, y):
    
    # Random crop and resize     
    crop_size = [416, 416, 3]
    resize_size = [224, 224]
    x = tf.image.random_crop(x, crop_size)
    x = tf.image.resize(x, resize_size)
    
    # Color shifts
    x = tf.image.random_hue(x, 0.5)
    x = tf.image.random_saturation(x, 0.5, 1.5)
    x = tf.image.random_brightness(x, 0.5)
    x = tf.image.random_contrast(x, 0.5, 1.5)
    
    # Random rotation
    x = tf.image.rot90(x, tf.random_uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
    
    # Normalize
    x = histonet.normalize_image(x, is_glas=True)
    train_mean = 193.09203
    train_std = 56.450138
    x = tf.clip_by_value(x, 0, 255)
    x = (x - train_mean)/(train_std + 1e-7)
    
    return x, y

In [None]:
def load_images(folds):
    X = []
    Y = []
    
    GO_INDEX = 48
    NUM_CLASSES = 51
    
    for fold in folds:
        imgs = np.zeros((len(fold), 522, 775, 3))
        for i, f in enumerate(fold):
            img = np.asarray(Image.open(f), dtype="int32")
            imgs[i] = np.resize(img, (522,775,3))
        
        X.append(imgs)
        
        # Create labels, only class is G.O
        y = np.zeros((len(imgs), NUM_CLASSES))
        y[:,GO_INDEX] = 1
        Y.append(y)
    
    return np.array(X), np.array(Y)

In [None]:
X_train_folds, Y_train_folds = load_images(train_patches)
X_test_folds, Y_test_folds = load_images(test_patches)
X_val_folds, Y_val_folds = load_images(val_patches)

In [None]:
Y_train_folds[4].shape

In [None]:
def load_datasets(X_folds, Y_folds):
    
    dataset_folds = []
    
    for X, Y in zip(X_folds, Y_folds):
        print(X.shape, Y.shape)
        dataset = tf.data.Dataset.from_tensor_slices((X, Y))
        dataset = dataset.map(preprocess)
        
        dataset_folds.append(dataset)
    
    return dataset_folds

In [None]:
train_datasets = load_datasets(X_train_folds, Y_train_folds)
test_datasets = load_datasets(X_test_folds, Y_test_folds)
val_datasets = load_datasets(X_val_folds, Y_val_folds)

### Train model

In [None]:
model = histonet.model

In [None]:
if IS_FINETUNE:
    weights_path = 'data/histonet_glas_ft.h5'
else:
    weights_path = 'data/histonet_glas.h5'
model_chkpt = keras.callbacks.ModelCheckpoint(filepath=weights_path, monitor='val_loss', verbose=1,
                                             save_best_only=True, save_weights_only=True)
weights_path

In [None]:
num_epochs = 30
batch_size = 1
num_folds = X_train_folds.shape[0]
num_folds

In [None]:
train_steps_per_epoch = [(len(fold))//batch_size for fold in train_patches]
val_steps_per_epoch = [(len(fold))//batch_size for fold in val_patches]

train_steps_per_epoch, val_steps_per_epoch

In [None]:
for i in range(num_folds):
    
    print('***** FOLD {} *****'.format(i))
    
    train_dataset = train_datasets[i].batch(batch_size)
    test_dataset = test_datasets[i].batch(batch_size)
    val_dataset = val_datasets[i].batch(batch_size)
    
    model.fit(train_dataset, 
              epochs=num_epochs, 
              validation_data=val_dataset, 
              steps_per_epoch=train_steps_per_epoch[i], 
              validation_steps=val_steps_per_epoch[i])