In [1]:
#include <backprop_tools/operations/cpu.h>
#include <backprop_tools/nn_models/models.h>
#include <backprop_tools/nn/operations_cpu.h>
#include <backprop_tools/nn_models/operations_cpu.h>
#include <backprop_tools/containers/persist.h>
namespace bpt = backprop_tools;

In [2]:
#include <random>
#include <chrono>
#include <highfive/H5File.hpp>

In [3]:
#pragma cling load("hdf5")

In [4]:
using T = float;
using DEVICE = bpt::devices::DefaultCPU;
using TI = typename DEVICE::index_t;

In [5]:
constexpr TI BATCH_SIZE = 32;
constexpr TI NUM_EPOCHS = 1;
constexpr TI INPUT_DIM = 28 * 28;
constexpr TI OUTPUT_DIM = 10;
constexpr TI NUM_LAYERS = 3;
constexpr TI HIDDEN_DIM = 50;
constexpr TI DATASET_SIZE_TRAIN = 60000;
constexpr TI DATASET_SIZE_VAL = 10000;
constexpr TI NUM_BATCHES = DATASET_SIZE_TRAIN / BATCH_SIZE;

In [6]:
using StructureSpecification = bpt::nn_models::mlp::StructureSpecification<T, DEVICE::index_t, INPUT_DIM, OUTPUT_DIM, NUM_LAYERS, HIDDEN_DIM, bpt::nn::activation_functions::RELU, bpt::nn::activation_functions::IDENTITY, 1>;

using OPTIMIZER_PARAMETERS = bpt::nn::optimizers::adam::DefaultParametersTF<T>;
using OPTIMIZER = bpt::nn::optimizers::Adam<OPTIMIZER_PARAMETERS>;
using NETWORK_SPEC = bpt::nn_models::mlp::AdamSpecification<StructureSpecification>;
using MODEL_TYPE = bpt::nn_models::mlp::NeuralNetworkAdam<NETWORK_SPEC>;

In [7]:
std::string dataset_path = "/home/mambauser/mnist.hdf5";

In [8]:
DEVICE::SPEC::LOGGING logger;
DEVICE device;
OPTIMIZER optimizer;
device.logger = &logger;
MODEL_TYPE model;
typename MODEL_TYPE::Buffers<1> buffers;

bpt::MatrixDynamic<bpt::matrix::Specification<T, TI, DATASET_SIZE_TRAIN, INPUT_DIM>> x_train;
bpt::MatrixDynamic<bpt::matrix::Specification<T, TI, DATASET_SIZE_VAL, INPUT_DIM>> x_val;
bpt::MatrixDynamic<bpt::matrix::Specification<TI, TI, DATASET_SIZE_TRAIN, 1>> y_train;
bpt::MatrixDynamic<bpt::matrix::Specification<TI, TI, DATASET_SIZE_VAL, 1>> y_val;

bpt::MatrixDynamic<bpt::matrix::Specification<T, DEVICE::index_t, 1, OUTPUT_DIM, bpt::matrix::layouts::RowMajorAlignment<typename DEVICE::index_t>>> d_loss_d_output_matrix;
bpt::MatrixDynamic<bpt::matrix::Specification<T, DEVICE::index_t, 1, INPUT_DIM, bpt::matrix::layouts::RowMajorAlignment<typename DEVICE::index_t>>> d_input_matrix;

bpt::malloc(device, model);
bpt::malloc(device, buffers);
bpt::malloc(device, x_train);
bpt::malloc(device, y_train);
bpt::malloc(device, x_val);
bpt::malloc(device, y_val);
bpt::malloc(device, d_loss_d_output_matrix);
bpt::malloc(device, d_input_matrix);

In [9]:
auto data_file = HighFive::File(dataset_path, HighFive::File::ReadOnly);
bpt::load(device, x_train, data_file.getGroup("train"), "inputs");
bpt::load(device, y_train, data_file.getGroup("train"), "labels");
bpt::load(device, x_val, data_file.getGroup("test"), "inputs");
bpt::load(device, y_val, data_file.getGroup("test"), "labels");

