# 4. Teste do modelo

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import h5py
import keras
import k3d
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
from ipywidgets import interact, widgets

import src

In [None]:
with h5py.File('../data/processed/stanford6_32.h5', 'r') as dataset:
    x_test = np.array(dataset['test/X'])
    y_test = np.array(dataset['test/Y'])

classnames = {
    0: 'Floodplain',
    1: 'Pointbar',
    2: 'Channel',
    3: 'Boundary',
}

model = keras.models.load_model('../models/trained_model_32.h5')
model.summary()

In [None]:
print('Evaluating model...\n')
y_pred = np.argmax(model.predict(x_test), axis=-1)
y_true = np.argmax(y_test, axis=-1)

precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')

matrix = confusion_matrix(y_true, y_pred)
matrix = matrix.astype('float') / matrix.sum(axis=1)[:, np.newaxis]

print(f'Precision: \t{precision}')
print(f'Recall: \t{recall}')
print(f'F1-Score: \t{f1}')

src.plot_confusion_matrix(matrix, classnames.values(), title="Confusion matrix")

In [None]:
image_size = 32
output_shape = [150 + 1 - image_size, 200 + 1 - image_size, 119 + 1 - image_size]

def plot_section(z=0):
    plt.subplot(121)
    plt.imshow(y_true.reshape(output_shape)[:,:,z].T)
    plt.subplot(122)
    plt.imshow(y_pred.reshape(output_shape)[:,:,z].T)

interact(plot_section, z=widgets.IntSlider(min=0,max=119 - image_size,step=1,value=0))

In [None]:
color_map = (0x3A528B, 0x20908C, 0xFDE724)

plot = k3d.plot(camera_auto_fit=False)
obj = k3d.voxels(y_pred.reshape(output_shape).T, color_map, compression_level=1)
plot += obj
plot.camera=[150, 230, -40, 60, 85, 80, 0.0, 0.0, -1.0]
plot.display()

In [None]:
color_map = (0x3A528B, 0x20908C, 0xFDE724)

plot = k3d.plot(camera_auto_fit=False)
obj = k3d.voxels(y_true.reshape(output_shape).T, color_map, compression_level=1)
plot += obj
plot.camera=[150, 230, -40, 60, 85, 80, 0.0, 0.0, -1.0]
plot.display()