# 4. Teste do modelo

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import h5py
import keras
import k3d
from sklearn.metrics import confusion_matrix, f1_score
from ipywidgets import interact, widgets

import src

In [None]:
with h5py.File('../data/processed/stanford6_truncated_rgb.h5', 'r') as dataset:
    x_test = np.array(dataset['test/X'])
    y_test = np.array(dataset['test/Y'])
    
with h5py.File('../data/interim/stanford6_truncated_rgb.h5', 'r') as dataset:
    true_labels = np.array(dataset['test/Y'])  

classnames = {
    0: 'Floodplain',
    1: 'Pointbar',
    2: 'Channel',
    3: 'Boundary',
}

model = keras.models.load_model('../models/trained_model.h5')
model.summary()

In [None]:
scores = model.evaluate(x_test, y_test, verbose=1)
for i, score in enumerate(scores):
    print(f'Test {model.metrics_names[i]}: {score}')
      
predict_class = np.argmax(model.predict(x_test), axis=1)
print(f'F1-score: {f1_score(true_labels, predict_class, average="weighted")}')
matrix = confusion_matrix(true_labels, predict_class)
matrix = matrix.astype('float') / matrix.sum(axis=1)[:, np.newaxis]
      
src.plot_confusion_matrix(matrix, classnames.values(), title="Confusion matrix")

In [None]:
def plot_section(z=0):
    plt.subplot(121)
    plt.imshow(true_labels.reshape(119, 169, 88)[:,:,z].T)
    plt.subplot(122)
    plt.imshow(predict_class.reshape(119, 169, 88)[:,:,z].T)
    
interact(plot_section, z=widgets.IntSlider(min=0,max=87,step=1,value=0))

In [None]:
color_map = (0x3A528B, 0x20908C, 0xFDE724)

plot = k3d.plot(camera_auto_fit=False)
obj = k3d.voxels(predict_class.reshape(119, 169, 88).T, color_map, compression_level=1)
plot += obj
plot.camera=[150, 230, -40, 60, 85, 80, 0.0, 0.0, -1.0]
plot.display()

In [None]:
color_map = (0x3A528B, 0x20908C, 0xFDE724)

plot = k3d.plot(camera_auto_fit=False)
obj = k3d.voxels(true_labels.reshape(119, 169, 88).T, color_map, compression_level=1)
plot += obj
plot.camera=[150, 230, -40, 60, 85, 80, 0.0, 0.0, -1.0]
plot.display()