Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions test/cpp/api/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,34 @@ using namespace torch::nn;
class TestModel : public Module {
public:
TestModel() {
add(Linear(10, 3).build(), "l1");
add(Linear(3, 5).build(), "l2");
add(Linear(5, 100).build(), "l3");
register_module("l1", &TestModel::l1, Linear(10, 3).build());
register_module("l2", &TestModel::l2, Linear(3, 5).build());
register_module("l3", &TestModel::l3, Linear(5, 100).build());
}

variable_list forward(variable_list input) override {
return input;
};
}

std::shared_ptr<Linear> l1, l2, l3;
};

class NestedModel : public Module {
public:
NestedModel() {
add(Linear(5, 20).build(), "l1");
add(std::make_shared<TestModel>(), "test");
add(Var(at::CPU(at::kFloat).tensor({3, 2, 21}), false), "param");
register_module("l1", &NestedModel::l1, Linear(5, 20).build());
register_module("test", &NestedModel::t, std::make_shared<TestModel>());
register_parameter(
"param", &NestedModel::param_, at::CPU(at::kFloat).tensor({3, 2, 21}));
}

variable_list forward(variable_list input) override {
return input;
};

Variable param_;
std::shared_ptr<Linear> l1;
std::shared_ptr<TestModel> t;
};

