Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 55 additions & 40 deletions cpp/dcgan/dcgan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,50 @@ const int64_t kLogInterval = 10;

using namespace torch;

struct DCGANGeneratorImpl : nn::Module {
DCGANGeneratorImpl(int kNoiseSize)
: conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
.bias(false)),
batch_norm1(256),
conv2(nn::ConvTranspose2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.bias(false)),
batch_norm2(128),
conv3(nn::ConvTranspose2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.bias(false)),
batch_norm3(64),
conv4(nn::ConvTranspose2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.bias(false))
{
// register_module() is needed if we want to use the parameters() method later on
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("conv4", conv4);
register_module("batch_norm1", batch_norm1);
register_module("batch_norm2", batch_norm2);
register_module("batch_norm3", batch_norm3);
}

torch::Tensor forward(torch::Tensor x) {
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::tanh(conv4(x));
return x;
}

nn::ConvTranspose2d conv1, conv2, conv3, conv4;
nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};

TORCH_MODULE(DCGANGenerator);

int main(int argc, const char* argv[]) {
torch::manual_seed(1);

Expand All @@ -41,57 +85,28 @@ int main(int argc, const char* argv[]) {
device = torch::Device(torch::kCUDA);
}

nn::Sequential generator(
// Layer 1
nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(256),
nn::Functional(torch::relu),
// Layer 2
nn::Conv2d(nn::Conv2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(128),
nn::Functional(torch::relu),
// Layer 3
nn::Conv2d(nn::Conv2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(64),
nn::Functional(torch::relu),
// Layer 4
nn::Conv2d(nn::Conv2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::Functional(torch::tanh));
DCGANGenerator generator(kNoiseSize);
generator->to(device);

nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)),
nn::Functional(torch::leaky_relu, 0.2),
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)),
nn::BatchNorm(128),
nn::Functional(torch::leaky_relu, 0.2),
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(128),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)),
nn::BatchNorm(256),
nn::Functional(torch::leaky_relu, 0.2),
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(256),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)),
nn::Functional(torch::sigmoid));
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
nn::Sigmoid());
discriminator->to(device);

// Assume the MNIST dataset is available under `kDataFolder`;
Expand Down
4 changes: 2 additions & 2 deletions cpp/mnist/mnist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct Net : torch::nn::Module {

torch::nn::Conv2d conv1;
torch::nn::Conv2d conv2;
torch::nn::FeatureDropout conv2_drop;
torch::nn::Dropout2d conv2_drop;
torch::nn::Linear fc1;
torch::nn::Linear fc2;
};
Expand Down Expand Up @@ -99,7 +99,7 @@ void test(
output,
targets,
/*weight=*/{},
Reduction::Sum)
torch::Reduction::Sum)
.template item<float>();
auto pred = output.argmax(1);
correct += pred.eq(targets).sum().template item<int64_t>();
Expand Down