# Example 4: Retraining a model 

This example demonstrates how training (re-train) can be continued from saved states. The states here referred to 
saved baseline model, saved ESO model and ESO chromosome. 

In [10]:
### Import libraries
import pickle
import os
import json

import torch

import eso.utils as utils
from eso.model.model import Model

from pathlib import Path

## For logging info during execution
import logging

In [12]:
RESULTS_PATH = Path('/home/aaron-joel/Documents/Examples/results')
# saved baseline CNN model path
BASELINE_CNN_STATE_PATH = RESULTS_PATH / 'baseline_cnn_state.pth'
# saved ESO CNN model path
ESO_CNN_STATE_PATH = RESULTS_PATH / 'chromosome_cnn_state.pth'
# saved path of best performing chromosome
CHROMOSOME_PKL_PATH = RESULTS_PATH / 'eso_chromosome.pkl'
# Audio data folder (A small dataset for demonstration purpose only)
DATA_PATH = Path('/home/aaron-joel/Documents/Examples/SmallData/SavedData')
# Dataset generated for baseline training
X_TRAIN_BASE_PATH = DATA_PATH / 'preprocessed' / 'train' / 'X.pkl'
Y_TRAIN_BASE_PATH = DATA_PATH / 'preprocessed' / 'train' / 'Y.pkl'

# Dataset generated for baseline validation
X_VAL_BASE_PATH = DATA_PATH / 'preprocessed' / 'validation' / 'X.pkl'
Y_VAL_BASE_PATH = DATA_PATH / 'preprocessed' / 'validation' / 'Y.pkl'

# Dataset generated for eso training
X_TRAIN_ESO_PATH = DATA_PATH / 'unpreprocessed' / 'train' / 'X.pkl'
Y_TRAIN_ESO_PATH = DATA_PATH / 'unpreprocessed' / 'train' / 'Y.pkl'

# Dataset generated for eso validation
X_VAL_ESO_PATH = DATA_PATH / 'unpreprocessed' / 'validation' / 'X.pkl'
Y_VAL_ESO_PATH = DATA_PATH / 'unpreprocessed' / 'validation' / 'Y.pkl'

In [13]:
## Constant parameters
EPOCHS = 10

## Model args
MODEL_ARGS = {
        "optimizer_name": "adam",
        "loss_function_name": "cross_entropy",
        "num_epochs": EPOCHS,
        "batch_size": 128,
        "learning_rate": 0.001,
        "shuffle": True,
        "metric": "f1",
    }

In [14]:
## Setup logger for collecting info during execution
logger = logging.getLogger('eso')
logger.setLevel(logging.INFO)

In [16]:
## Read baseline data from Pickle file
X_train_base = pickle.load(open(X_TRAIN_BASE_PATH, 'rb'))
Y_train_base = pickle.load(open(Y_TRAIN_BASE_PATH, 'rb'))

X_val_base = pickle.load(open(X_VAL_BASE_PATH, 'rb'))
Y_val_base = pickle.load(open(Y_VAL_BASE_PATH, 'rb'))

print(f'The shape of X_train_base: {X_train_base.shape}')
print(f'The shape of Y_train_base: {Y_train_base.shape}')

print(f'The shape of X_val_base: {X_val_base.shape}')
print(f'The shape of Y_val_base: {Y_val_base.shape}')

The shape of X_train_base: (3452, 128, 76)
The shape of Y_train_base: (3452, 2)
The shape of X_val_base: (338, 128, 76)
The shape of Y_val_base: (338, 2)


In [23]:
print('----RE-TRAINING BASELINE MODEL----')

# Read image shape from json file
with open(RESULTS_PATH / 'baseline.json', 'r') as read_file:
    data = json.load(read_file)
    
image_shape = data['image_shape']
print(f"Image shape for baseline model is: {image_shape}", end='\n\n')

## Instantiate a baseline model
model = Model(input_shape=(1, *image_shape), **MODEL_ARGS, logger=logger)
print(model)

## Train the model
model.train(X_train=X_train_base, Y_train=Y_train_base)

