In [3]:
@info "Loading libraries"
@info " Loading Flux"
using Flux
using Statistics
using Flux: onehotbatch, crossentropy, Momentum, update!, onecold
@info " Loading MLDatasets"
using MLDatasets: CIFAR10
using Base.Iterators: partition

batchsize = 1000
trainsize = 50000 - batchsize

@info "Loading training data"
# trainimgs = CIFAR10.traintensor(Float32);
trainimgs = CIFAR10(split=:train).features
# trainlabels = onehotbatch(CIFAR10.trainlabels(Float32) .+ 1, 1:10);
trainlabels = onehotbatch(CIFAR10(split=:train).targets .+ 1, 1:10);

@info "Building the trainset"
trainset = [(trainimgs[:,:,:,i], trainlabels[:,i]) for i in partition(1:trainsize, batchsize)];
batchnum = size(trainset)[1]

@info "Loading validation data"
valset = (trainsize+1):(trainsize+batchsize)
valX = trainimgs[:,:,:,valset] |> gpu;
valY = trainlabels[:, valset] |> gpu;

loss(x, y) = sum(crossentropy(m(x), y))
opt = Momentum(0.01)
max_pred(x) = [findmax(m(x[:,:,:,i:i]))[2][1] for i in 1:(size(x)[4])] |> gpu
max_lab(y) = [findmax(y[:,i])[2] for i in 1:(size(y)[2])] |> gpu
accuracy(x, y) = mean(max_pred(x) .== max_lab(y)) |> gpu

@info "Loading the model"
include("yiyu-resnet.jl")
m = ResNet([2,2,2,2], 10) |> gpu; #ResNet18

epochs = 10

for epoch = 1:epochs
        @info "epoch" epoch
        for i in 1:batchnum
                batch = trainset[i] |> gpu
                gs = gradient(Flux.params(m)) do
                        l = loss(batch...)
                end
                @info "batch fraction" i/batchnum
                update!(opt, Flux.params(m), gs)
        end
        @show accuracy(valX, valY)
end

@info "Loading test data"
# testimgs = CIFAR10.testtensor(Float32);
testimgs = CIFAR10(split=:test).features
# testlabels = onehotbatch(CIFAR10.testlabels(Float32) .+ 1, 1:10);
testlabels = onehotbatch(CIFAR10(split=:test).targets .+ 1, 1:10);

testset = [(testimgs[:,:,:,i], testlabels[:,i]) for i in partition(1:10000, batchsize)] |> gpu;

class_correct = zeros(10)
class_total = zeros(10)
for i in 1:(10000/batchsize)
  @info "Evaluating testset batch " i
  preds = m(testset[i][1])
  lab = testset[i][2]
  for j = 1:batchsize
    pred_class = findmax(preds[:, j])[2]
    actual_class = findmax(lab[:, j])[2]
    if pred_class == actual_class
      class_correct[pred_class] += 1
    end
    class_total[actual_class] += 1
  end
end

class_correct ./ class_total

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoading libraries
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m Loading Flux
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m Loading MLDatasets
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoading training data
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mBuilding the trainset
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoading validation data
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoading the model
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mepoch
[36m[1m└ [22m[39m  epoch = 1


LoadError: DomainError with -0.9952224:
log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).