In [None]:
import Pkg; Pkg.add.(["Flux", "UnicodePlots", "Images", "ImageIO", "ImageMagick", "PlutoUI", "PyCall", "Conda", "BSON"])

using Flux
using Flux.Data.MNIST
using UnicodePlots
using Images
using ImageIO
using ImageMagick
using PlutoUI
using PyCall
using Conda
using BSON: @save

In [None]:
Conda.add("wandb"; channel="conda-forge")
wandb = pyimport("wandb")

In [None]:
labels = MNIST.labels()
images = MNIST.images()

n_inputs = unique(length.(images))[]
n_outputs = length(unique(labels))

In [None]:
preprocess(img) = vec(Float64.(img))

In [None]:
function create_batch(r)
    xs = [preprocess(img) for img in images[r]]
    ys = [Flux.onehot(label, 0:9) for label in labels[r]]
    return (Flux.batch(xs), Flux.batch(ys))
end

In [None]:
trainbatch = create_batch(1:5000)
testbatch = create_batch(5001:6000)

In [None]:
epochs = 10
train_loss = Float64[]
test_loss = Float64[]

wandb_run = wandb.init(project="mnist-flux")
function update_loss!()
    trainL = L(trainbatch...)
    testL = L(testbatch...)
    push!(train_loss, trainL)
    push!(test_loss, testL)
    wandb.log(Dict("training_loss"=>trainL))
    wandb.log(Dict("testing_loss"=>testL))    
end

model = Chain(
    Dense(n_inputs, n_outputs, identity), 
    softmax
)

L(x,y) = Flux.crossentropy(model(x), y)
opt = Flux.Optimise.Descent()
@elapsed Flux.train!(L, 
                    params(model), 
                    Iterators.repeated(trainbatch, epochs), 
                    opt; 
                    cb=Flux.throttle(update_loss!, 1))

In [None]:
lineplot(1:length(train_loss), train_loss, title = "train_loss")

In [None]:
test_index = 50001
#images[test_index]
println(labels[test_index])
findmax(model(preprocess(images[test_index]))) .- (0, 1)

In [None]:
@save "mnist-flux.bson" model
wandb.save("mnist-flux.bson")

In [None]:
wandb.termwarn("Done!")