In [1]:
using Pkg
using Flux
using MLDatasets

In [2]:
train_x, train_y = CIFAR10.traindata(Float32, 1:5000)
test_x, test_y = CIFAR10.testdata(Float32, 1:5000);

In [3]:
println("Size of each image: ", size(train_x))
println("Label of 50th training datapoint: ", train_y[50])
# So here we can see that each training point is a 3D array - a 32x32 image with 3 color channels

Size of each image: (32, 32, 3, 5000)
Label of 50th training datapoint: 0


In [4]:
# Since this is a multi-class classification problem, we can use one hot encoding, just like the MNIST dataset.
# There's 10 classes just like mnist, so we encode from 0 to 9

train_y, test_y = Flux.onehotbatch(train_y, 0:9), Flux.onehotbatch(test_y, 0:9)
nclasses = length(train_y[:,1])
println("number of classes: ", nclasses)

number of classes: 10


In [5]:
# 3 VGG Block
model_VGG3 = Chain(
              Conv((3,3), 3=>32, relu, pad=SamePad()),
              Conv((3,3), 32=>32, relu, pad=SamePad()),
              MaxPool((2,2)),
              Dropout(0.2),
              Conv((3,3), 32=>64, relu, pad=SamePad()),
              Conv((3,3), 64=>64, relu, pad=SamePad()),
              MaxPool((2,2)),
              Dropout(0.2),
              Conv((3,3), 64=>128, relu, pad=SamePad()),
              Conv((3,3), 128=>256, relu, pad=SamePad()),
              MaxPool((2,2)),
              Dropout(0.2),
              Flux.flatten,
              Dense(4096,128,relu),
              Dropout(0.2),
              Dense(128,10),
              softmax)

Chain(
  Conv((3, 3), 3 => 32, relu, pad=1),   [90m# 896 parameters[39m
  Conv((3, 3), 32 => 32, relu, pad=1),  [90m# 9_248 parameters[39m
  MaxPool((2, 2)),
  Dropout(0.2),
  Conv((3, 3), 32 => 64, relu, pad=1),  [90m# 18_496 parameters[39m
  Conv((3, 3), 64 => 64, relu, pad=1),  [90m# 36_928 parameters[39m
  MaxPool((2, 2)),
  Dropout(0.2),
  Conv((3, 3), 64 => 128, relu, pad=1),  [90m# 73_856 parameters[39m
  Conv((3, 3), 128 => 256, relu, pad=1),  [90m# 295_168 parameters[39m
  MaxPool((2, 2)),
  Dropout(0.2),
  Flux.flatten,
  Dense(4096, 128, relu),               [90m# 524_416 parameters[39m
  Dropout(0.2),
  Dense(128, 10),                       [90m# 1_290 parameters[39m
  NNlib.softmax,
)[90m                   # Total: 16 arrays, [39m960_298 parameters, 3.666 MiB.

In [6]:
function loss_and_accuracy(udata, wdata, model)

    ndata = size(udata,4)

    ŵ = model(udata)
    loss = Flux.crossentropy(ŵ, wdata; agg=sum)
    accuracy = sum(Flux.onecold(ŵ) .== Flux.onecold(wdata)) / ndata
    return loss, accuracy
end

loss_and_accuracy (generic function with 1 method)

In [7]:
batch_size = 128
train_loader = Flux.Data.DataLoader((train_x, train_y), batchsize=batch_size, shuffle=true);

In [8]:
function train(model,train_loader,optimizer,train_x,train_y,test_x,test_y,model_name)
    train_losses = []
    train_accuracy = []
    test_losses = []
    test_accuracy = []
    α = 0.001 # <- stepsize; in the ML community, it is often denoted as a `learning rate η`
    #opt = optimizer(α) 
    opt = optimizer 
    K = 20
    for k in 1:K
        for (u, w) in train_loader
            gs = gradient(() -> Flux.Losses.crossentropy(model(u), w), Flux.params(model)) # compute gradient
            Flux.Optimise.update!(opt, Flux.params(model), gs) # update parameters
        end
        println("Epoch $k for $model_name architecture.")
        train_loss, train_acc = loss_and_accuracy(train_x, train_y,  model)

        test_loss, test_acc = loss_and_accuracy(test_x, test_y, model)

        println("  train_loss = $train_loss, train_accuracy = $train_acc")
        println("  test_loss = $test_loss, test_accuracy = $test_acc")
        
        push!(test_losses, test_loss)
        push!(test_accuracy, test_acc)
        push!(train_losses, train_loss)
        push!(train_accuracy, train_acc)
    end
    return train_losses, train_accuracy, test_losses, test_accuracy
end

train (generic function with 1 method)

In [None]:
vgg3_train_loss, 
vgg3_train_accuracy, 
vgg3_test_loss, 
vgg3_test_accuracy = train(model_VGG3, train_loader, ADAM(0.001),train_x, train_y,test_x, test_y, "VGG3");

Epoch 1 for VGG3 architecture.
  train_loss = 10364.922, train_accuracy = 0.248
  test_loss = 10431.513, test_accuracy = 0.2352
