Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Deep Residual Network tutorial #7

Merged
merged 4 commits into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_subdirectory("tutorials/basics/linear_regression")
add_subdirectory("tutorials/basics/logistic_regression")
add_subdirectory("tutorials/basics/pytorch_basics")
add_subdirectory("tutorials/intermediate/convolutional_neural_network")
add_subdirectory("tutorials/intermediate/deep_residual_network")

# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ $ ./scripts.sh build

#### 2. Intermediate
* [Convolutional Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/convolutional_neural_network/src/main.cpp)
* [Deep Residual Network]()
* [Deep Residual Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/deep_residual_network/src/main.cpp)
* [Recurrent Neural Network]()
* [Bidirectional Recurrent Neural Network]()
* [Language Model (RNN-LM)]()
Expand Down
42 changes: 42 additions & 0 deletions tutorials/intermediate/deep_residual_network/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)

project(deep-residual-network VERSION 1.0.0 LANGUAGES CXX)

# Files
set(SOURCES src/main.cpp
src/cifar10.cpp
src/residual_block.cpp
src/transform.cpp
)

set(HEADERS include/residual_block.h
include/resnet.h
include/cifar10.h
include/transform.h
)

set(EXECUTABLE_NAME deep-residual-network)


add_executable(${EXECUTABLE_NAME} ${SOURCES} ${HEADERS})
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)

target_link_libraries(${EXECUTABLE_NAME} "${TORCH_LIBRARIES}")

set_target_properties(${EXECUTABLE_NAME} PROPERTIES
CXX_STANDARD 11
CXX_STANDARD_REQUIRED YES
)

# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
# See https://pytorch.org/cppdocs/installing.html.
if (MSVC)
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
add_custom_command(TARGET ${EXECUTABLE_NAME}
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TORCH_DLLS}
$<TARGET_FILE_DIR:${EXECUTABLE_NAME}>)
endif (MSVC)
43 changes: 43 additions & 0 deletions tutorials/intermediate/deep_residual_network/include/cifar10.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2019 Markus Fleischhacker
#pragma once

#include <torch/data/datasets/base.h>
#include <torch/data/example.h>
#include <torch/types.h>
#include <cstddef>
#include <string>

// Cifar10 dataset
// based on: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/data/datasets/mnist.h.
class Cifar10 : public torch::data::datasets::Dataset<Cifar10> {
mfl28 marked this conversation as resolved.
Show resolved Hide resolved
public:
// The mode in which the dataset is loaded
enum Mode { kTrain, kTest };

// Loads the Cifar10 dataset from the `root` path.
//
// The supplied `root` path should contain the *content* of the unzipped
// Cifar10 dataset (binary version), available from http://www.cs.toronto.edu/~kriz/cifar.html.
explicit Cifar10(const std::string& root, Mode mode = Mode::kTrain);
mfl28 marked this conversation as resolved.
Show resolved Hide resolved

// Returns the `Example` at the given `index`.
torch::data::Example<> get(size_t index) override;

// Returns the size of the dataset.
torch::optional<size_t> size() const override;

// Returns true if this is the training subset of Cifar10.
bool is_train() const noexcept;

// Returns all images stacked into a single tensor.
const torch::Tensor& images() const;

// Returns all targets stacked into a single tensor.
const torch::Tensor& targets() const;

private:
torch::Tensor images_;
torch::Tensor targets_;
Mode mode_;
};

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2019 Markus Fleischhacker
#pragma once

#include <torch/torch.h>

namespace resnet {
class ResidualBlockImpl : public torch::nn::Module {
public:
ResidualBlockImpl(int64_t in_channels, int64_t out_channels, int64_t stride = 1,
torch::nn::Sequential downsample = nullptr);
torch::Tensor forward(torch::Tensor x);

private:
torch::nn::Conv2d conv1;
torch::nn::BatchNorm bn1;
torch::nn::Functional relu;
torch::nn::Conv2d conv2;
torch::nn::BatchNorm bn2;
torch::nn::Sequential downsampler;
};

torch::nn::Conv2d conv3x3(int64_t in_channels, int64_t out_channels, int64_t stride = 1);

TORCH_MODULE(ResidualBlock);
} // namespace resnet
91 changes: 91 additions & 0 deletions tutorials/intermediate/deep_residual_network/include/resnet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2019 Markus Fleischhacker
#pragma once

