In [None]:
using IJulia
IJulia.installkernel("Julia nodeps", "--depwarn=no")

using Zygote
using Flux
using MLDatasets
using Makie
using Flux: mse, throttle, onehotbatch

In [None]:
AbstractPlotting.inline!(true) # Fixes output of Makie to notebook

In [None]:
dat_x, dat_y = MLDatasets.MNIST.traindata(Float32)
val_x, val_y = MLDatasets.MNIST.testdata(Float32)

batchsize = 10

dat_x = cpu.([reshape(dat_x[:,:,i:i+batchsize-1], :, batchsize) for i in 1:batchsize:size(dat_x, 3)])
val_x = cpu(hcat([vec(val_x[:,:,i]) for i in 1:size(val_x, 3)]...))

In [None]:
function viewmnist(x, d = 64)
    stride = round(Int, sqrt(d))
    img = [reshape(x[:,i], 28, 28) for i in 1:d]
    img = reshape(img, stride, :)
    out = [cat(img[i,1:end]..., dims = 2) for i in 1:stride]
    out = cpu(cat(out..., dims = 1)[:,end:-1:1])
    image(out, show_axis=false)
end

In [None]:
function cbmnist(x, d = 64)
    stride = round(Int, sqrt(d))
    img = [reshape(x[:,i], 28, 28) for i in 1:d]
    img = reshape(img, stride, :)
    out = [cat(img[i,1:end]..., dims = 2) for i in 1:stride]
    out = cpu(cat(out..., dims = 1)[:,end:-1:1])
    IJulia.clear_output(true)
    display(image(out, show_axis=false))
    yield()
end

In [None]:
#act = relu
act = leakyrelu

encoder = cpu(Chain(
        Dense(28^2, 512, act),
        Dense(512, 128, act),
        Dense(128, 10, act),
        Dense(10, 2)))

decoder = cpu(Chain(
        Dense(2, 10, act),
        Dense(10, 128, act),
        Dense(128, 512, act),
        Dense(512, 28^2, act)))

model = Chain(encoder, decoder)

In [None]:
AbstractPlotting.inline!(true) # Fixes output of Makie to notebook
viewmnist(val_x)

In [None]:
evalcb = throttle(() -> cbmnist(model(val_x)), 2)
loss(x) = mse(model(x), x)
opt = ADAM()

In [None]:
AbstractPlotting.inline!(true) # Fixes output of Makie to notebook
Flux.@epochs 3 Flux.train!(loss, params(model), zip(dat_x), opt, cb = evalcb)

In [None]:
AbstractPlotting.inline!(true) # Fixes output of Makie to notebook
viewmnist(model(val_x))

In [None]:
using MultivariateStats

In [None]:
tX, tY = MNIST.traindata(Float32)
vX, vY = MNIST.testdata(Float32)

M = fit(PCA, reshape(tX, :, 60000); maxoutdim = 2)

components = transform(M, reshape(vX, :, 10000))
colors = to_colormap(:Set1, 10)
scatter(components[1, :], components[2, :], color=[colors[y+1] for y in vY])

latentspace_dnn = encoder(val_x)

In [None]:
AbstractPlotting.inline!(false)

decodeimg(x, y) = reshape(decoder(cpu([x, y])), 28, 28)[:,end:-1:1]

s1 = slider(-8.0:0.05:8.0, raw = true, camera = campixel!, start = -5.0)
s2 = slider(-8.0:0.01:8.0, raw = true, camera = campixel!, start = 5.0)

xy = lift((x, y)->[to_value(x) to_value(y)], s1[end][:value], s2[end][:value])
digit   = lift((x, y) -> decodeimg(to_value(x), to_value(y)), s1[end][:value], s2[end][:value])
scene_d = image(digit, show_axis = false)
scene_s = scatter(latentspace_dnn[1,:], latentspace_dnn[2,:], color=[colors[y+1] for y in vY])
scene_s = scatter!(scene_s, xy, color=:black, marker='+', markersize=1)
display(vbox(hbox(scene_d, s1, s2, sizes=[0.8, 0.1, 0.1]), scene_s, sizes=[0.3, 0.7]))