In [1]:
import os
import cntk as ct

from src.ferplus import FERPlusReader, FERPlusParameters
from src.models import *

In [2]:
def cost_func(training_mode, prediction, target):
    '''
    We use cross entropy in most mode, except for the multi-label mode, which require treating
    multiple labels exactly the same.
    '''
    train_loss = None
    if training_mode == 'majority' or training_mode == 'probability' or training_mode == 'crossentropy': 
        # Cross Entropy.
        train_loss = ct.negate(ct.reduce_sum(ct.element_times(target, ct.log(prediction)), axis=-1))
    elif training_mode == 'multi_target':
        train_loss = ct.negate(ct.log(ct.reduce_max(ct.element_times(target, prediction), axis=-1)))

    return train_loss

In [3]:
# parametros
model_name='VGG13'

training_mode = "crossentropy"
base_folder = 'data'

test_folders  = ['FER2013Test']

In [4]:
# folders
output_model_path   = os.path.join(base_folder, R'models')
output_model_folder = os.path.join(output_model_path, model_name + '_' + training_mode)

if not os.path.exists(output_model_folder):
    os.makedirs(output_model_folder)

In [5]:
emotion_table = {'neutral'  : 0, 
                 'happiness': 1, 
                 'surprise' : 2, 
                 'sadness'  : 3, 
                 'anger'    : 4, 
                 'disgust'  : 5, 
                 'fear'     : 6, 
                 'contempt' : 7}

num_classes = len(emotion_table)

In [6]:
# leitura do modelo
model = build_model(num_classes, model_name)

In [7]:
# set the input variables.
input_var = ct.input((1, model.input_height, model.input_width), np.float32)
label_var = ct.input((num_classes), np.float32)

In [8]:
# training_mode interfere nos labels, fazendo com que seja 0 e 1 ou entre 0 e 1.
# por algum motivo, no test_and_val_params ele está estatico em 'majatory'
# test_and_val_params = FERPlusParameters(num_classes, model.input_height, model.input_width, training_mode, True)
test_and_val_params = FERPlusParameters(num_classes, model.input_height, model.input_width, "majority", True)
test_data_reader = FERPlusReader.create(base_folder, test_folders, "label.csv", test_and_val_params)

In [9]:
epoch_size = test_data_reader.size()
minibatch_size = 32

In [10]:
# get the probalistic output of the model.
z    = model.model(input_var)
pred = ct.softmax(z)

In [11]:
# Training config
lr_per_minibatch       = [model.learning_rate]*20 + [model.learning_rate / 2.0]*20 + [model.learning_rate / 10.0]
mm_time_constant       = -minibatch_size/np.log(0.9)
lr_schedule            = ct.learning_rate_schedule(lr_per_minibatch, unit=ct.UnitType.minibatch, epoch_size=epoch_size)
mm_schedule            = ct.momentum_as_time_constant_schedule(mm_time_constant)

# loss and error cost
train_loss = cost_func(training_mode, pred, label_var)
pe         = ct.classification_error(z, label_var)

In [12]:
# construct the trainer
learner = ct.momentum_sgd(z.parameters, lr_schedule, mm_schedule)
trainer = ct.Trainer(z, (train_loss, pe), learner)
trainer.total_number_of_samples_seen

0

In [13]:
best_epoch = 98
trainer.restore_from_checkpoint(os.path.join(output_model_folder, "model_{}".format(best_epoch)))
trainer.total_number_of_samples_seen

2766753