In [None]:
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import numpy as np
from PIL import Image
import hsn_v1
import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
# path = 'img/02_glas_full'
path = 'img/02_glas_patch'

In [None]:
glas_paths = [str(path) for path in Path(path).rglob('*.png')]

In [None]:
glas_paths

In [None]:
glas_test = [file for file in glas_paths if 'test' in file]
glas_train = [file for file in glas_paths if 'train' in file]

In [None]:
X_train = []
for file in glas_train:
    img = Image.open(file)
    X_train.append(np.array(img))
X_train = np.array(X_train)

In [None]:
X_test = []
for file in glas_test:
    img = Image.open(file)
    X_test.append(np.array(img))
X_test = np.array(X_test)

In [None]:
X_train.shape, X_test.shape

### Prepare labels

In [None]:
GO_INDEX = 48

In [None]:
y_train = np.zeros((X_train.shape[0], 51))
y_train[:,GO_INDEX] = 1

In [None]:
y_test = np.zeros((X_test.shape[0], 51))
y_test[:,GO_INDEX] = 1

In [None]:
# Validation set split 
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
X_train.shape, X_val.shape

### Load model

In [None]:
MODEL_NAME = 'histonet_X1.7_clrdecay_5'
INPUT_NAME = '02_glas_full'
INPUT_MODE = 'patch'                    # {'patch', 'wsi'}
INPUT_SIZE = [224, 224]                 # [<int>, <int>] > 0
HTT_MODE = 'glas'                       # {'both', 'morph', 'func', 'glas'}
BATCH_SIZE = 1                          # int > 0
GT_MODE = 'on'                          # {'on', 'off'}
RUN_LEVEL = 3                           # {1: HTT confidence scores, 2: Grad-CAMs, 3: Segmentation masks}
SAVE_TYPES = [1, 1, 1, 1]               # {HTT confidence scores, Grad-CAMs, Segmentation masks, Summary images}
VERBOSITY = 'QUIET'                    # {'NORMAL', 'QUIET'}
# Settings for image set
IN_PX_RESOL = 0.620
OUT_PX_RESOL = 0.25 * 1088 / 224    # 1.21428571429
DOWNSAMPLE_FACTOR = OUT_PX_RESOL / IN_PX_RESOL

In [None]:
hsn = hsn_v1.HistoSegNetV1(params={'input_name': INPUT_NAME, 'input_size': INPUT_SIZE, 'input_mode': INPUT_MODE,
                                       'down_fac': DOWNSAMPLE_FACTOR, 'batch_size': BATCH_SIZE, 'htt_mode': HTT_MODE,
                                       'gt_mode': GT_MODE, 'run_level': RUN_LEVEL, 'save_types': SAVE_TYPES,
                                       'verbosity': VERBOSITY})

In [None]:
hsn.load_histonet(params={'model_name': MODEL_NAME}, pretrained=True)

In [None]:
histonet = hsn.hn
histonet.model.summary()

In [None]:
X_train.max(), X_train.min(), X_test.max(), X_test.min(), X_val.max(), X_val.min()

In [None]:
X_train = histonet.normalize_image(X_train, is_glas=True)
X_test = histonet.normalize_image(X_test, is_glas=True)
X_val = histonet.normalize_image(X_val, is_glas=True)
X_train.max(), X_train.min(), X_test.max(), X_test.min(), X_val.max(), X_val.min()

### Train model

In [None]:
num_epochs = 30
batch_size = 8
steps_per_epoch = X_train.shape[0] // batch_size

In [None]:
model = histonet.model

In [None]:
model_chkpt = keras.callbacks.ModelCheckpoint(filepath='data/histonet_glas_holdout_ft.h5', monitor='val_loss', verbose=1,
                                             save_best_only=True, save_weights_only=True)

In [None]:
train_gen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True)
train_generator = train_gen.flow(X_train, y_train, batch_size=batch_size)

In [None]:
history = model.fit_generator(train_generator, 
                        epochs=num_epochs, 
                        verbose=1, 
                        shuffle=True, 
                        callbacks=[model_chkpt],
                        validation_data=(X_val, y_val),
                        steps_per_epoch=steps_per_epoch)

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label = 'val loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')

In [None]:
history.history['val_loss']