Skip to content

Commit

Permalink
Merge pull request #5 from mfl28/convolutional-neural-network
Browse files Browse the repository at this point in the history
Add Convolutional Neural Network tutorial
  • Loading branch information
prabhuomkar committed Dec 3, 2019
2 parents 4aa93eb + 1b977d5 commit 4d369b8
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_subdirectory("tutorials/basics/feedforward_neural_network")
add_subdirectory("tutorials/basics/linear_regression")
add_subdirectory("tutorials/basics/logistic_regression")
add_subdirectory("tutorials/basics/pytorch_basics")
add_subdirectory("tutorials/intermediate/convolutional_neural_network")

# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ $ ./scripts.sh build
* [Feedforward Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/basics/feedforward_neural_network/main.cpp)

#### 2. Intermediate
* [Convolutional Neural Network]()
* [Convolutional Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/convolutional_neural_network/src/main.cpp)
* [Deep Residual Network]()
* [Recurrent Neural Network]()
* [Bidirectional Recurrent Neural Network]()
Expand All @@ -47,4 +47,5 @@ $ ./scripts.sh build

## Authors
- Omkar Prabhu - [prabhuomkar](https://github.com/prabhuomkar)
- Markus Fleischhacker - [mfl28](https://github.com/mfl28)

37 changes: 37 additions & 0 deletions tutorials/intermediate/convolutional_neural_network/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)

project(convolutional-neural-network VERSION 1.0.0 LANGUAGES CXX)

# Files
set(SOURCES src/main.cpp
src/convnet.cpp
)

set(HEADERS include/convnet.h
)

set(EXECUTABLE_NAME convolutional-neural-network)


add_executable(${EXECUTABLE_NAME} ${SOURCES} ${HEADERS})
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)

target_link_libraries(${EXECUTABLE_NAME} "${TORCH_LIBRARIES}")

set_target_properties(${EXECUTABLE_NAME} PROPERTIES
CXX_STANDARD 11
CXX_STANDARD_REQUIRED YES
)

# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
# See https://pytorch.org/cppdocs/installing.html.
if (MSVC)
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
add_custom_command(TARGET ${EXECUTABLE_NAME}
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TORCH_DLLS}
$<TARGET_FILE_DIR:${EXECUTABLE_NAME}>)
endif (MSVC)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2019 Markus Fleischhacker
#pragma once

#include <torch/torch.h>

class ConvNetImpl : public torch::nn::Module {
public:
explicit ConvNetImpl(int64_t num_classes = 10);
torch::Tensor forward(torch::Tensor x);

private:
torch::nn::Sequential layer1{
torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 16, 5).stride(1).padding(2)),
torch::nn::BatchNorm(16),
torch::nn::Functional(torch::relu),
torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))
};

torch::nn::Sequential layer2{
torch::nn::Conv2d(torch::nn::Conv2dOptions(16, 32, 5).stride(1).padding(2)),
torch::nn::BatchNorm(32),
torch::nn::Functional(torch::relu),
torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))
};

torch::nn::Linear fc;
};

TORCH_MODULE(ConvNet);


Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright 2019 Markus Fleischhacker
#include "convnet.h"
#include <torch/torch.h>

ConvNetImpl::ConvNetImpl(int64_t num_classes)
: fc(7 * 7 * 32, num_classes) {
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("fc", fc);
}

torch::Tensor ConvNetImpl::forward(torch::Tensor x) {
x = layer1->forward(x);
x = layer2->forward(x);
x = x.view({-1, 7 * 7 * 32});
x = fc->forward(x);
return torch::log_softmax(x, 1);
}
122 changes: 122 additions & 0 deletions tutorials/intermediate/convolutional_neural_network/src/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2019 Markus Fleischhacker
#include <torch/torch.h>
#include <iostream>
#include <iomanip>
#include "convnet.h"

int main() {
std::cout << "Convolutional Neural Network\n\n";

// Device
torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);

// Hyper parameters
const int64_t num_epochs = 5;
const int64_t num_classes = 10;
const int64_t batch_size = 100;
const double learning_rate = 0.001;

const std::string MNIST_data_path = "../../../../tutorials/intermediate/convolutional_neural_network/data/";

// MNIST dataset
auto train_dataset = torch::data::datasets::MNIST(MNIST_data_path)
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());

// Number of samples in the training set
auto num_train_samples = train_dataset.size().value();

auto test_dataset = torch::data::datasets::MNIST(MNIST_data_path, torch::data::datasets::MNIST::Mode::kTest)
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());

// Number of samples in the testset
auto num_test_samples = test_dataset.size().value();

// Data loader
auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
std::move(train_dataset), batch_size);
auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
std::move(test_dataset), batch_size);

// Model
ConvNet model(num_classes);
model->to(device);

// Optimizer
auto optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(learning_rate));

// Set floating point output precision
std::cout << std::fixed << std::setprecision(4);

std::cout << "Training...\n";

// Train the model
for (size_t epoch = 0; epoch != num_epochs; ++epoch) {
// Initialize running metrics
float running_loss = 0.0;
size_t num_correct = 0;

for (auto& batch : *train_loader) {
// Transfer images and target labels to device
auto data = batch.data.to(device);
auto target = batch.target.to(device);

// Forward pass
auto output = model->forward(data);

// Calculate loss
auto loss = torch::nll_loss(output, target);

// Update running loss
running_loss += loss.item().toFloat() * data.size(0);

// Calculate prediction
auto prediction = output.argmax(1);

// Update number of correctly classified samples
num_correct += prediction.eq(target).sum().item().toLong();

// Backward pass and optimize
optimizer.zero_grad();
loss.backward();
optimizer.step();
}

auto sample_mean_loss = running_loss / num_train_samples;
auto accuracy = static_cast<float>(num_correct) / num_train_samples;

std::cout << "Epoch [" << (epoch + 1) << "/" << num_epochs << "], Trainset - Loss: "
<< sample_mean_loss << ", Accuracy: " << accuracy << '\n';
}

std::cout << "Training finished!\n\n";
std::cout << "Testing...\n";

// Test the model
model->eval();
torch::NoGradGuard no_grad;

float running_loss = 0.0;
size_t num_correct = 0;

for (const auto& batch : *test_loader) {
auto data = batch.data.to(device);
auto target = batch.target.to(device);

auto output = model->forward(data);

auto loss = torch::nll_loss(output, target);
running_loss += loss.item().toFloat() * data.size(0);

auto prediction = output.argmax(1);
num_correct += prediction.eq(target).sum().item().toLong();
}

std::cout << "Testing finished!\n";

auto test_accuracy = static_cast<float>(num_correct) / num_test_samples;
auto test_sample_mean_loss = running_loss / num_test_samples;

std::cout << "Testset - Loss: " << test_sample_mean_loss << ", Accuracy: " << test_accuracy << '\n';
}

0 comments on commit 4d369b8

Please sign in to comment.