# model-training

Training the Next Basket predictor.

Requires:
1. `data/x_train.npy`
2. `data/y_train.npy`
3. `data/x_test.npy`
4. `data/y_test.npy`

Produces:
1. `model/model.hdf5`

In [1]:
import sys
sys.path.append('..')

In [2]:
import numpy as np
import tensorflow as tf

from config import x_train_fpath, y_train_fpath, x_test_fpath, y_test_fpath, model_fpath

In [3]:
x_train = np.load(x_train_fpath).astype(np.float64)
y_train = np.load(y_train_fpath).astype(np.float64)
x_test = np.load(x_test_fpath).astype(np.float64)
y_test = np.load(y_test_fpath).astype(np.float64)

x_train.shape, y_train.shape, x_test.shape, y_test.shape

((26038, 5, 243), (26038, 243), (8560, 5, 243), (8560, 243))

In [4]:
x_train = np.log10(1 + x_train)
y_train = np.log10(1 + y_train)
x_test = np.log10(1 + x_test)
y_test = np.log10(1 + y_test)

x_train.max(), y_train.max(), x_test.max(), y_test.max()

(2.1492191126553797, 2.1367205671564067, 2.1492191126553797, 2.123851640967086)

In [5]:
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=x_train[0].shape),
    # tf.keras.layers.LSTM(128, dropout=0.1, return_sequences=True),
    tf.keras.layers.LSTM(128, dropout=0.2),
    tf.keras.layers.Dense(y_train.shape[1], activation='relu'),
])

model.compile(optimizer='adam', loss='mse')

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  (None, 128)               190464    
_________________________________________________________________
dense (Dense)                (None, 243)               31347     
Total params: 221,811
Trainable params: 221,811
Non-trainable params: 0
_________________________________________________________________


In [6]:
model.save(model_fpath)

In [7]:
model.fit(
    x_train, y_train,
    epochs=24,
    batch_size=16,
    validation_data=(x_test, y_test),
)

model.save(model_fpath)

Epoch 1/24
Epoch 2/24
Epoch 3/24
Epoch 4/24
Epoch 5/24
Epoch 6/24
Epoch 7/24
Epoch 8/24
Epoch 9/24
Epoch 10/24
Epoch 11/24
Epoch 12/24
Epoch 13/24
Epoch 14/24
Epoch 15/24
Epoch 16/24
Epoch 17/24
Epoch 18/24
Epoch 19/24
Epoch 20/24
Epoch 21/24
Epoch 22/24
Epoch 23/24
Epoch 24/24
