# Deep Convolutional Generative Adversarial Network (DCGAN)
Train a DCGAN to generate MNIST images

See [this repo](https://github.com/FluxML/model-zoo/tree/master/vision/cdcgan_mnist), [this pytorch tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html), and [this](https://github.com/soumith/ganhacks) for useful tricks.

[Origianl GAN paper](https://arxiv.org/abs/1406.2661)

[DCGAN paper](https://arxiv.org/pdf/1511.06434.pdf)

In [None]:
# Import packages
using Pkg
Pkg.activate(".")
using Flux
using MLDatasets
using Random
using CUDA

: 

In [41]:
# Hyperparameters
batch_size = 128  
latent_dim = 100
epochs = 30
n_val = 16
lr_disc = 0.0002
lr_gen = 0.0002


0.0002

In [24]:
# Load MNIST dataset and reshape + normalize to [-1,1]
images = MLDatasets.MNIST(:train).features
image_tensor = reshape(@.(2f0*images - 1f0), 28, 28, 1, :)
size(images)

(28, 28, 60000)

In [6]:
# Labels to use for real images. As noted in the useful tricks, it helps to set real=0 and fake=1 and using soft labels: real=[0,1], fake=[0.9,1]
real_labels = 0.1.*rand(Float32, size(image_tensor)[end])
fake_labels = 0.1.*rand(Float32, size(image_tensor)[end]).+0.9

60000-element Vector{Float64}:
 0.9900133013725281
 0.9000288128852845
 0.9238494038581848
 0.9546479403972626
 0.9118001341819764
 0.9931084394454956
 0.9013398230075836
 0.978172916173935
 0.999093782901764
 0.9117977380752563
 0.9535072267055512
 0.9035910665988922
 0.9979776263237
 ⋮
 0.9675454676151276
 0.9204791843891144
 0.9304541230201722
 0.9807366907596589
 0.9015099227428437
 0.9300561308860779
 0.9478922963142395
 0.9884008467197418
 0.9303670108318329
 0.9534220278263092
 0.9790423393249512
 0.9586935639381409

In [32]:
# Fake but fixed noise to use for monitoring training process
fixed_noise = randn(Float32,(latent_dim, n_val))

100×16 Matrix{Float32}:
  0.00724127   1.62873     …   0.0110928  -1.04959      0.536255
  0.348034     1.25897         0.680787   -0.617239    -0.818312
 -0.610398     0.302688        0.748166    0.00465741   1.34457
  0.806511    -0.917159       -0.152679   -0.628985     1.14415
 -2.13614      0.791357        0.319659   -0.172842    -0.816172
  2.64086      0.14931     …  -0.185784   -1.15927      0.98001
 -0.131726     0.378519        0.689256   -1.27684      2.54643
  0.316592     0.609034        0.804003   -0.935041     0.0113185
 -0.0887196   -0.382533       -1.07647    -0.651147    -0.0727658
 -1.62575     -0.297545        1.05755     0.252879    -0.560706
  0.690676     0.00189033  …   0.466039   -0.496895    -0.224915
 -0.15587      0.947364       -0.405452   -1.69413      1.23672
  0.824651    -0.19469         0.34773    -0.828632    -0.181304
  ⋮                        ⋱                            ⋮
  1.22846     -0.148749       -0.173922   -1.36994      0.922575
  0.439776 

In [12]:
# Initializer for weights
dcgan_init(shape...) = randn(Float32,shape...)* 0.02f0

dcgan_init (generic function with 1 method)

In [40]:
# Function to build discriminator and its loss function
function Discriminator()
    return Chain(
        Conv((4,4),1 => 64; stride=2, pad=1, init=dcgan_init),
        x->leakyrelu.(x, 0.2f0),
        Dropout(0.25),
        Conv((4,4), 64 => 128; stride=2, pad=1, init=dcgan_init),
        x->leakyrelu.(x, 0.2f0),
        Dropout(0.25),
        x->reshape(x,7*7*128,:),
        Dense(7*7*128=>1)
    )
end

function discriminator_loss(real_output, fake_output, real_labels, fake_labels)
    real_loss = logitbinarycrossentropy(real_output, real_labels)
    fake_loss = logitbinarycrossentropy(fake_output, fake_labels)
    return real_loss+fake_loss
end

discriminator_loss (generic function with 1 method)

In [35]:
# Function to build generator
function Generator(latent_dim::Int)
    return Chain(
        Dense(latent_dim, 7*7*256),
        BatchNorm(7*7*256, relu),
        x->reshape(x, 7, 7, 256, :),
        ConvTranspose((5,5), 256 => 128; stride=1, pad=2, init=dcgan_init),
        BatchNorm(128,relu),
        ConvTranspose((4,4), 128 => 64; stride=2, pad=1, init=dcgan_init),
        BatchNorm(64,relu),
        ConvTranspose((4,4), 64 => 1, stride=2, pad=1, init=dcgan_init),
        x->tanh.(x)
    )
end

generator_loss(fake_output) = logitbinarycrossentropy(fake_output, real_labels)

Generator (generic function with 1 method)

In [37]:
gen = Generator(latent_dim)
disc = Discriminator()

Chain(
  Conv((4, 4), 1 => 64, pad=1, stride=2),  [90m# 1_088 parameters[39m
  var"#3#6"(),
  Dropout(0.25),
  Conv((4, 4), 64 => 128, pad=1, stride=2),  [90m# 131_200 parameters[39m
  var"#4#7"(),
  Dropout(0.25),
  var"#5#8"(),
  Dense(6272 => 1),                     [90m# 6_273 parameters[39m
) [90m                  # Total: 6 arrays, [39m138_561 parameters, 542.145 KiB.

In [42]:
disc_opt = Flux.setup(Adam(lr_disc), disc)
gen_opt = Flux.setup(Adam(lr_gen), gen)

(layers = ((weight = [32mLeaf(Adam(0.0002, (0.9, 0.999), 1.0e-8), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, bias = [32mLeaf(Adam(0.0002, (0.9, 0.999), 1.0e-8), [39m(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))[32m)[39m, σ = ()), (λ = (), β = [32mLeaf(Adam(0.0002, (0.9, 0.999), 1.0e-8), [39m(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))[32m)[39m, γ = [32mLeaf(Adam(0.0002, (0.9, 0.999), 1.0e-8), [39m(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0

In [33]:
gen(fixed_noise)

28×28×1×16 Array{Float32, 4}:
[:, :, 1, 1] =
 -0.000745898   0.000961913  -0.00471434   …   0.000640112   0.00068877
 -0.00465893    0.00374137   -0.00296304      -0.0075544    -0.000739787
 -0.00108253    0.0107205    -0.00653486      -0.00108507   -0.00100319
  0.00302239    0.000982488   0.00143781      -0.00777687    0.00421835
 -0.00424564    0.00380425   -0.00757325       0.00214467   -0.000574293
 -0.000369982   0.00462612   -0.00625899   …  -0.00943098    0.00575584
  0.00100281    0.0186923    -0.00964474       0.00501088   -0.00147504
  0.000980961  -7.88132f-7   -0.0095213       -0.00301693    0.00448221
 -0.00338154    0.00168633   -0.0149562        3.35285f-5   -0.0020573
 -0.00353171    0.00735607   -0.00125396      -0.00790458    0.00079305
  0.0016785     0.00890716    0.00138324   …   0.00455573   -0.00318709
  0.000800399  -0.00116952    0.00121243      -0.00751017    0.00331199
 -0.00263419   -0.00467495   -0.0127132        0.00148931   -0.00112025
  ⋮               

In [39]:
disc(image_tensor[:,:,:,1:4])

1×4 Matrix{Float32}:
 -0.0306821  -0.00460608  0.014002  -0.0174576