In [1]:
import json

import pandas as pd
import tensorflow as tf
import numpy as np

from model.dataset import *
from model.model import *

%load_ext autoreload
%autoreload 2

with open('config.json') as raw_config:
    config = json.load(raw_config)
    
train_source = build_source_from_metadata(pd.read_csv(config['model']['metadata_path']),
                                        config['model']['data_path'], 'train')
train_source[:5]

[('data/train/a5c512d28fc64c7db1bab9090706fea1.jpg', 2),
 ('data/train/3def6df6ae6147be8b98db8730a9e6d2.jpg', 1),
 ('data/train/a35809301db3445dbea67f182481791c.jpg', 2),
 ('data/train/c23efc1434bf4c37ab4f831037818ce6.jpg', 1),
 ('data/train/01400fcdfd4e45128d558c04e2c8a781.jpg', 2)]

In [2]:
model = AlexNet(config, overfit_mode=True)
print(model)

AlexNet

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
CONV1 (Conv2D)               (None, 55, 55, 96)        34944     
_________________________________________________________________
MAX_POOL1 (MaxPooling2D)     (None, 27, 27, 96)        0         
_________________________________________________________________
NORM1 (BatchNormalization)   (None, 27, 27, 96)        384       
_________________________________________________________________
zero_padding2d (ZeroPadding2 (None, 31, 31, 96)        0         
_________________________________________________________________
CONV2 (Conv2D)               (None, 27, 27, 256)       614656    
_________________________________________________________________
MAX_POOL2 (MaxPooling2D)     (None, 13, 13, 256)       0         
_________________________________________________________________
NORM2 (BatchNormalization)   (None, 13, 13, 256

In [3]:
hist = model.train()

W0707 01:55:01.535829 140428569782080 training_utils.py:1436] Expected a shuffled dataset but input dataset `x` is not shuffled. Please invoke `shuffle()` on input dataset.


Epoch 1/20


W0707 01:55:01.791127 140428569782080 deprecation.py:323] From /home/santiago/anaconda3/envs/CV/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch 2/20
Epoch 3/20


KeyboardInterrupt: 

In [None]:
from matplotlib.ticker import MaxNLocator

fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
fig.set_figwidth(20)

ax1.plot(hist.history['loss'], label='train')
ax1.plot(hist.history['val_loss'], label='validation')
ax1.legend()
ax1.set_xlabel('epoch')
ax1.set_title('Loss')


ax2.plot(hist.history['accuracy'], label='train')
ax2.plot(hist.history['val_accuracy'], label='validation')
ax2.legend()
ax2.set_title('Accuracy')
ax2.set_xlabel('epoch')

plt.show()

In [None]:
model.evaluate()

# Whole data

In [None]:
tf.keras.backend.clear_session()
model = AlexNet(config, overfit_mode=False)
print(model)

In [None]:
hist = model.train()

In [None]:
from matplotlib.ticker import MaxNLocator

fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
fig.set_figwidth(20)

ax1.plot(hist.history['loss'], label='train')
ax1.plot(hist.history['val_loss'], label='validation')
ax1.legend()
ax1.set_xlabel('epoch')
ax1.set_title('Loss')


ax2.plot(hist.history['accuracy'], label='train')
ax2.plot(hist.history['val_accuracy'], label='validation')
ax2.legend()
ax2.set_title('Accuracy')
ax2.set_xlabel('epoch')

plt.show()

In [None]:
imshow_with_predictions(model, next(iter(model.test_data)), show_label=True, label_map=config['model']['labels'])

In [None]:
model.evaluate()

In [None]:
plt.rcParams.update({'font.size': 27})
plt.figure(figsize=(20, 20))
plt.imshow(np.log10(compute_confusion_matrix(model).numpy() + 0.0000001))
plt.yticks(np.arange(5), labels=config['model']['labels'])
plt.xticks(np.arange(5), labels=config['model']['labels'])
plt.title('Confusion matrix (log scale)')
plt.colorbar()