# GDL2e Chapter 2 Examples

## Some basic utility functions

### Flux and required packages.

In [None]:
using Flux, Metal, MLUtils, OneHotArrays, Statistics, ProgressMeter


### Get the CIFAR10 image data

In [None]:
using MLDatasets: CIFAR10

# Get the 60,000 32x32 pixel color image data
function cifar10_data() 
    # Split the 60,000 images into training and testing observations
    # and make sure we have normalized Float32 pixel data.
    (CIFAR10(Tx=Float32, split=:train), CIFAR10(Tx=Float32, split=:test))
end


### Create one-hot encoding of targets/labels

In [None]:
function onehotlabels(data ::CIFAR10)
    onehotbatch(data.targets, range(extrema(data.targets)...))
end


### Dataloader

In [None]:
function loader(data::CIFAR10; batchsize)
    x = data.features
    y = onehotlabels(data)
    Flux.DataLoader((x, y); batchsize, shuffle=true)
end


### Hyperparameters

We provide a set of training *hyperparameters* but prefer to call these
training parameters.

In [None]:
struct TrainingParams
    batchsize :: Int
    epochs    :: Int
    learnrate :: Float64
end

function trainparams(;batchsize::Int, epochs::Int, learnrate::Float64)
    TrainingParams(batchsize, epochs, learnrate)
end


### Compute accuracy of predictions

In [None]:
function accuracy(m, data::CIFAR10, args::TrainingParams)
    (x, y) = only(loader(data; batchsize=(length(data))))
    y_hat = m(x)
    iscorrect = Flux.onecold(y_hat) .== Flux.onecold(y)
    round(100 * mean(iscorrect); digits=2)
end


## Let's put those functions to work to train the model network.

In [None]:
# constructor with keyword args
function trainwith(model, train_data::CIFAR10, args::TrainingParams; device)
    @info "trainwith" args
    # model
    md = device(model)
    # loader
    train_loader = loader(train_data, batchsize=args.batchsize)
    # optimizer state with training rate
    opt_state = Flux.setup(Adam(args.learnrate), md)

    losses = [] # keep track of loss at each epoch

    @showprogress for epoch in 1:args.epochs
        for (x_batch, y_batch) in train_loader
            # device transfer if required
            x, y = device(x_batch), device(y_batch)
            # compute loss and gradients
            l, gs = Flux.withgradient(m -> Flux.crossentropy(m(x), y), md)
            # update model parameters
            Flux.update!(opt_state, md, gs[1]) # see: withgradient
            # accumulate losses for logging
            push!(losses, l)
        end
    end
    return (md, losses, length(train_loader))
end


## Main train and test

In [None]:
using Plots

function trainandtest(model, tparam::TrainingParams; device=cpu)
    @info "Loading CIFAR10 data..."
    train, test = cifar10_data()

    @info "Training..."
    (trained, losses, n) = trainwith(model, train, tparam, device=device)

    @info "Testing..."
    testm = cpu(trained)
    train_a = accuracy(testm, train, tparam)
    test_a = accuracy(testm, test, tparam)
    @info "Accuracy:" train_a test_a

    # output a plot of loss
    plot(losses; xaxis=(:log10, "iteration"),
         yaxis="loss", label="per batch")
    # mean loss for epoch
    plot!(n:n:length(losses), mean.(Iterators.partition(losses, n)),
          label="epoch mean", dpi=200)
end


## The simple model from GDL2e Chapter 2

Note:
[softmax](https://fluxml.ai/Flux.jl/stable/models/nnlib/#NNlib.softmax)
must not be passed to layers like Dense, which accept an activation
function, as activation is broadcasted; if you get errors dispatching
softmax this might be the problem.

In [None]:
function simplemodel()
    Chain(MLUtils.flatten,
          Dense(32^2 * 3 => 200, relu),
          Dense(200 => 150, relu),
          Dense(150 => 10),
          softmax)
end


## Train and test the results of simple model

Note: Doing this on the GPU takes twice as long as the CPUs which takes
1:17 mins on my laptop using 8 cores.

In [None]:
trainandtest(simplemodel(),
             trainparams(batchsize=32,
                         epochs=10,
                         learnrate=5e-4),
             device=cpu)


## The convolutional model (CNN)

Following the batch, activation, dropout (BAD) method. NB: a kernel size
of 3 in *Keras* conv2d is (3,3) in the more generic *Flux* Conv.

In [None]:
function cnnmodel()
    Chain(
        # 1
        Conv((3,3), 3 => 32; pad=SamePad(), stride=1),
        BatchNorm(32, rrelu),
        # 2
        Conv((3,3), 32 => 32; pad=SamePad(), stride=2),
        BatchNorm(32, rrelu),
        # 3
        Conv((3,3), 32 => 64; pad=SamePad(), stride=1),
        BatchNorm(64, rrelu),
        # 4
        Conv((3,3), 64 => 64, pad = SamePad(), stride=2),
        BatchNorm(64, rrelu),
        # 5
        MLUtils.flatten,
        Dense(4096 => 128),
        BatchNorm(128, rrelu),
        Dropout(0.5),
        # 6
        Dense(128 => 10),
        softmax
    )
end


## Train and test the convolutional model

Note: Doing this on GPU fails miserably with some scalar indexing bug in
one of the layers. This was *Metal* driver so probably not worth
investigating at this point. I'll update packages and try again sometime
(18-Dec-23) On my laptop using 8 cores this takes about 30 mins to
train.

In [None]:
trainandtest(cnnmodel(),
             trainparams(batchsize=32,
                         epochs=10,
                         learnrate=5e-4),
             device=cpu)
