Skip to content

Commit 03ba8fa

Browse files
author
Jessica Lin
committed
Correct errors in the TorchModule code to match the sequential version
1 parent 1bc0efe commit 03ba8fa

File tree

1 file changed

+30
-21
lines changed

1 file changed

+30
-21
lines changed

advanced_source/cpp_frontend.rst

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -769,24 +769,24 @@ modules in the ``forward()`` method of a module we define ourselves:
769769
.. code-block:: cpp
770770
771771
struct GeneratorImpl : nn::Module {
772-
GeneratorImpl()
773-
: conv1(nn::Conv2dOptions(kNoiseSize, 512, 4)
772+
GeneratorImpl(int kNoiseSize)
773+
: conv1(nn::Conv2dOptions(kNoiseSize, 256, 4)
774774
.with_bias(false)
775775
.transposed(true)),
776-
batch_norm1(512),
777-
conv2(nn::Conv2dOptions(512, 256, 4)
776+
batch_norm1(256),
777+
conv2(nn::Conv2dOptions(256, 128, 3)
778778
.stride(2)
779779
.padding(1)
780780
.with_bias(false)
781781
.transposed(true)),
782-
batch_norm2(256),
783-
conv3(nn::Conv2dOptions(256, 128, 4)
782+
batch_norm2(128),
783+
conv3(nn::Conv2dOptions(128, 64, 4)
784784
.stride(2)
785785
.padding(1)
786786
.with_bias(false)
787787
.transposed(true)),
788-
batch_norm3(128),
789-
conv4(nn::Conv2dOptions(128, 64, 4)
788+
batch_norm3(64),
789+
conv4(nn::Conv2dOptions(64, 1, 4)
790790
.stride(2)
791791
.padding(1)
792792
.with_bias(false)
@@ -796,19 +796,28 @@ modules in the ``forward()`` method of a module we define ourselves:
796796
.stride(2)
797797
.padding(1)
798798
.with_bias(false)
799-
.transposed(true)) {}
800-
801-
torch::Tensor forward(torch::Tensor x) {
802-
x = torch::relu(batch_norm1(conv1(x)));
803-
x = torch::relu(batch_norm2(conv2(x)));
804-
x = torch::relu(batch_norm3(conv3(x)));
805-
x = torch::relu(batch_norm4(conv4(x)));
806-
x = torch::tanh(conv5(x));
807-
return x;
808-
}
809-
810-
nn::Conv2d conv1, conv2, conv3, conv4, conv5;
811-
nn::BatchNorm batch_norm1, batch_norm2, batch_norm3, batch_norm4;
799+
.transposed(true))
800+
{
801+
// register_module() is needed if we want to use the parameters() method later on
802+
register_module("conv1", conv1);
803+
register_module("conv2", conv2);
804+
register_module("conv3", conv3);
805+
register_module("conv4", conv4);
806+
register_module("batch_norm1", batch_norm1);
807+
register_module("batch_norm2", batch_norm1);
808+
register_module("batch_norm3", batch_norm1);
809+
}
810+
811+
torch::Tensor forward(torch::Tensor x) {
812+
x = torch::relu(batch_norm1(conv1(x)));
813+
x = torch::relu(batch_norm2(conv2(x)));
814+
x = torch::relu(batch_norm3(conv3(x)));
815+
x = torch::tanh(conv4(x));
816+
return x;
817+
}
818+
819+
nn::Conv2d conv1, conv2, conv3, conv4;
820+
nn::BatchNorm batch_norm1, batch_norm2, batch_norm3;
812821
};
813822
TORCH_MODULE(Generator);
814823

0 commit comments

Comments
 (0)