In [1]:
using Flux, MLDatasets, Statistics, Random, BSON
using Flux.Optimise: update!
using Flux: logitbinarycrossentropy

In [2]:
using Metal

Metal.functional()
device = Flux.get_device(; verbose=true)
device.deviceID

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mUsing backend: Metal.


<AGXG14GDevice: 0x1562bd800>
    name = Apple M2

In [3]:
batchSize, latentDim = 500, 100
epochs = 40
etaD, etaG = 0.0002, 0.0002;

In [4]:
# images, _ = MLDatasets.MNIST.traindata(Float32)
images, _ = MNIST(split=:train)[:]
imageTensor = reshape(@.(2f0 * images - 1f0), 28, 28, 1, :)
data = [imageTensor[:, :, :, r] for r in Iterators.partition(1:60000, batchSize)];

In [5]:
dscr = Chain(Conv((4,4),1=>64;stride=2,pad=1),x->leakyrelu.(x,0.2f0),
        Dropout(0.25),Conv((4,4),64=>128;stride=2,pad=1),x->leakyrelu.(x,0.2f0),
        Dropout(0.25), x->reshape(x, 7 * 7 * 128, :), Dense(7 * 7 * 128, 1))
gen =  Chain(Dense(latentDim,7*7*256),BatchNorm(7*7*256,relu),
        x->reshape(x,7,7,256,:),ConvTranspose((5,5),256=>128;stride=1,pad=2),
        BatchNorm(128,relu),ConvTranspose((4,4),128=>64;stride=2,pad=1),
        BatchNorm(64,relu),ConvTranspose((4,4),64=>1,tanh;stride=2,pad=1))

Chain(
  Dense(100 => 12544),                  [90m# 1_266_944 parameters[39m
  BatchNorm(12544, relu),               [90m# 25_088 parameters[39m[90m, plus 25_088[39m
  var"#9#10"(),
  ConvTranspose((5, 5), 256 => 128, pad=2),  [90m# 819_328 parameters[39m
  BatchNorm(128, relu),                 [90m# 256 parameters[39m[90m, plus 256[39m
  ConvTranspose((4, 4), 128 => 64, pad=1, stride=2),  [90m# 131_136 parameters[39m
  BatchNorm(64, relu),                  [90m# 128 parameters[39m[90m, plus 128[39m
  ConvTranspose((4, 4), 64 => 1, tanh, pad=1, stride=2),  [90m# 1_025 parameters[39m
) [90m        # Total: 14 trainable arrays, [39m2_243_905 parameters,
[90m          # plus 6 non-trainable, 25_472 parameters, summarysize [39m8.659 MiB.

In [6]:
dLoss(realOut,fakeOut) =    mean(logitbinarycrossentropy.(realOut,1f0)) +
                            mean(logitbinarycrossentropy.(fakeOut,0f0))
gLoss(u) = mean(logitbinarycrossentropy.(u, 1f0));

In [7]:
function updateD!(gen, dscr, x, opt_dscr)
    noise = randn!(similar(x, (latentDim, batchSize)))
    fakeInput = gen(noise)
    ps = Flux.params(dscr)
    loss, back = Flux.pullback(()->dLoss(dscr(x), dscr(fakeInput)), ps)
    grad = back(1f0)
    update!(opt_dscr, ps, grad)
    return loss
end

updateD! (generic function with 1 method)

In [8]:
function updateG!(gen, dscr, x, optGen)
    noise = randn!(similar(x, (latentDim, batchSize)))
    ps = Flux.params(gen)
    loss, back = Flux.pullback(()->gLoss(dscr(gen(noise))),ps)
    grad = back(1f0)
    update!(optGen, ps, grad)
    return loss
end

updateG! (generic function with 1 method)

In [9]:
optDscr, optGen = ADAM(etaD), ADAM(etaG)
cd(@__DIR__)
@time begin
    for ep in 1:epochs
        for (bi,x) in enumerate(data)
            lossD = updateD!(gen, dscr, x, optDscr)
            lossG = updateG!(gen, dscr, x, optGen)
            @info "Epoch $ep, batch $bi, D loss = $(lossD), G loss = $(lossG)"
        end
        @info "Saving generator for epcoh $ep"
        BSON.@save "../data/mnistGAN$(ep).bson" genParams=cpu.(params(gen))
    end
end

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 1, D loss = 1.3860536, G loss = 0.68386394
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 2, D loss = 1.2977568, G loss = 0.633849
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 3, D loss = 1.2196785, G loss = 0.58668965
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 4, D loss = 1.1495194, G loss = 0.54042566
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 5, D loss = 1.0852562, G loss = 0.49773306
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 6, D loss = 1.031728, G loss = 0.45863298
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 7, D loss = 0.9815934, G loss = 0.41732886
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 8, D loss = 0.9550576, G loss = 0.37962782
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 9, D loss = 0.9315155, G loss = 0.34747088
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch 1, batch 10, D 

LoadError: UndefVarError: `params` not defined