# Tutorial: Classification with Flux

This tutorial introduces the reader to classification in [Flux](https://github.com/fluxml/flux.jl) using the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset of handwritten digits to test the classification ability of different networks. MNIST is made up of 60000 training images and 10000 testing images, with the lowest error rate ever achieved on the dataset being $0.23%$. Since 2017 an extended MNIST dataset (EMNIST) is available, but MNIST remains a benchmark for different approaches. The tutorial is based on amalgamation redux of two examples from the [Flux model zoo](https://github.com/FluxML/model-zoo/blob/master/vision/mnist/conv.jl).

Structure:
    1. Classification using a multi-layer-perceptron
    2. Classification using a convolutional neural network
    3. Exercises

In [None]:
using Pkg

In [None]:
using Flux
using Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Printf

## 1. Classification using a multi-layer-perceptron (MLP)

In [None]:
# Definition of the MLP to go here

In [None]:
# Load dataset
imgs = Flux.Data.MNIST.images()

# Stack into one batch
X = hcat(float.(reshape.(imgs, :))...);

In [None]:
# Load labels
labels = Flux.Data.MNIST.labels()

# One-hot-encode the labels
Y = onehotbatch(labels, 0:9);

In [None]:
# Set up the MLP
m = Chain(
    Dense(28^2, 32, relu),
    Dense(32, 10),
    softmax
)

loss(x, y) = crossentropy(m(x), y)
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y));

In [None]:
# Set up the training
dataset = repeated((X, Y), 200)
evalcb = () -> @show(loss(X, Y))
opt = ADAM()

Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 10))

In [None]:
# Assess the accuracy
accuracy(X, Y)

In [None]:
# Compute the test set accuracy
tX = hcat(float.(reshape.(Flux.Data.MNIST.images(:test), :))...)
tY = onehotbatch(Flux.Data.MNIST.labels(:test), 0:9)

accuracy(tX, tY)

## 2. Classification using a convolutional neural network

In [None]:
# Load labels and images
train_labels = Flux.Data.MNIST.labels()
train_imgs = Flux.Data.MNIST.images();

In [None]:
# Construct minibatches
function make_minibatch(X, Y, idxs)
    X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs))
    for i in 1:length(idxs)
        X_batch[:, :, :, i] = Float32.(X[idxs[i]])
    end
    Y_batch = onehotbatch(Y[idxs], 0:9)
    return (X_batch, Y_batch)
end

batch_size = 128
mb_idxs = partition(1:length(train_imgs), batch_size)
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs]

# Test set as one minibatch
test_imgs = Flux.Data.MNIST.images(:test)
test_labels = Flux.Data.MNIST.labels(:test)
test_set = make_minibatch(test_imgs, test_labels, 1:length(test_imgs));

In [None]:
# Convolutional architecture with three iterations of Conv -> ReLU -> MaxPool followed by
# a final dense layer fed into a softmax probability output
model = Chain(
    # 1st convolutional layer, taking a 28x28 image
    Conv((3, 3), 1=>16, pad=(1, 1), relu),
    MaxPool((2, 2)),
    
    # 2nd convolutional layer, taking a 14x14 image
    Conv((3, 3), 16=>32, pad=(1, 1), relu),
    MaxPool((2, 2)),
    
    # 3rd convolutional layer, taking a 7x7 image
    Conv((3, 3), 32=>32, pad=(1, 1), relu),
    MaxPool((2, 2)),
    
    # Reshape 3d tensor into a 2d tensor of shape (3, 3, 32, N)
    x -> reshape(x, :, size(x, 4)),
    Dense(288, 10),
    
    # Softmax output layer
    softmax,
);

Optional GPU training, if the data and model are sent to the available GPU

In [None]:
# If GPU is enabled uncomment the lines below to load the model onto a GPU
#train_set = gpu.(train_set)
#test_set = gpu.(test_set)
#model = gpu(model)

In [None]:
# Precompile the model
model(train_set[1][1])

In [None]:
# Crossentropy loss between prediction and ground truth, add Gaussian noise to make model more robust
function loss(x, y)
    # Add random noise to x
    x_aug = x .+ 0.1f0 * randn(eltype(x), size(x))
    
    y_hat = model(x_aug)
    return crossentropy(y_hat, y)
end

accuracy(x, y) = mean(onecold(model(x)) .== onecold(y));

In [None]:
# Use ADAM as the optimizer
opt = ADAM(0.001)

best_acc = 0.0
last_improvement = 0;

In [None]:
# Define the training loop and train for 100 epochs
for epoch_idx in 1:100
    global best_acc, last_improvement
    
    # Train for one epoch
    Flux.train!(loss, params(model), train_set, opt)
    
    # Calculate the accuracy
    acc = accuracy(test_set...)
    @info(@sprintf("[%d]: Test accuracy: %.4f", epoch_idx, acc))
    
    # Stop if accuracy is good enough
    if acc >= 0.999
        @info(" -> Early-exiting: We reached our target accuracy of 99.9%")
        break
    end
    
    # Reduce learning rate if there has been no improvement for 5 epochs
    if epoch_idx - last_improvement >= 5 && opt.eta > 1e-6
        opt.eta /= 10.0
        @warn(" -> Haven't improved in a while, dropping learning rate to $(opt.eta)!")
        
        last_improvement = epoch_idx
    end
    
    if epoch_idx - last_improvement >= 10
        @warn(" -> The model has converged.")
        break
    end
end

## 3. Exercise: Construct a different neural network for classification

- Experiment with other neural network architectures for classification:
    - Construct a radial basis network and test it on MNIST
         Hint: Combine it with Stheno to use Gaussians as special instances of radial basis functions as activations in a feed-forward network
    - Construct an autoencoder and test its classification ability on MNIST
         Hint: An [autoencoder](https://www.jeremyjordan.me/autoencoders/) consists on an encoder-decoder structure, which can be made up of only feed-forward, dense layers, or convolutional layers, which then amounts to a convolutional autoencoder.
    - Challenge: Build a variational autoencoder and test its performance
- Experiment with a mixture of the initial MLP-classification and convolutional classification, by introducing dense layers between convolutional layers.
    - How does the training time change?
    - How large is the influence of the activation functions on the performance?