diff --git a/cpp/dcgan/dcgan.cpp b/cpp/dcgan/dcgan.cpp index ffbb28bc3d..116131b88d 100644 --- a/cpp/dcgan/dcgan.cpp +++ b/cpp/dcgan/dcgan.cpp @@ -31,6 +31,50 @@ const int64_t kLogInterval = 10; using namespace torch; +struct DCGANGeneratorImpl : nn::Module { + DCGANGeneratorImpl(int kNoiseSize) + : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4) + .bias(false)), + batch_norm1(256), + conv2(nn::ConvTranspose2dOptions(256, 128, 3) + .stride(2) + .padding(1) + .bias(false)), + batch_norm2(128), + conv3(nn::ConvTranspose2dOptions(128, 64, 4) + .stride(2) + .padding(1) + .bias(false)), + batch_norm3(64), + conv4(nn::ConvTranspose2dOptions(64, 1, 4) + .stride(2) + .padding(1) + .bias(false)) + { + // register_module() is needed if we want to use the parameters() method later on + register_module("conv1", conv1); + register_module("conv2", conv2); + register_module("conv3", conv3); + register_module("conv4", conv4); + register_module("batch_norm1", batch_norm1); + register_module("batch_norm2", batch_norm2); + register_module("batch_norm3", batch_norm3); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(batch_norm1(conv1(x))); + x = torch::relu(batch_norm2(conv2(x))); + x = torch::relu(batch_norm3(conv3(x))); + x = torch::tanh(conv4(x)); + return x; + } + + nn::ConvTranspose2d conv1, conv2, conv3, conv4; + nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3; +}; + +TORCH_MODULE(DCGANGenerator); + int main(int argc, const char* argv[]) { torch::manual_seed(1); @@ -41,57 +85,28 @@ int main(int argc, const char* argv[]) { device = torch::Device(torch::kCUDA); } - nn::Sequential generator( - // Layer 1 - nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4) - .with_bias(false) - .transposed(true)), - nn::BatchNorm(256), - nn::Functional(torch::relu), - // Layer 2 - nn::Conv2d(nn::Conv2dOptions(256, 128, 3) - .stride(2) - .padding(1) - .with_bias(false) - .transposed(true)), - nn::BatchNorm(128), - nn::Functional(torch::relu), - // Layer 3 - nn::Conv2d(nn::Conv2dOptions(128, 64, 4) - .stride(2) - .padding(1) - .with_bias(false) - .transposed(true)), - nn::BatchNorm(64), - nn::Functional(torch::relu), - // Layer 4 - nn::Conv2d(nn::Conv2dOptions(64, 1, 4) - .stride(2) - .padding(1) - .with_bias(false) - .transposed(true)), - nn::Functional(torch::tanh)); + DCGANGenerator generator(kNoiseSize); generator->to(device); nn::Sequential discriminator( // Layer 1 nn::Conv2d( - nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)), - nn::Functional(torch::leaky_relu, 0.2), + nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)), + nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), // Layer 2 nn::Conv2d( - nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)), - nn::BatchNorm(128), - nn::Functional(torch::leaky_relu, 0.2), + nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)), + nn::BatchNorm2d(128), + nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), // Layer 3 nn::Conv2d( - nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)), - nn::BatchNorm(256), - nn::Functional(torch::leaky_relu, 0.2), + nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)), + nn::BatchNorm2d(256), + nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), // Layer 4 nn::Conv2d( - nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)), - nn::Functional(torch::sigmoid)); + nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)), + nn::Sigmoid()); discriminator->to(device); // Assume the MNIST dataset is available under `kDataFolder`; diff --git a/cpp/mnist/mnist.cpp b/cpp/mnist/mnist.cpp index 4c6f103dbc..d353c4f021 100644 --- a/cpp/mnist/mnist.cpp +++ b/cpp/mnist/mnist.cpp @@ -47,7 +47,7 @@ struct Net : torch::nn::Module { torch::nn::Conv2d conv1; torch::nn::Conv2d conv2; - torch::nn::FeatureDropout conv2_drop; + torch::nn::Dropout2d conv2_drop; torch::nn::Linear fc1; torch::nn::Linear fc2; }; @@ -99,7 +99,7 @@ void test( output, targets, /*weight=*/{}, - Reduction::Sum) + torch::Reduction::Sum) .template item(); auto pred = output.argmax(1); correct += pred.eq(targets).sum().template item();