In [1]:
import Pkg

#Pkg.add("Zygote")

using Zygote
using Flux
using Flux: onehotbatch, onecold, crossentropy
using MLDatasets
#using CUDA
using Statistics

In [2]:
train_x, train_y = MNIST(split=:train)[:]
test_x, test_y = MNIST(split=:test)[:]

train_x = float.(train_x) ./ 255.0
test_x = float.(test_x) ./ 255.0

train_x = reshape(train_x, 28, 28, 1, :);
test_x = reshape(test_x, 28, 28, 1, :);

train_y = onehotbatch(train_y, 0:9);
test_y = onehotbatch(test_y, 0:9);


In [3]:
conv1 = Conv((3,3), 1=>16, relu)
pool1 = MaxPool((2,2))
conv2 = Conv((3,3), 16=>32, relu)
pool2 = MaxPool((2,2))
conv3 = Conv((3,3), 32=>64, relu)
pool3 = MaxPool((2,2))
flatten = Flux.flatten
dense1 = Dense(64, 128, relu)
dense2 = Dense(128, 10)
softmax_layer = softmax

model = Chain(conv1, pool1, conv2, pool2, conv3, pool3, flatten, dense1, dense2, softmax_layer)

loss(x,y) = crossentropy(model(x), y)
opt = ADAM()

Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)

In [4]:
function visualize_layers(model, x)
    println("Input shape: ", size(x))
    x = conv1(x); println("After Conv1: ", size(x))
    x = pool1(x); println("After Pool1: ", size(x))
    x = conv2(x); println("After Conv2: ", size(x))
    x = pool2(x); println("After Pool2: ", size(x))
    x = flatten(x); println("After Flatten: ", size(x))
    x = dense1(x); println("After Dense1: ", size(x))
    x = dense2(x); println("After Dense2: ", size(x))
    x = softmax_layer(x); println("After Softmax: ", size(x))
end


visualize_layers (generic function with 1 method)

In [11]:
function train_model!(model, train_X, train_Y, opt, epochs, batch_size)
    data_loader = Flux.DataLoader((train_X, train_Y), batchsize=batch_size, shuffle=true)
    
    opt_state = Flux.setup(opt, model)  

    for epoch in 1:epochs
        for (x, y) in data_loader
            gs = Flux.gradient(model -> Flux.Losses.mse(model(x), y), model)[1]  
            Flux.update!(opt_state, Flux.trainable(model), gs)  
        end
        println("Epoch $epoch complete")
    end
end


train_model! (generic function with 1 method)

In [12]:
train_model!(model, train_x, train_y, opt, 5, 64)



Epoch 1 complete
Epoch 2 complete
Epoch 3 complete


In [None]:
visualize_layers(model, train_x[:, :, :, 1:1])

accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
println("Test Accuracy: ", accuracy(test_x, test_y))
