diff --git a/CMakeLists.txt b/CMakeLists.txt index 9352819..b6c8449 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory("tutorials/basics/linear_regression") add_subdirectory("tutorials/basics/logistic_regression") add_subdirectory("tutorials/basics/pytorch_basics") add_subdirectory("tutorials/intermediate/convolutional_neural_network") +add_subdirectory("tutorials/intermediate/deep_residual_network") # The following code block is suggested to be used on Windows. # According to https://github.com/pytorch/pytorch/issues/25457, diff --git a/README.md b/README.md index a7da6b8..d62bdc0 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ $ ./scripts.sh build #### 2. Intermediate * [Convolutional Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/convolutional_neural_network/src/main.cpp) -* [Deep Residual Network]() +* [Deep Residual Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/deep_residual_network/src/main.cpp) * [Recurrent Neural Network]() * [Bidirectional Recurrent Neural Network]() * [Language Model (RNN-LM)]() diff --git a/scripts.sh b/scripts.sh index eeacc1e..2040d48 100755 --- a/scripts.sh +++ b/scripts.sh @@ -1,9 +1,9 @@ #!/bin/bash function install() { - wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip - unzip libtorch-shared-with-deps-latest.zip - rm -rf libtorch-shared-with-deps-latest.zip + wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.3.1%2Bcpu.zip + unzip libtorch-shared-with-deps-1.3.1+cpu.zip + rm -rf libtorch-shared-with-deps-1.3.1+cpu.zip } function build() { diff --git a/tutorials/intermediate/deep_residual_network/CMakeLists.txt b/tutorials/intermediate/deep_residual_network/CMakeLists.txt new file mode 100644 index 0000000..eeb60ee --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/CMakeLists.txt @@ -0,0 +1,42 @@ +cmake_minimum_required(VERSION 3.0 FATAL_ERROR) + +project(deep-residual-network VERSION 1.0.0 LANGUAGES CXX) + +# Files +set(SOURCES src/main.cpp + src/cifar10.cpp + src/residual_block.cpp + src/transform.cpp +) + +set(HEADERS include/residual_block.h + include/resnet.h + include/cifar10.h + include/transform.h +) + +set(EXECUTABLE_NAME deep-residual-network) + + +add_executable(${EXECUTABLE_NAME} ${SOURCES} ${HEADERS}) +target_include_directories(${EXECUTABLE_NAME} PRIVATE include) + +target_link_libraries(${EXECUTABLE_NAME} "${TORCH_LIBRARIES}") + +set_target_properties(${EXECUTABLE_NAME} PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES +) + +# The following code block is suggested to be used on Windows. +# According to https://github.com/pytorch/pytorch/issues/25457, +# the DLLs need to be copied to avoid memory errors. +# See https://pytorch.org/cppdocs/installing.html. +if (MSVC) + file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") + add_custom_command(TARGET ${EXECUTABLE_NAME} + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${TORCH_DLLS} + $) +endif (MSVC) \ No newline at end of file diff --git a/tutorials/intermediate/deep_residual_network/include/cifar10.h b/tutorials/intermediate/deep_residual_network/include/cifar10.h new file mode 100644 index 0000000..2656d7f --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/include/cifar10.h @@ -0,0 +1,43 @@ +// Copyright 2019 Markus Fleischhacker +#pragma once + +#include +#include +#include +#include +#include + +// CIFAR10 dataset +// based on: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/data/datasets/mnist.h. +class CIFAR10 : public torch::data::datasets::Dataset { + public: + // The mode in which the dataset is loaded + enum Mode { kTrain, kTest }; + + // Loads the CIFAR10 dataset from the `root` path. + // + // The supplied `root` path should contain the *content* of the unzipped + // CIFAR10 dataset (binary version), available from http://www.cs.toronto.edu/~kriz/cifar.html. + explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain); + + // Returns the `Example` at the given `index`. + torch::data::Example<> get(size_t index) override; + + // Returns the size of the dataset. + torch::optional size() const override; + + // Returns true if this is the training subset of CIFAR10. + bool is_train() const noexcept; + + // Returns all images stacked into a single tensor. + const torch::Tensor& images() const; + + // Returns all targets stacked into a single tensor. + const torch::Tensor& targets() const; + + private: + torch::Tensor images_; + torch::Tensor targets_; + Mode mode_; +}; + diff --git a/tutorials/intermediate/deep_residual_network/include/residual_block.h b/tutorials/intermediate/deep_residual_network/include/residual_block.h new file mode 100644 index 0000000..74cb057 --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/include/residual_block.h @@ -0,0 +1,25 @@ +// Copyright 2019 Markus Fleischhacker +#pragma once + +#include + +namespace resnet { +class ResidualBlockImpl : public torch::nn::Module { + public: + ResidualBlockImpl(int64_t in_channels, int64_t out_channels, int64_t stride = 1, + torch::nn::Sequential downsample = nullptr); + torch::Tensor forward(torch::Tensor x); + + private: + torch::nn::Conv2d conv1; + torch::nn::BatchNorm bn1; + torch::nn::Functional relu; + torch::nn::Conv2d conv2; + torch::nn::BatchNorm bn2; + torch::nn::Sequential downsampler; +}; + +torch::nn::Conv2d conv3x3(int64_t in_channels, int64_t out_channels, int64_t stride = 1); + +TORCH_MODULE(ResidualBlock); +} // namespace resnet diff --git a/tutorials/intermediate/deep_residual_network/include/resnet.h b/tutorials/intermediate/deep_residual_network/include/resnet.h new file mode 100644 index 0000000..2acd45a --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/include/resnet.h @@ -0,0 +1,91 @@ +// Copyright 2019 Markus Fleischhacker +#pragma once + +#include +#include +#include "residual_block.h" + +namespace resnet { +template +class ResNetImpl : public torch::nn::Module { + public: + explicit ResNetImpl(const std::array& layers, int64_t num_classes = 10); + torch::Tensor forward(torch::Tensor x); + + private: + int64_t in_channels = 16; + torch::nn::Conv2d conv{conv3x3(3, 16)}; + torch::nn::BatchNorm bn{16}; + torch::nn::Functional relu{torch::relu}; + torch::nn::Sequential layer1; + torch::nn::Sequential layer2; + torch::nn::Sequential layer3; + torch::nn::AvgPool2d avg_pool{8}; + torch::nn::Linear fc; + + torch::nn::Sequential make_layer(int64_t out_channels, int64_t blocks, int64_t stride = 1); +}; + +template +ResNetImpl::ResNetImpl(const std::array& layers, int64_t num_classes) : + layer1(make_layer(16, layers[0])), + layer2(make_layer(32, layers[1], 2)), + layer3(make_layer(64, layers[2], 2)), + fc(64, num_classes) { + register_module("conv", conv); + register_module("bn", bn); + register_module("relu", relu); + register_module("layer1", layer1); + register_module("layer2", layer2); + register_module("layer3", layer3); + register_module("avg_pool", avg_pool); + register_module("fc", fc); +} + +template +torch::Tensor ResNetImpl::forward(torch::Tensor x) { + auto out = conv->forward(x); + out = bn->forward(out); + out = relu->forward(out); + out = layer1->forward(out); + out = layer2->forward(out); + out = layer3->forward(out); + out = avg_pool->forward(out); + out = out.view({out.size(0), -1}); + out = fc->forward(out); + + return torch::log_softmax(out, 1); +} + +template +torch::nn::Sequential ResNetImpl::make_layer(int64_t out_channels, int64_t blocks, int64_t stride) { + torch::nn::Sequential layers; + torch::nn::Sequential downsample{nullptr}; + + if (stride != 1 || in_channels != out_channels) { + downsample = torch::nn::Sequential{ + conv3x3(in_channels, out_channels, stride), + torch::nn::BatchNorm(out_channels) + }; + } + + layers->push_back(Block(in_channels, out_channels, stride, downsample)); + + in_channels = out_channels; + + for (int64_t i = 1; i != blocks; ++i) { + layers->push_back(Block(out_channels, out_channels)); + } + + return layers; +} + +// Wrap class into ModuleHolder (a shared_ptr wrapper), +// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/pimpl.h +template +class ResNet : public torch::nn::ModuleHolder> { + public: + using torch::nn::ModuleHolder>::ModuleHolder; +}; +} // namespace resnet + diff --git a/tutorials/intermediate/deep_residual_network/include/transform.h b/tutorials/intermediate/deep_residual_network/include/transform.h new file mode 100644 index 0000000..f5c442d --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/include/transform.h @@ -0,0 +1,55 @@ +// Copyright 2019 Markus Fleischhacker +#pragma once + +#include +#include + +namespace transform { +class RandomHorizontalFlip : public torch::data::transforms::TensorTransform { + public: + // Creates a transformation that randomly horizontally flips a tensor. + // + // The parameter `p` determines the probability that a tensor is flipped (default = 0.5). + explicit RandomHorizontalFlip(double p = 0.5); + + torch::Tensor operator()(torch::Tensor input) override; + + private: + double p_; +}; + +class ConstantPad : public torch::data::transforms::TensorTransform { + public: + // Creates a transformation that pads a tensor. + // + // `padding` is expected to be a vector of size 4 whose entries correspond to the + // padding of the sides, i.e {left, right, top, bottom}. `value` determines the value + // for the padded pixels. + explicit ConstantPad(const std::vector& padding, torch::Scalar value = 0); + + // Creates a transformation that pads a tensor. + // + // The padding will be performed using the size `padding` for all 4 sides. + // `value` determines the value for the padded pixels. + explicit ConstantPad(int64_t padding, torch::Scalar value = 0); + + torch::Tensor operator()(torch::Tensor input) override; + + private: + std::vector padding_; + torch::Scalar value_; +}; + +class RandomCrop : public torch::data::transforms::TensorTransform { + public: + // Creates a transformation that randomly crops a tensor. + // + // The parameter `size` is expected to be a vector of size 2 + // and determines the output size {height, width}. + explicit RandomCrop(const std::vector& size); + torch::Tensor operator()(torch::Tensor input) override; + + private: + std::vector size_; +}; +} // namespace transform diff --git a/tutorials/intermediate/deep_residual_network/src/cifar10.cpp b/tutorials/intermediate/deep_residual_network/src/cifar10.cpp new file mode 100644 index 0000000..ced5bc9 --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/src/cifar10.cpp @@ -0,0 +1,98 @@ +// Copyright 2019 Markus Fleischhacker +#include "cifar10.h" + +namespace { +// CIFAR10 dataset description can be found at https://www.cs.toronto.edu/~kriz/cifar.html. +constexpr uint32_t kTrainSize = 50000; +constexpr uint32_t kTestSize = 10000; +constexpr uint32_t kSizePerBatch = 10000; +constexpr uint32_t kImageRows = 32; +constexpr uint32_t kImageColumns = 32; +constexpr uint32_t kBytesPerRow = 3073; +constexpr uint32_t kBytesPerChannelPerRow = (kBytesPerRow - 1) / 3; +constexpr uint32_t kBytesPerBatchFile = kBytesPerRow * kSizePerBatch; + +const std::vector kTrainDataBatchFiles = { + "data_batch_1.bin", + "data_batch_2.bin", + "data_batch_3.bin", + "data_batch_4.bin", + "data_batch_5.bin", +}; + +const std::vector kTestDataBatchFiles = { + "test_batch.bin" +}; + +// Source: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp. +std::string join_paths(std::string head, const std::string& tail) { + if (head.back() != '/') { + head.push_back('/'); + } + head += tail; + return head; +} +// Partially based on https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp. +std::pair read_data(const std::string& root, bool train) { + const auto& files = train ? kTrainDataBatchFiles : kTestDataBatchFiles; + const auto num_samples = train ? kTrainSize : kTestSize; + + std::vector data_buffer; + data_buffer.reserve(files.size() * kBytesPerBatchFile); + + for (const auto& file : files) { + const auto path = join_paths(root, file); + std::ifstream data(path, std::ios::binary); + TORCH_CHECK(data, "Error opening data file at", path); + + data_buffer.insert(data_buffer.end(), std::istreambuf_iterator(data), {}); + } + + TORCH_CHECK(data_buffer.size() == files.size() * kBytesPerBatchFile, "Unexpected file sizes"); + + auto targets = torch::empty(num_samples, torch::kByte); + auto images = torch::empty({num_samples, 3, kImageRows, kImageColumns}, torch::kByte); + + for (uint32_t i = 0; i != num_samples; ++i) { + // The first byte of each row is the target class index. + uint32_t start_index = i * kBytesPerRow; + targets[i] = data_buffer[start_index]; + + // The next bytes correspond to the rgb channel values in the following order: + // red (32 *32 = 1024 bytes) | green (1024 bytes) | blue (1024 bytes) + uint32_t image_start = start_index + 1; + uint32_t image_end = image_start + 3 * kBytesPerChannelPerRow; + std::copy(&data_buffer[image_start], &data_buffer[image_end], + reinterpret_cast(images[i].data_ptr())); + } + + return {images.to(torch::kFloat32).div_(255), targets.to(torch::kInt64)}; +} +} // namespace + +CIFAR10::CIFAR10(const std::string& root, Mode mode) : mode_(mode) { + auto data = read_data(root, mode == Mode::kTrain); + + images_ = std::move(data.first); + targets_ = std::move(data.second); +} + +torch::data::Example<> CIFAR10::get(size_t index) { + return {images_[index], targets_[index]}; +} + +torch::optional CIFAR10::size() const { + return images_.size(0); +} + +bool CIFAR10::is_train() const noexcept { + return mode_ == Mode::kTrain; +} + +const torch::Tensor& CIFAR10::images() const { + return images_; +} + +const torch::Tensor& CIFAR10::targets() const { + return targets_; +} diff --git a/tutorials/intermediate/deep_residual_network/src/main.cpp b/tutorials/intermediate/deep_residual_network/src/main.cpp new file mode 100644 index 0000000..5561846 --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/src/main.cpp @@ -0,0 +1,142 @@ +// Copyright 2019 Markus Fleischhacker +#include +#include +#include +#include "resnet.h" +#include "cifar10.h" +#include "transform.h" + +using resnet::ResNet; +using resnet::ResidualBlock; +using transform::ConstantPad; +using transform::RandomCrop; +using transform::RandomHorizontalFlip; + +int main() { + std::cout << "Deep Residual Network\n\n"; + + // Device + torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); + + // Hyper parameters + const int64_t num_epochs = 20; + const int64_t num_classes = 10; + const int64_t batch_size = 100; + const double learning_rate = 0.001; + const int64_t learning_rate_decay_frequency = 8; // number of epochs after which to decay the learning rate + const double learning_rate_decay_factor = 1.0 / 3.0; + + const std::string CIFAR_data_path = "../../../../tutorials/intermediate/deep_residual_network/data/"; + + // CIFAR10 custom dataset + auto train_dataset = CIFAR10(CIFAR_data_path) + .map(ConstantPad(4)) + .map(RandomHorizontalFlip()) + .map(RandomCrop({32, 32})) + .map(torch::data::transforms::Stack<>()); + + // Number of samples in the training set + auto num_train_samples = train_dataset.size().value(); + + auto test_dataset = CIFAR10(CIFAR_data_path, CIFAR10::Mode::kTest) + .map(torch::data::transforms::Stack<>()); + + // Number of samples in the testset + auto num_test_samples = test_dataset.size().value(); + + // Data loader + auto train_loader = torch::data::make_data_loader( + std::move(train_dataset), batch_size); + auto test_loader = torch::data::make_data_loader( + std::move(test_dataset), batch_size); + + // Model + std::array layers{2, 2, 2}; + ResNet model(layers, num_classes); + model->to(device); + + // Optimizer + auto optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(learning_rate)); + + // Set floating point output precision + std::cout << std::fixed << std::setprecision(4); + + auto current_learning_rate = learning_rate; + + std::cout << "Training...\n"; + + // Train the model + for (size_t epoch = 0; epoch != num_epochs; ++epoch) { + // Initialize running metrics + float running_loss = 0.0; + size_t num_correct = 0; + + for (auto& batch : *train_loader) { + // Transfer images and target labels to device + auto data = batch.data.to(device); + auto target = batch.target.to(device); + + // Forward pass + auto output = model->forward(data); + + // Calculate loss + auto loss = torch::nll_loss(output, target); + + // Update running loss + running_loss += loss.item().toFloat() * data.size(0); + + // Calculate prediction + auto prediction = output.argmax(1); + + // Update number of correctly classified samples + num_correct += prediction.eq(target).sum().item().toLong(); + + // Backward pass and optimize + optimizer.zero_grad(); + loss.backward(); + optimizer.step(); + } + + // Decay learning rate + if ((epoch + 1) % learning_rate_decay_frequency == 0) { + current_learning_rate *= learning_rate_decay_factor; + optimizer.options.learning_rate(current_learning_rate); + } + + auto sample_mean_loss = running_loss / num_train_samples; + auto accuracy = static_cast(num_correct) / num_train_samples; + + std::cout << "Epoch [" << (epoch + 1) << "/" << num_epochs << "], Trainset - Loss: " + << sample_mean_loss << ", Accuracy: " << accuracy << '\n'; + } + + std::cout << "Training finished!\n\n"; + std::cout << "Testing...\n"; + + // Test the model + model->eval(); + torch::NoGradGuard no_grad; + + float running_loss = 0.0; + size_t num_correct = 0; + + for (const auto& batch : *test_loader) { + auto data = batch.data.to(device); + auto target = batch.target.to(device); + + auto output = model->forward(data); + + auto loss = torch::nll_loss(output, target); + running_loss += loss.item().toFloat() * data.size(0); + + auto prediction = output.argmax(1); + num_correct += prediction.eq(target).sum().item().toLong(); + } + + std::cout << "Testing finished!\n"; + + auto test_accuracy = static_cast(num_correct) / num_test_samples; + auto test_sample_mean_loss = running_loss / num_test_samples; + + std::cout << "Testset - Loss: " << test_sample_mean_loss << ", Accuracy: " << test_accuracy << '\n'; +} diff --git a/tutorials/intermediate/deep_residual_network/src/residual_block.cpp b/tutorials/intermediate/deep_residual_network/src/residual_block.cpp new file mode 100644 index 0000000..d34bc83 --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/src/residual_block.cpp @@ -0,0 +1,45 @@ +// Copyright 2019 Markus Fleischhacker +#include "residual_block.h" +#include + +namespace resnet { +ResidualBlockImpl::ResidualBlockImpl(int64_t in_channels, int64_t out_channels, int64_t stride, + torch::nn::Sequential downsample) : + conv1(conv3x3(in_channels, out_channels, stride)), + bn1(out_channels), + relu(torch::relu), + conv2(conv3x3(out_channels, out_channels)), + bn2(out_channels), + downsampler(downsample) { + register_module("conv1", conv1); + register_module("bn1", bn1); + register_module("relu", relu); + register_module("conv2", conv2); + register_module("bn2", bn2); + + if (downsampler) { + register_module("downsampler", downsampler); + } +} + +torch::Tensor ResidualBlockImpl::forward(torch::Tensor x) { + auto out = conv1->forward(x); + out = bn1->forward(out); + out = relu->forward(out); + out = conv2->forward(out); + out = bn2->forward(out); + + auto residual = downsampler ? downsampler->forward(x) : x; + out += residual; + out = relu->forward(out); + + return out; +} + +torch::nn::Conv2d conv3x3(int64_t in_channels, int64_t out_channels, int64_t stride) { + return torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 3) + .stride(stride) + .padding(1) + .with_bias(false)); +} +} // namespace resnet diff --git a/tutorials/intermediate/deep_residual_network/src/transform.cpp b/tutorials/intermediate/deep_residual_network/src/transform.cpp new file mode 100644 index 0000000..8c6eb24 --- /dev/null +++ b/tutorials/intermediate/deep_residual_network/src/transform.cpp @@ -0,0 +1,50 @@ +// Copyright 2019 Markus Fleischhacker +#include "transform.h" + +namespace transform { +namespace { + double rand_double() { + return torch::rand(1)[0].item(); + } + + int64_t rand_int(int64_t max) { + return torch::randint(max, 1)[0].item(); + } +} // namespace + +// RandomHorizontalFlip +RandomHorizontalFlip::RandomHorizontalFlip(double p) : p_(p) {} + +torch::Tensor RandomHorizontalFlip::operator()(torch::Tensor input) { + if (rand_double() < p_) { + return input.flip(-1); + } + + return input; +} + +// ConstantPad +ConstantPad::ConstantPad(const std::vector& padding, torch::Scalar value) + : padding_(padding), value_(value) {} + +ConstantPad::ConstantPad(int64_t padding, torch::Scalar value) + : padding_(4, padding), value_(value) {} + +torch::Tensor ConstantPad::operator()(torch::Tensor input) { + return torch::constant_pad_nd(input, padding_, value_); +} + +// RandomCrop +RandomCrop::RandomCrop(const std::vector& size) : size_(size) {} + +torch::Tensor RandomCrop::operator()(torch::Tensor input) { + auto height_offset_length = input.size(-2) - size_[0]; + auto width_offset_length = input.size(-1) - size_[1]; + + auto height_offset = rand_int(height_offset_length); + auto width_offset = rand_int(width_offset_length); + + return input.slice(-2, height_offset, height_offset + size_[0]) + .slice(-1, width_offset, width_offset + size_[1]); +} +} // namespace transform