In [10]:
bpt::reset_optimizer_state(device, model, optimizer);
auto rng = bpt::random::default_engine(typename DEVICE::SPEC::RANDOM(), 2);
bpt::init_weights(device, model, rng);

In [11]:
for (int batch_i=0; batch_i < NUM_BATCHES; batch_i++){
    T loss = 0;
    bpt::zero_gradient(device, model);
    for (int sample_i=0; sample_i < BATCH_SIZE; sample_i++){
        auto input = bpt::row(device, x_train, batch_i * BATCH_SIZE + sample_i);
        auto output = bpt::row(device, y_train, batch_i * BATCH_SIZE + sample_i);
        auto prediction = bpt::row(device, model.output_layer.output, 0);
        bpt::forward(device, model, input);
        bpt::nn::loss_functions::categorical_cross_entropy::gradient(device, prediction, output, d_loss_d_output_matrix, T(1)/((T)BATCH_SIZE));
        loss += bpt::nn::loss_functions::categorical_cross_entropy::evaluate(device, prediction, output, T(1)/((T)BATCH_SIZE));
        T d_input[INPUT_DIM];
        d_input_matrix._data = d_input;
        bpt::backward(device, model, input, d_loss_d_output_matrix, d_input_matrix, buffers);
    }
    loss /= BATCH_SIZE;
    bpt::update(device, model, optimizer);
    if(batch_i % 100 == 0){
        std::cout << "batch: " << batch_i << " loss: " << loss << std::endl;
    }
}

batch: 0 loss: 0.280092
batch: 100 loss: 0.025453
batch: 200 loss: 0.0289114
batch: 300 loss: 0.00838864
batch: 400 loss: 0.00909099
batch: 500 loss: 0.0176317
batch: 600 loss: 0.0102458
batch: 700 loss: 0.00427028
batch: 800 loss: 0.00866197
batch: 900 loss: 0.00429899
batch: 1000 loss: 0.0209082
batch: 1100 loss: 0.00755304
batch: 1200 loss: 0.00581563
batch: 1300 loss: 0.00471603
batch: 1400 loss: 0.0080257
batch: 1500 loss: 0.00992348
batch: 1600 loss: 0.00634431
batch: 1700 loss: 0.00276589
batch: 1800 loss: 0.00870145


In [12]:
T val_loss = 0;
T accuracy = 0;
for (int sample_i=0; sample_i < DATASET_SIZE_VAL; sample_i++){
    auto input = bpt::row(device, x_val, sample_i);
    auto output = bpt::row(device, y_val, sample_i);
    bpt::forward(device, model, input);
    val_loss += bpt::nn::loss_functions::categorical_cross_entropy::evaluate(device, model.output_layer.output, output, T(1)/BATCH_SIZE);
    TI predicted_label = bpt::argmax_row(device, model.output_layer.output);
    if(sample_i % 1000 == 0){
        for(TI row_i = 0; row_i < 28; row_i++){
            for(TI col_i = 0; col_i < 28; col_i++){
                T val = bpt::get(input, 0, row_i * 28 + col_i);
                std::cout << (val > 0.5 ? (std::string(" ") + std::to_string(predicted_label)) : std::string("  "));
            }
            std::cout << std::endl;
        }
    }
    accuracy += predicted_label == bpt::get(output, 0, 0);
}
val_loss /= DATASET_SIZE_VAL;
accuracy /= DATASET_SIZE_VAL;
bpt::logging::text(device, device.logger, "Validation accuracy: ", accuracy * 100, "%", "");

                                                        
                                                        
                                                        
                                                        
                                                        
                                                        
                                                        
             7 7 7 7 7 7                                
             7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7            
             7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7            
                       7 7 7 7 7 7 7 7 7 7 7            
                                     7 7 7 7            
                                   7 7 7 7              
                                   7 7 7 7              
                                 7 7 7 7                
                                 7 7 7 7                
                               7 7 7 7                  
                               