# Deep Convolutional GAN (DC-GAN)

In [1]:
using Base.Iterators: partition
using Flux
using Flux.Optimise: update!
using Flux.Losses: logitbinarycrossentropy
using Images
using MLDatasets
using Statistics
using Printf
using Random
using CUDA
CUDA.allowscalar(true)

In [2]:
Base.@kwdef struct HyperParams
    batch_size::Int = 128
    latent_dim::Int = 100
    # epochs::Int = 20
    epochs::Int = 1
    verbose_freq::Int = 1000
    output_x::Int = 6
    output_y::Int = 6
    lr_dscr::Float32 = 0.0002
    lr_gen::Float32 = 0.0002
end

HyperParams

In [3]:
function create_output_image(gen, fixed_noise, hparams)
    fake_images = @. cpu(gen(fixed_noise))
    image_array = reduce(vcat, reduce.(hcat, partition(fake_images, hparams.output_y)))
    image_array = permutedims(dropdims(image_array; dims=(3, 4)), (2, 1))
    image_array = @. Gray(image_array + 1f0) / 2f0
    return image_array
end

create_output_image (generic function with 1 method)

In [4]:
function Discriminator()
    return Chain(
            Conv((4, 4), 1 => 64; stride = 2, pad = 1, init = dcgan_init),
            x->leakyrelu.(x, 0.2f0),
            Dropout(0.25),
            Conv((4, 4), 64 => 128; stride = 2, pad = 1, init = dcgan_init),
            x->leakyrelu.(x, 0.2f0),
            Dropout(0.25), 
            x->reshape(x, 7 * 7 * 128, :),
            Dense(7 * 7 * 128, 1))	
end

Discriminator (generic function with 1 method)

In [5]:
function Generator(latent_dim::Int)
    return Chain(
            Dense(latent_dim, 7 * 7 * 256),
            BatchNorm(7 * 7 * 256, relu),
            x->reshape(x, 7, 7, 256, :),
            ConvTranspose((5, 5), 256 => 128; stride = 1, pad = 2, init = dcgan_init),
            BatchNorm(128, relu),
            ConvTranspose((4, 4), 128 => 64; stride = 2, pad = 1, init = dcgan_init),
            BatchNorm(64, relu),
            ConvTranspose((4, 4), 64 => 1; stride = 2, pad = 1, init = dcgan_init),
            x -> tanh.(x)
            )
end

Generator (generic function with 1 method)

In [6]:
# Loss functions
function discriminator_loss(real_output, fake_output)
    real_loss = logitbinarycrossentropy(real_output, 1)
    fake_loss = logitbinarycrossentropy(fake_output, 0)
    return real_loss + fake_loss
end

discriminator_loss (generic function with 1 method)

In [7]:
function train_discriminator!(gen, dscr, x, opt_dscr, hparams)
    noise = randn!(similar(x, (hparams.latent_dim, hparams.batch_size))) 
    fake_input = gen(noise)
    # Taking gradient
    loss, grads = Flux.withgradient(dscr) do dscr
        discriminator_loss(dscr(x), dscr(fake_input))
    end
    update!(opt_dscr, dscr, grads[1])
    return loss
end

train_discriminator! (generic function with 1 method)

In [8]:
function train_generator!(gen, dscr, x, opt_gen, hparams)
    noise = randn!(similar(x, (hparams.latent_dim, hparams.batch_size))) 
    # Taking gradient
    loss, grads = Flux.withgradient(gen) do gen
        generator_loss(dscr(gen(noise)))
    end
    update!(opt_gen, gen, grads[1])
    return loss
end

train_generator! (generic function with 1 method)

In [9]:
function train(; kws...)
    # Model Parameters
    hparams = HyperParams(; kws...)

    if CUDA.functional()
        @info "Training on GPU"
    else
        @warn "Training on CPU, this will be very slow!"  # 20 mins/epoch
    end

    # Load MNIST dataset
    images = MLDatasets.MNIST(:train).features
    # Normalize to [-1, 1]
    image_tensor = reshape(@.(2f0 * images - 1f0), 28, 28, 1, :)
    # Partition into batches
    data = [image_tensor[:, :, :, r] |> gpu for r in partition(1:60000, hparams.batch_size)]

    fixed_noise = [randn(Float32, hparams.latent_dim, 1) |> gpu for _=1:hparams.output_x*hparams.output_y]

    # Discriminator
    dscr = Discriminator() |> gpu

    # Generator
    gen =  Generator(hparams.latent_dim) |> gpu

    # Optimizers
    opt_dscr = Flux.setup(Adam(hparams.lr_dscr), dscr)
    opt_gen = Flux.setup(Adam(hparams.lr_gen), gen)

    # Training
    train_steps = 0
    for ep in 1:hparams.epochs
        @info "Epoch $ep"
        for x in data
            # Update discriminator and generator
            loss_dscr = train_discriminator!(gen, dscr, x, opt_dscr, hparams)
            loss_gen = train_generator!(gen, dscr, x, opt_gen, hparams)

            if train_steps % hparams.verbose_freq == 0
                @info("Train step $(train_steps), Discriminator loss = $(loss_dscr), Generator loss = $(loss_gen)")
                # Save generated fake image
                output_image = create_output_image(gen, fixed_noise, hparams)
                save(@sprintf("images/dcgan_steps_%06d.png", train_steps), output_image)
            end
            train_steps += 1
        end
    end

    output_image = create_output_image(gen, fixed_noise, hparams)
    save(@sprintf("images/dcgan_steps_%06d.png", train_steps), output_image)