#include <torch/torch.h>
#include <vector>
#include "residual_block.h"

namespace resnet {
template<typename Block>
class ResNetImpl : public torch::nn::Module {
public:
explicit ResNetImpl(const std::array<int64_t, 3>& layers, int64_t num_classes = 10);
torch::Tensor forward(torch::Tensor x);

private:
int64_t in_channels = 16;
torch::nn::Conv2d conv{conv3x3(3, 16)};
torch::nn::BatchNorm bn{16};
torch::nn::Functional relu{torch::relu};
torch::nn::Sequential layer1;
torch::nn::Sequential layer2;
torch::nn::Sequential layer3;
torch::nn::AvgPool2d avg_pool{8};
torch::nn::Linear fc;

torch::nn::Sequential make_layer(int64_t out_channels, int64_t blocks, int64_t stride = 1);
};

template<typename Block>
ResNetImpl<Block>::ResNetImpl(const std::array<int64_t, 3>& layers, int64_t num_classes) :
layer1(make_layer(16, layers[0])),
layer2(make_layer(32, layers[1], 2)),
layer3(make_layer(64, layers[2], 2)),
fc(64, num_classes) {
register_module("conv", conv);
register_module("bn", bn);
register_module("relu", relu);
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("layer3", layer3);
register_module("avg_pool", avg_pool);
register_module("fc", fc);
}

template<typename Block>
torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
auto out = conv->forward(x);
out = bn->forward(out);
out = relu->forward(out);
out = layer1->forward(out);
out = layer2->forward(out);
out = layer3->forward(out);
out = avg_pool->forward(out);
out = out.view({out.size(0), -1});
out = fc->forward(out);

return torch::log_softmax(out, 1);
}

template<typename Block>
torch::nn::Sequential ResNetImpl<Block>::make_layer(int64_t out_channels, int64_t blocks, int64_t stride) {
torch::nn::Sequential layers;
torch::nn::Sequential downsample{nullptr};

if (stride != 1 || in_channels != out_channels) {
downsample = {
conv3x3(in_channels, out_channels, stride),
torch::nn::BatchNorm(out_channels)
};
}

layers->push_back(Block(in_channels, out_channels, stride, downsample));

in_channels = out_channels;

for (int64_t i = 1; i != blocks; ++i) {
layers->push_back(Block(out_channels, out_channels));
}

return layers;
}

// Wrap class into ModuleHolder (a shared_ptr wrapper),
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/pimpl.h
template<typename Block = ResidualBlock>
class ResNet : public torch::nn::ModuleHolder<ResNetImpl<Block>> {
public:
using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
};
} // namespace resnet

Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2019 Markus Fleischhacker
#pragma once

#include <torch/torch.h>
#include <random>

namespace transform {
class RandomHorizontalFlip : public torch::data::transforms::TensorTransform<torch::Tensor> {
public:
// Creates a transformation that randomly horizontally flips a tensor.
//
// The parameter `p` determines the probability that a tensor is flipped (default = 0.5).
explicit RandomHorizontalFlip(double p = 0.5);

torch::Tensor operator()(torch::Tensor input) override;

private:
double p_;
};

class ConstantPad : public torch::data::transforms::TensorTransform<torch::Tensor> {
public:
// Creates a transformation that pads a tensor.
//
// `padding` is expected to be a vector of size 4 whose entries correspond to the
// padding of the sides, i.e {left, right, top, bottom}. `value` determines the value
// for the padded pixels.
explicit ConstantPad(const std::vector<int64_t>& padding, torch::Scalar value = 0);

// Creates a transformation that pads a tensor.
//
// The padding will be performed using the size `padding` for all 4 sides.
// `value` determines the value for the padded pixels.
explicit ConstantPad(int64_t padding, torch::Scalar value = 0);

torch::Tensor operator()(torch::Tensor input) override;

private:
std::vector<int64_t> padding_;
torch::Scalar value_;
};

class RandomCrop : public torch::data::transforms::TensorTransform<torch::Tensor> {
public:
// Creates a transformation that randomly crops a tensor.
//
// The parameter `size` is expected to be a vector of size 2
// and determines the output size {height, width}.
explicit RandomCrop(const std::vector<int64_t>& size);
torch::Tensor operator()(torch::Tensor input) override;

private:
std::vector<int64_t> size_;
};
} // namespace transform
98 changes: 98 additions & 0 deletions tutorials/intermediate/deep_residual_network/src/cifar10.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2019 Markus Fleischhacker
#include "cifar10.h"

