In [None]:
using Flux
using MLDatasets
using Flux.Optimise: update!, train!
using Flux.Data: DataLoader
using Flux: logitbinarycrossentropy
using ProgressMeter: Progress, next!

In [2]:
# Load data sets
X,Y = MLDatasets.MNIST.traindata(Float32)

# Image to vector
X = reshape(X, 28^2, :)

# Data shape
D,N = size(X)

# Start DataLoader
data = DataLoader(X,Y, batchsize=32, shuffle=true)

DataLoader((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4  …  9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 32, 60000, true, 60000, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  59991, 59992, 59993, 59994, 59995, 59996, 59997, 59998, 59999, 60000], true)

In [3]:
function encoder(; input_dim=2, hidden_dim=3, latent_dim=1, nonlinearity=tanh)
    "Mapping from observed space to latent space"

    # Map input to hidden layer
    h = Dense(input_dim, hidden_dim, nonlinearity)

    # Map hidden layer activity to mean and log-variance
    μ = Chain(h, Dense(hidden_dim, latent_dim, nonlinearity))
    logσ = Chain(h, Dense(hidden_dim, latent_dim, nonlinearity))

    return μ, logσ
end

encoder (generic function with 1 method)

In [4]:
function decoder(; latent_dim=1, hidden_dim=3, input_dim=2, nonlinearity=tanh)
   "Mapping from latent space to observed space" 
    
    # Latent to hidden
    h = Dense(latent_dim, hidden_dim, nonlinearity)
    
    # Hidden to input
    return Chain(h, Dense(hidden_dim, input_dim, nonlinearity))
end

decoder (generic function with 1 method)

In [5]:
function reconstruct(x, encoder, decoder)
   "Apply encoder and decoder to data"
    
    # Number of samples
    input_dim, num_samples = size(x)
    
    # Encode samples
    μ = encoder[1](x)
    logσ = encoder[2](x)
    
    # Dimensionality of latent space
    latent_dim = size(μ,1)
    
    # Generate samples in latent space
    z = μ + randn(Float32, (latent_dim, num_samples)) .* exp.(logσ)
    
    # Decode generated samples
    x_hat = decoder(z)  
    
    return μ, logσ, x_hat
end

reconstruct (generic function with 1 method)

In [6]:
function loss(x, encoder, decoder; λ=0.1)
    "Loss layer"
    
    # Encode and decode data
    μ, logσ, x_hat = reconstruct(x, encoder, decoder)
   
    # KL-divergence
    KL = 0.5 * sum(@. (exp(2. *logσ) + μ^2 -1. - 2. *logσ)) / N

    # Reconstruction error
    logp_x_z = -sum(logitbinarycrossentropy.(x_hat, x)) / N
    
    # regularization
    reg = λ * sum(x->sum(x.^2), Flux.params(decoder))

    return -logp_x_z + KL + reg
end

loss (generic function with 1 method)

In [30]:
# Training parameters
num_epochs = 2

# Define optimizer
learning_rate = 1e-3
opt = ADAM(learning_rate)

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

In [None]:
# Define encoder and decoder
enc = encoder(input_dim=784, hidden_dim=500, latent_dim=1)
dec = decoder(latent_dim=1, hidden_dim=500, input_dim=784)

# Extract parameters
ps = Flux.params(enc[1], enc[2], dec)

# Create output directory if not present
!ispath("output") && mkpath("output")

@info "Start Training, total $(num_epochs) epochs"
for epoch in 1:num_epochs
    
    # Report progress
    @info "Epoch $(epoch)"
    progress = Progress(length(data),1)
    
    # Iterate over data
    for (x,y) in data
        
        # Define gradient function
        gs = gradient(ps) do
            training_loss = loss(x, enc, dec)
          return training_loss
        end
        
        # Update params
        Flux.update!(opt, ps, gs)
    
        # Update progress meter
        next!(progress; showvalues=[(:loss, loss(x, enc, dec))])
    end
end

┌ Info: Start Training, total 2 epochs
└ @ Main In[41]:11
┌ Info: Epoch 1
└ @ Main In[41]:15
[32mProgress:   0%|▏                                        |  ETA: 0:05:52[39m
[A4m  loss:  46.93286285680531[39m
[32mProgress:   1%|▎                                        |  ETA: 0:04:54[39m
[A4m  loss:  31.598729875600316[39m
[32mProgress:   1%|▌                                        |  ETA: 0:04:35[39m
[A4m  loss:  20.549827077853124[39m
[32mProgress:   2%|▋                                        |  ETA: 0:04:26[39m
[A4m  loss:  12.937808322551204[39m
[32mProgress:   2%|▉                                        |  ETA: 0:04:21[39m
[A4m  loss:  7.907152677457458[39m
[32mProgress:   2%|█                                        |  ETA: 0:04:20[39m
[A4m  loss:  5.033096095058622[39m
[32mProgress:   3%|█▏                                       |  ETA: 0:04:24[39m
[A4m  loss:  3.159264708737334[39m
[32mProgress:   3%|█▎                                       |  ETA: 0: