In [1]:
using Pkg
using Flux, MLDatasets, Statistics, Random, BSON
using Flux.Optimise: update!
using Flux: logitbinarycrossentropy, binarycrossentropy, onecold, DataLoader, gradient, setup
using Metal

In [2]:
# Initialize Metal and set the current device
# metal_dir = dirname(@__DIR__)
# Pkg.activate(; temp=true)
# Pkg.add(["Metal", "Flux", "DataFrames", "OneHotArrays"])
# Pkg.develop(path=metal_dir)
# using Metal
device = Flux.get_device(; verbose=true)

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mUsing backend: Metal.


(::Flux.FluxMetalDevice) (generic function with 1 method)

In [3]:
# Data Preparation
batchSize, latentDim = 500, 100
epochs = 40
etaD, etaG = 0.0002, 0.0002;

In [4]:
# Data Preparation
images, _ = MNIST(split=:train)[:]
images = reshape(@.(2f0 * images - 1f0), 28, 28, 1, :)
images = gpu(images);  # Move images to GPU early

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mThe CUDA functionality is being called but
[36m[1m│ [22m[39m`CUDA.jl` must be loaded to access it.
[36m[1m└ [22m[39mAdd `using CUDA` or `import CUDA` to your code.


In [5]:
imageTensor = reshape(@.(2f0 * images - 1f0), 28, 28, 1, :)
data = [imageTensor[:, :, :, r] for r in Iterators.partition(1:60000, batchSize)];

In [6]:
# Model Definitions with CPU initialization
dscr = Chain(
    Conv((4,4), 1=>64, stride=2, pad=1),
    x -> leakyrelu.(x, 0.2),
    Dropout(0.25),
    Conv((4,4), 64=>128, stride=2, pad=1),
    x -> leakyrelu.(x, 0.2),
    Dropout(0.25),
    x -> reshape(x, :, size(x, 4)),
    Dense(7*7*128, 1)
) |> gpu  # Move the entire model to GPU after creation

gen = Chain(
    Dense(latentDim, 7*7*256),
    BatchNorm(7*7*256, relu),
    x -> reshape(x, 7, 7, 256, :),
    ConvTranspose((4,4), 256=>128, stride=2, pad=1),
    BatchNorm(128, relu),
    ConvTranspose((4,4), 128=>64, stride=2, pad=1),
    BatchNorm(64, relu),
    ConvTranspose((4,4), 64=>1, tanh, stride=2, pad=1)
) |> gpu  # Move the entire model to GPU after creation


Chain(
  Dense(100 => 12544),                  [90m# 1_266_944 parameters[39m
  BatchNorm(12544, relu),               [90m# 25_088 parameters[39m[90m, plus 25_088[39m
  var"#9#10"(),
  ConvTranspose((4, 4), 256 => 128, pad=1, stride=2),  [90m# 524_416 parameters[39m
  BatchNorm(128, relu),                 [90m# 256 parameters[39m[90m, plus 256[39m
  ConvTranspose((4, 4), 128 => 64, pad=1, stride=2),  [90m# 131_136 parameters[39m
  BatchNorm(64, relu),                  [90m# 128 parameters[39m[90m, plus 128[39m
  ConvTranspose((4, 4), 64 => 1, tanh, pad=1, stride=2),  [90m# 1_025 parameters[39m
) [90m        # Total: 14 trainable arrays, [39m1_948_993 parameters,
[90m          # plus 6 non-trainable, 25_472 parameters, summarysize [39m7.534 MiB.

In [7]:
# Loss Functions
dLoss(realOut, fakeOut) = mean(logitbinarycrossentropy.(realOut, 1f0)) +
                          mean(logitbinarycrossentropy.(fakeOut, 0f0))
gLoss(u) = mean(logitbinarycrossentropy.(u, 1f0))

gLoss (generic function with 1 method)

In [8]:
# Update Functions
function updateD!(gen, dscr, x, opt_dscr)
    noise = randn!(similar(x, (latentDim, batchSize)))
    fakeInput = gen(noise)
    ps = Flux.params(dscr)
    loss, back = Flux.pullback(()->dLoss(dscr(x), dscr(fakeInput)), ps)
    grad = back(1f0)
    update!(opt_dscr, ps, grad)
    return loss
end

updateD! (generic function with 1 method)

In [9]:
function updateG!(gen, dscr, x, optGen)
    noise = randn!(similar(x, (latentDim, batchSize)))
    ps = Flux.params(gen)
    loss, back = Flux.pullback(()->gLoss(dscr(gen(noise))), ps)
    grad = back(1f0)
    update!(optGen, ps, grad)
    return loss
end


updateG! (generic function with 1 method)

In [10]:
# Optimization
optDscr, optGen = ADAM(etaD), ADAM(etaG)

(Adam(0.0002, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}()), Adam(0.0002, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}()))

In [11]:
# Training Loop
cd(@__DIR__)
@time begin
    for ep in 1:epochs
        for (bi, x) in enumerate(data)
            x_gpu = x |> device  # Ensure the batch is moved to GPU
            lossD = updateD!(gen, dscr, x_gpu, optDscr)
            lossG = updateG!(gen, dscr, x_gpu, optGen)
            @info "Epoch $ep, batch $bi, D loss = $(lossD), G loss = $(lossG)"
        end
        @info "Saving generator for epoch $ep"
        BSON.@save "../data/mnistGAN$(ep)_gpu.bson" genParams=cpu.(params(gen))
    end
end

LoadError: ArgumentError: cannot take the CPU address of a MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}