@@ -769,24 +769,24 @@ modules in the ``forward()`` method of a module we define ourselves:
769769.. code-block:: cpp
770770
771771 struct GeneratorImpl : nn::Module {
772- GeneratorImpl()
773- : conv1(nn::Conv2dOptions(kNoiseSize, 512 , 4)
772+ GeneratorImpl(int kNoiseSize )
773+ : conv1(nn::Conv2dOptions(kNoiseSize, 256 , 4)
774774 .with_bias(false)
775775 .transposed(true)),
776- batch_norm1(512 ),
777- conv2(nn::Conv2dOptions(512, 256, 4 )
776+ batch_norm1(256 ),
777+ conv2(nn::Conv2dOptions(256, 128, 3 )
778778 .stride(2)
779779 .padding(1)
780780 .with_bias(false)
781781 .transposed(true)),
782- batch_norm2(256 ),
783- conv3(nn::Conv2dOptions(256, 128 , 4)
782+ batch_norm2(128 ),
783+ conv3(nn::Conv2dOptions(128, 64 , 4)
784784 .stride(2)
785785 .padding(1)
786786 .with_bias(false)
787787 .transposed(true)),
788- batch_norm3(128 ),
789- conv4(nn::Conv2dOptions(128, 64 , 4)
788+ batch_norm3(64 ),
789+ conv4(nn::Conv2dOptions(64, 1 , 4)
790790 .stride(2)
791791 .padding(1)
792792 .with_bias(false)
@@ -796,19 +796,28 @@ modules in the ``forward()`` method of a module we define ourselves:
796796 .stride(2)
797797 .padding(1)
798798 .with_bias(false)
799- .transposed(true)) {}
800-
801- torch::Tensor forward(torch::Tensor x) {
802- x = torch::relu(batch_norm1(conv1(x)));
803- x = torch::relu(batch_norm2(conv2(x)));
804- x = torch::relu(batch_norm3(conv3(x)));
805- x = torch::relu(batch_norm4(conv4(x)));
806- x = torch::tanh(conv5(x));
807- return x;
808- }
809-
810- nn::Conv2d conv1, conv2, conv3, conv4, conv5;
811- nn::BatchNorm batch_norm1, batch_norm2, batch_norm3, batch_norm4;
799+ .transposed(true))
800+ {
801+ // register_module() is needed if we want to use the parameters() method later on
802+ register_module("conv1", conv1);
803+ register_module("conv2", conv2);
804+ register_module("conv3", conv3);
805+ register_module("conv4", conv4);
806+ register_module("batch_norm1", batch_norm1);
807+ register_module("batch_norm2", batch_norm1);
808+ register_module("batch_norm3", batch_norm1);
809+ }
810+
811+ torch::Tensor forward(torch::Tensor x) {
812+ x = torch::relu(batch_norm1(conv1(x)));
813+ x = torch::relu(batch_norm2(conv2(x)));
814+ x = torch::relu(batch_norm3(conv3(x)));
815+ x = torch::tanh(conv4(x));
816+ return x;
817+ }
818+
819+ nn::Conv2d conv1, conv2, conv3, conv4;
820+ nn::BatchNorm batch_norm1, batch_norm2, batch_norm3;
812821 };
813822 TORCH_MODULE(Generator);
814823
0 commit comments