# Introduction to Julia: Machine Learning Example

This code is an example of how to train a machine learning model in Julia. It is based on the PyTorch tutorial: [Training a Classifier](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py). Most of what you see in Julia code is very similar to what you would see in R, MATLAB, or Python. 

## Imports

In [None]:
using Statistics     
using Flux
using Flux.MLUtils: DataLoader
using Flux.Losses: logitcrossentropy
using MLDatasets
using Images
using Flux: onehotbatch, onecold, flatten, gradient
using CUDA

We set the following environment variable to avoid having to type 'y' when downloading the CIFAR-10 dataset.

In [None]:
ENV["DATADEPS_ALWAYS_ACCEPT"] = true;

These are the parameters we will use to train the model.

In [None]:
batchsize = 128
learning_rate = 3e-4
epochs = 2
validationsplit = 0.2;

We need to check if there is a CUDA GPU attached to the resource. This sets the device to train on appropriately.

In [None]:
if CUDA.functional()
    @info "Training on CUDA GPU"
    CUDA.allowscalar(false)
    device = gpu
else
    @info "Training on CPU"
    device = cpu
end

Next, we download the data and create DataLoaders to pass batches of data when training.

In [None]:
traindata = CIFAR10(; Tx=Float32, split=:train)
testdata = CIFAR10(; Tx=Float32, split=:test)

classes = traindata.metadata["class_names"]

(xtrain, ytrain), (xvalid, yvalid) = Flux.MLUtils.splitobs((traindata.features, traindata.targets), at=1-validationsplit)
xtest, ytest = testdata.features, testdata.targets

ytrain, yvalid, ytest = onehotbatch(ytrain, 0:9), onehotbatch(yvalid, 0:9), onehotbatch(ytest, 0:9)

train_loader = DataLoader((xtrain, ytrain), batchsize=batchsize, shuffle=true)
valid_loader = DataLoader((xvalid, yvalid), batchsize=batchsize)
test_loader = DataLoader((xtest, ytest),  batchsize=batchsize);

Let's take a look at one of the CIFAR-10 pictures to get a sense of the data. We permute the dimensions to conform with the Image library expectations.

In [None]:
colorview(RGB, permutedims(xtrain[:,:,:,1], (3, 2, 1,)))

This appears to be a picture of a frog. Let's check the actual class.

In [None]:
classes[ytrain[:,1]]

Next, we define a simple CNN model to train. Using different layers, you could create more complex models. Check out the [MetalHead.jl library](https://github.com/FluxML/Metalhead.jl) for some prebuilt standard models. 

In [None]:
model = Chain(
    Conv((5,5),  3=>16, relu), 
    MaxPool((2,2)),
    Conv((5, 5), 16=>8, relu),
    flatten,
    Dense(8*4*5*5, 120),
    Dense(120, 84),
    Dense(84, 10)
);

We need to move the model to the correct device, and then set our loss and optimization functions. We also need to pull out the parameters of the model so we can pass them to the training function.

In [None]:
model = model |> device
loss(x, y) = logitcrossentropy(model(x), y)
opt = ADAM(learning_rate)
ps = Flux.params(model);

Let's train the model!

Here we move the data to the correct device, calculate a gradient on the loss function, and update the model based on our optimization function. We then check the model on our validation data and show the output.

In [None]:
for epoch in 1:epochs
    for (x, y) in train_loader
        x, y = x |> device, y |> device
        gs = Flux.gradient(() -> loss(x,y), ps)
        Flux.update!(opt, ps, gs)
    end

    validation_loss = 0f0
    for (x, y) in valid_loader
        x, y = x |> device, y |> device
        validation_loss += loss(x, y)
    end
    validation_loss /= length(valid_loader)
    @show validation_loss
end

So how well did we do? Here we calculate the overall accuracy of the model on the test set.

In [None]:
correct, total = 0, 0
for (x, y) in test_loader
    x, y = x |> device, y |> device
    correct += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))
    total += size(y, 2)
end
test_accuracy = correct / total;

# Print the final accuracy
@show test_accuracy

We also want to see how well we did on each individual class of the training data.

In [None]:
correct_pred = Dict(zip(classes,zeros(10)))
total_pred = Dict(zip(classes,zeros(10)))
        
for (x, y) in test_loader
    x, y = x |> device, y |> device
    outputs = cpu(model(x))
    predictions = mapslices(argmax,outputs,dims=1)
    lebels = mapslices(argmax,cpu(y),dims=1)
    for (label, prediction) in zip(lebels, predictions)
        if label == prediction
            correct_pred[classes[label]] += 1
        end
        total_pred[classes[label]] += 1
    end
end

In [None]:
for (classname, correct_count) in correct_pred
    accuracy = 100 * correct_count/total_pred[classname]
    println("Accuracy for class: $classname is $accuracy%.")
end

Lastly, let's check a single image and see what the model predicts.

In [None]:
colorview(RGB, permutedims(xtest[:,:,:,2], (3, 2, 1,)))

In [None]:
classes[ytest[:,2]]

In [None]:
testing = reshape(xtest[:,:,:,2],(32,32,3,1)) |> device;

In [None]:
classes[argmax(model(testing))]