TEST_CASE("containers") {
Expand Down Expand Up @@ -97,8 +104,7 @@ TEST_CASE("containers") {
}

REQUIRE(
model->parameters().at("weight").grad().numel() ==
3 * 2 * 3 * 3 * 3);
model->parameters().at("weight").grad().numel() == 3 * 2 * 3 * 3 * 3);
}
}
SECTION("linear") {
Expand Down Expand Up @@ -137,7 +143,7 @@ TEST_CASE("containers") {
}

SECTION("simple") {
auto model = std::make_shared<SimpleContainer>();
auto model = std::make_shared<Sequential>();
auto l1 = model->add(Linear(10, 3).build(), "l1");
auto l2 = model->add(Linear(3, 5).build(), "l2");
auto l3 = model->add(Linear(5, 100).build(), "l3");
Expand Down
6 changes: 3 additions & 3 deletions test/cpp/api/integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ TEST_CASE("integration") {
std::cerr
<< "Training episodic policy gradient with a critic for up to 3000"
" episodes, rest your eyes for a bit!\n";
auto model = std::make_shared<SimpleContainer>();
auto model = std::make_shared<Sequential>();
auto linear = model->add(Linear(4, 128).build(), "linear");
auto policyHead = model->add(Linear(128, 2).build(), "policy");
auto valueHead = model->add(Linear(128, 1).build(), "action");
Expand Down Expand Up @@ -320,7 +320,7 @@ TEST_CASE("integration") {
}

TEST_CASE("integration/mnist", "[cuda]") {
auto model = std::make_shared<SimpleContainer>();
auto model = std::make_shared<Sequential>();
auto conv1 = model->add(Conv2d(1, 10, 5).build(), "conv1");
auto conv2 = model->add(Conv2d(10, 20, 5).build(), "conv2");
auto drop = Dropout(0.3).build();
Expand Down Expand Up @@ -355,7 +355,7 @@ TEST_CASE("integration/mnist", "[cuda]") {
}

TEST_CASE("integration/mnist/batchnorm", "[cuda]") {
auto model = std::make_shared<SimpleContainer>();
auto model = std::make_shared<Sequential>();
auto conv1 = model->add(Conv2d(1, 10, 5).build(), "conv1");
auto batchnorm2d =
model->add(BatchNorm(10).stateful(true).build(), "batchnorm2d");
Expand Down
60 changes: 49 additions & 11 deletions test/cpp/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ TEST_CASE("module/training-mode") {

TEST_CASE("module/zero-grad") {
auto model = Linear(3, 4).build();
auto weights = Var(at::ones(at::CPU(at::kFloat), {8, 3}));
auto loss = model->forward({weights}).front().sum();
auto weight = Var(at::ones(at::CPU(at::kFloat), {8, 3}));
auto loss = model->forward({weight}).front().sum();
backward(loss);
for (auto& parameter : model->parameters()) {
Variable grad = parameter.second.grad();
Expand Down Expand Up @@ -145,16 +145,18 @@ TEST_CASE("module/clone") {
SECTION("Cloning creates distinct parameters") {
struct TestModel : public CloneableModule<TestModel> {
TestModel() {
add(Linear(10, 3).build(), "l1");
add(Linear(3, 5).build(), "l2");
add(Linear(5, 100).build(), "l3");
register_module("l1", &TestModel::l1, Linear(10, 3).build());
register_module("l2", &TestModel::l2, Linear(3, 5).build());
register_module("l3", &TestModel::l3, Linear(5, 100).build());
}

void reset() override {}

variable_list forward(variable_list input) override {
return input;
}

std::shared_ptr<Linear> l1, l2, l3;
};

auto model = TestModel().build();
Expand All @@ -163,6 +165,7 @@ TEST_CASE("module/clone") {
auto m1param = model->parameters();
auto m2param = model2->parameters();
for (auto& param : m1param) {
REQUIRE(!pointer_equal(param.second, m2param[param.first]));
REQUIRE(param.second.allclose(m2param[param.first]));
param.second.data().mul_(2);
}
Expand All @@ -174,23 +177,58 @@ TEST_CASE("module/clone") {
SECTION("Cloning preserves external references") {
struct TestModel : public CloneableModule<TestModel> {
void reset() {
weights = add(Var(at::ones(at::CPU(at::kFloat), {4, 4})), "weight");
register_parameter(
"weight",
&TestModel::weight,
at::ones(at::CPU(at::kFloat), {4, 4}));
}

variable_list forward(variable_list input) override {
return input;
}

Variable weights;
Variable weight;
};

auto model = TestModel().build();
REQUIRE(pointer_equal(model->weights, model->parameters_["weight"]));
REQUIRE(pointer_equal(model->weight, model->param("weight")));

auto model2 = std::dynamic_pointer_cast<TestModel>(
std::shared_ptr<Module>(model->clone()));
REQUIRE(!pointer_equal(model2->weights, model->weights));
REQUIRE(pointer_equal(model2->weights, model2->parameters_["weight"]));
REQUIRE(!pointer_equal(model2->weights, model->parameters_["weight"]));
REQUIRE(!pointer_equal(model2->weight, model->weight));
REQUIRE(pointer_equal(model2->weight, model2->param("weight")));
REQUIRE(!pointer_equal(model2->weight, model->param("weight")));
}
}

TEST_CASE("module/parameters") {
struct TestModule : Module {
TestModule() {
register_parameter(
"a", &TestModule::a, at::zeros(at::CPU(at::kFloat), {2, 2}));
register_parameter(
"b", &TestModule::b, at::ones(at::CPU(at::kFloat), {2, 2}));
register_parameter(
"c", &TestModule::c, at::ones(at::CPU(at::kFloat), {2, 2}) * 2);
}

variable_list forward(variable_list) override {
return {};
}

Variable a, b, c;
};

TestModule module;

SECTION("has correct number of parameters") {
REQUIRE(module.parameters().size() == 3);
}

SECTION("contains parameters with the correct name") {
auto parameters = module.parameters();
REQUIRE(parameters.count("a"));
REQUIRE(parameters.count("b"));
REQUIRE(parameters.count("c"));
}
}
2 changes: 1 addition & 1 deletion test/cpp/api/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace torch::nn;
template <typename R, typename Func>
bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
auto nhid = 32;
auto model = std::make_shared<SimpleContainer>();
auto model = std::make_shared<Sequential>();
auto l1 = model->add(Linear(1, nhid).build(), "l1");
auto rnn = model->add(model_maker(nhid), "rnn");
auto lo = model->add(Linear(nhid, 1).build(), "lo");
Expand Down
63 changes: 63 additions & 0 deletions torch/csrc/api/include/torch/detail/member_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <torch/csrc/autograd/variable.h>

#include <ATen/Error.h>

#include <functional>
#include <type_traits>

namespace torch {
namespace detail {

/// A class that stores const-correct getters to a member variable, accessible
/// through a pointer to an object.
template <typename T>
class MemberRef {
public:
// TODO: Replace with (std/boost/our)::any.
using This = void*;
using ConstThis = const void*;
using Getter = std::function<T&(This)>;
using ConstGetter = std::function<const T&(ConstThis)>;

template <typename Class>
/* implicit */ MemberRef(T Class::*member)
: getter_([member](This self) -> T& {
return static_cast<Class*>(self)->*member;
}),
const_getter_([member](ConstThis self) -> const T& {
return static_cast<const Class*>(self)->*member;
}) {}

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


template <typename Class>
MemberRef(std::vector<T> Class::*member, size_t index)
: getter_([member, index](This self) -> T& {
return (static_cast<Class*>(self)->*member)[index];
}),
const_getter_([member, index](ConstThis self) -> const T& {
return (static_cast<const Class*>(self)->*member)[index];
}) {}

MemberRef(Getter getter, ConstGetter const_getter)
: getter_(std::move(getter)), const_getter_(std::move(const_getter)) {}

template <typename Class>
T& operator()(Class* object) {
AT_CHECK(getter_ != nullptr, "Calling empty getter");
return getter_(object);
}

template <typename Class>
const T& operator()(const Class* object) const {
AT_CHECK(const_getter_ != nullptr, "Calling empty const getter");
return const_getter_(object);
}

private:
Getter getter_;
ConstGetter const_getter_;
};

} // namespace detail
} // namespace torch
Loading