## Save trained model
model.save_model(RESULTS_PATH, 'baseline_example4')

----RE-TRAINING BASELINE MODEL----
Image shape for baseline model is: [128, 76]

BaseCNN(
  (conv_layers): Sequential(
    (conv0): Conv2d(1, 8, kernel_size=(8, 8), stride=(1, 1))
    (relu0): ReLU()
    (dropout0): Dropout(p=0.5, inplace=False)
    (maxpool0): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layers): Sequential(
    (fc0): Linear(in_features=4080, out_features=32, bias=True)
    (relu0): ReLU()
    (dropout0): Dropout(p=0.5, inplace=False)
    (fc1): Linear(in_features=32, out_features=32, bias=True)
    (relu1): ReLU()
    (dropout1): Dropout(p=0.5, inplace=False)
  )
  (output_layer): Linear(in_features=32, out_features=2, bias=True)
  (softmax): Softmax(dim=1)
)


In [26]:
## Evaluate the trained model and get the metrics
baseline_metric, _name = model.evaluate(X_val=X_val_base, Y_val=Y_val_base)

print('----BASELINE MODEL TRAINING FINISHED---')
print(f'BASELINE METRIC {_name} : {baseline_metric}')

----BASELINE MODEL TRAINING FINISHED---
BASELINE METRIC F1 : 0.7955801104972375


In [27]:
# Delete some data (useful especially for large dataset)
del model, X_train_base, Y_train_base, X_val_base, Y_val_base

### Performing the training with an ESO classifier

In [29]:
# Read saved ESO related dataset
X_train_eso = pickle.load(open(X_TRAIN_ESO_PATH, 'rb'))
Y_train_eso = pickle.load(open(Y_TRAIN_ESO_PATH, 'rb'))

X_val_eso = pickle.load(open(X_VAL_ESO_PATH, 'rb'))
Y_val_eso = pickle.load(open(Y_VAL_ESO_PATH, 'rb'))

print(f'The shape of X_train_eso is: {X_train_eso.shape}')
print(f'The shape of Y_train_eso is: {Y_train_eso.shape}')
print(f'The shape of X_val_eso is: {X_val_eso.shape}')
print(f'The shape of Y_val_eso is: {Y_val_eso.shape}')

The shape of X_train_eso is: (3452, 128, 151)
The shape of Y_train_eso is: (3452, 2)
The shape of X_val_eso is: (338, 128, 151)
The shape of Y_val_eso is: (338, 2)


In [32]:
## Read the saved chromosome 
eso_chrom_fb = open(RESULTS_PATH / 'eso_chromosome.pkl', 'rb')
eso_chromosome = utils.unpickler.CPU_Unpickler(eso_chrom_fb).load()
print(eso_chromosome)
eso_chrom_fb.close()

Chromosome Info:
Number of Genes: 4
Validation F1: 1.0
Trainable parameters: 38538
Fitness: 0.3154045332190305
Genes: Gene 1: (13, 6)
Gene 2: (40, 7)
Gene 3: (56, 6)
Gene 4: (81, 7)




In [33]:
## Set MODEL ARGS
eso_chromosome._model_args = MODEL_ARGS

In [35]:
print("---TRAINING CHROMOSOME MODEL---")

## Train
eso_chromosome.train(X_train=X_train_eso, Y_train=Y_train_eso, X_val=X_val_eso, Y_val=Y_val_eso)

## Save output of training
eso_chromosome.save(RESULTS_PATH, 'eso_chromosome4')

## Save trained model
eso_chromosome.save_model(RESULTS_PATH)

---TRAINING CHROMOSOME MODEL---


In [36]:
## Show the metrics
print(f"ESO CHROMOSOME METRIC:  {eso_chromosome.get_metric()}")
print(f"ESO CHROMOSOME FITNESS: {eso_chromosome.get_fitness()}")

ESO CHROMOSOME METRIC:  0.9977011494252873
ESO CHROMOSOME FITNESS: 0.31395170633298114
