-
Notifications
You must be signed in to change notification settings - Fork 161
/
Copy pathmnist_lstm.cpp
47 lines (37 loc) · 1.58 KB
/
mnist_lstm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
//=======================================================================
// Copyright (c) 2014-2023 Baptiste Wicht
// Distributed under the terms of the MIT License.
// (See accompanying file LICENSE or copy at
// http://opensource.org/licenses/MIT)
//=======================================================================
#include "dll/neural/dense/dense_layer.hpp"
#include "dll/neural/lstm/lstm_layer.hpp"
#include "dll/neural/recurrent/recurrent_last_layer.hpp"
#include "dll/network.hpp"
#include "dll/datasets.hpp"
int main(int /*argc*/, char* /*argv*/ []) {
// Load the dataset
auto dataset = dll::make_mnist_dataset_nc(dll::batch_size<100>{}, dll::scale_pre<255>{});
constexpr size_t time_steps = 28;
constexpr size_t sequence_length = 28;
constexpr size_t hidden_units = 100;
// Build the network
using network_t = dll::dyn_network_desc<
dll::network_layers<
dll::lstm_layer<time_steps, sequence_length, hidden_units, dll::last_only>,
dll::recurrent_last_layer<time_steps, hidden_units>,
dll::dense_layer<hidden_units, 10, dll::softmax>
>
, dll::updater<dll::updater_type::ADAM> // Adam
, dll::batch_size<100> // The mini-batch size
>::network_t;
auto net = std::make_unique<network_t>();
// Display the network and dataset
net->display_pretty();
dataset.display_pretty();
// Train the network for performance sake
net->train(dataset.train(), 50);
// Test the network on test set
net->evaluate(dataset.test());
return 0;
}