namespace {
// Cifar10 dataset description can be found at https://www.cs.toronto.edu/~kriz/cifar.html.
constexpr uint32_t kTrainSize = 50000;
constexpr uint32_t kTestSize = 10000;
constexpr uint32_t kSizePerBatch = 10000;
constexpr uint32_t kImageRows = 32;
constexpr uint32_t kImageColumns = 32;
constexpr uint32_t kBytesPerRow = 3073;
constexpr uint32_t kBytesPerChannelPerRow = (kBytesPerRow - 1) / 3;
constexpr uint32_t kBytesPerBatchFile = kBytesPerRow * kSizePerBatch;

const std::vector<std::string> kTrainDataBatchFiles = {
"data_batch_1.bin",
"data_batch_2.bin",
"data_batch_3.bin",
"data_batch_4.bin",
"data_batch_5.bin",
};

const std::vector<std::string> kTestDataBatchFiles = {
"test_batch.bin"
};

// Source: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp.
std::string join_paths(std::string head, const std::string& tail) {
if (head.back() != '/') {
head.push_back('/');
}
head += tail;
return head;
}
// Partially based on https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp.
std::pair<torch::Tensor, torch::Tensor> read_data(const std::string& root, bool train) {
const auto& files = train ? kTrainDataBatchFiles : kTestDataBatchFiles;
const auto num_samples = train ? kTrainSize : kTestSize;

std::vector<char> data_buffer;
data_buffer.reserve(files.size() * kBytesPerBatchFile);

for (const auto& file : files) {
const auto path = join_paths(root, file);
std::ifstream data(path, std::ios::binary);
TORCH_CHECK(data, "Error opening data file at", path);

data_buffer.insert(data_buffer.end(), std::istreambuf_iterator<char>(data), {});
}

TORCH_CHECK(data_buffer.size() == files.size() * kBytesPerBatchFile, "Unexpected file sizes");

auto targets = torch::empty(num_samples, torch::kByte);
auto images = torch::empty({num_samples, 3, kImageRows, kImageColumns}, torch::kByte);

for (uint32_t i = 0; i != num_samples; ++i) {
// The first byte of each row is the target class index.
uint32_t start_index = i * kBytesPerRow;
targets[i] = data_buffer[start_index];

// The next bytes correspond to the rgb channel values in the following order:
// red (32 *32 = 1024 bytes) | green (1024 bytes) | blue (1024 bytes)
uint32_t image_start = start_index + 1;
uint32_t image_end = image_start + 3 * kBytesPerChannelPerRow;
std::copy(&data_buffer[image_start], &data_buffer[image_end],
reinterpret_cast<char*>(images[i].data_ptr()));
}

return {images.to(torch::kFloat32).div_(255), targets.to(torch::kInt64)};
}
} // namespace

Cifar10::Cifar10(const std::string& root, Mode mode) : mode_(mode) {
mfl28 marked this conversation as resolved.
Show resolved Hide resolved
auto data = read_data(root, mode == Mode::kTrain);

images_ = std::move(data.first);
targets_ = std::move(data.second);
}

torch::data::Example<> Cifar10::get(size_t index) {
return {images_[index], targets_[index]};
}

torch::optional<size_t> Cifar10::size() const {
return images_.size(0);
}

bool Cifar10::is_train() const noexcept {
return mode_ == Mode::kTrain;
}

const torch::Tensor& Cifar10::images() const {
return images_;
}

const torch::Tensor& Cifar10::targets() const {
return targets_;
}