end

train (generic function with 1 method)

In [10]:
# weight initialization as given in the paper https://arxiv.org/abs/1511.06434
dcgan_init(shape...) = randn(Float32, shape...) * 0.02f0

generator_loss(fake_output) = logitbinarycrossentropy(fake_output, 1)

# if abspath(PROGRAM_FILE) == @__FILE__
    train()
# end


[33m[1m└ [22m[39m[90m@ Main In[9]:8[39m
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mThe CUDA function is being called but CUDA.jl is not functional.
[36m[1m└ [22m[39mDefaulting back to the CPU. (No action is required if you want to run on the CPU).
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrain step 0, Discriminator loss = 1.3771783, Generator loss = 0.68713003


LoadError: TaskFailedException

[91m    nested task error: [39mTaskFailedException
    
    [91m    nested task error: [39mInterruptException:
        Stacktrace:
         [1] [0m[1mcol2im![22m[0m[1m([22m[90mx[39m::[0mSubArray[90m{Float32, 4, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Int64}, true}[39m, [90mcol[39m::[0mSubArray[90m{Float32, 2, Array{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}[39m, [90mcdims[39m::[0mDenseConvDims[90m{3, 3, 3, 6, 3}[39m, [90mbeta[39m::[0mFloat32[0m[1m)[22m
        [90m   @[39m [35mNNlib[39m [90m~/.julia/packages/NNlib/O0zGY/src/impl/[39m[90m[4mconv_im2col.jl:290[24m[39m
         [2] [0m[1m(::NNlib.var"#652#653"{Float32, Array{Float32, 3}, Float32, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}, Int64})[22m[0m[1m([22m[0m[1m)[22m
        [90m   @[39m [35mNNlib[39m [90m~/.julia/packages/NNlib/O0zGY/src/impl/[39m[90m[4mconv_im2col.jl:165[24m[39m
    Stacktrace:
     [1] [0m[1msync_end[22m[0m[1m([22m[90mc[39m::[0mChannel[90m{Any}[39m[0m[1m)[22m
    [90m   @[39m [90mBase[39m [90m./[39m[90m[4mtask.jl:448[24m[39m
     [2] [0m[1mmacro expansion[22m
    [90m   @[39m [90m./[39m[90m[4mtask.jl:480[24m[39m[90m [inlined][39m
     [3] [0m[1m∇conv_data_im2col![22m[0m[1m([22m[90mdx[39m::[0mSubArray[90m{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}[39m, [90mdy[39m::[0mSubArray[90m{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}[39m, [90mw[39m::[0mSubArray[90m{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}[39m, [90mcdims[39m::[0mDenseConvDims[90m{3, 3, 3, 6, 3}[39m; [90mcol[39m::[0mArray[90m{Float32, 3}[39m, [90malpha[39m::[0mFloat32, [90mbeta[39m::[0mFloat32, [90mntasks[39m::[0mInt64[0m[1m)[22m
    [90m   @[39m [35mNNlib[39m [90m~/.julia/packages/NNlib/O0zGY/src/impl/[39m[90m[4mconv_im2col.jl:155[24m[39m
     [4] [0m[1m∇conv_data_im2col![22m[0m[1m([22m[90mdx[39m::[0mSubArray[90m{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}[39m, [90mdy[39m::[0mSubArray[90m{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}[39m, [90mw[39m::[0mSubArray[90m{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}[39m, [90mcdims[39m::[0mDenseConvDims[90m{3, 3, 3, 6, 3}[39m[0m[1m)[22m
    [90m   @[39m [35mNNlib[39m [90m~/.julia/packages/NNlib/O0zGY/src/impl/[39m[90m[4mconv_im2col.jl:126[24m[39m
     [5] [0m[1m(::NNlib.var"#324#328"{@Kwargs{}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})[22m[0m[1m([22m[0m[1m)[22m
    [90m   @[39m [35mNNlib[39m [90m~/.julia/packages/NNlib/O0zGY/src/[39m[90m[4mconv.jl:253[24m[39m

# References

- [ ] [Deep Convolutional GAN (DC-GAN)](https://github.com/FluxML/model-zoo/tree/master/vision/dcgan_mnist)