In [1]:
using Flux
using Flux: Data.DataLoader
using Flux: onehotbatch, onecold, crossentropy
using Flux: @epochs
using Statistics
using MLDatasets
using CUDA

In [2]:
x_train, y_train = MLDatasets.MNIST.traindata();
x_valid, y_valid = MLDatasets.MNIST.testdata();

In [3]:
size(x_train)

(28, 28, 60000)

In [4]:
x_train = Flux.unsqueeze(x_train, 3) |> gpu;
x_valid = Flux.unsqueeze(x_valid, 3) |> gpu;

In [5]:
size(x_train)

(28, 28, 1, 60000)

In [6]:
size(y_train)

(60000,)

In [7]:
y_train = onehotbatch(y_train, 0:9) |> gpu;
y_valid = onehotbatch(y_valid, 0:9) |> gpu;

In [8]:
size(y_train)

(10, 60000)

In [11]:
train_data = DataLoader(x_train, y_train, batchsize=128) 

DataLoader{Tuple{CuArray{FixedPointNumbers.Normed{UInt8,8},4},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1}}}}((FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

...

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0

In [12]:
model = Chain(
    Conv((5, 5), 1=>8, pad=2, stride=2, relu),
    Conv((3, 3), 8=>16, pad=1, stride=2, relu),
    Conv((3, 3), 16=>32, pad=1, stride=2, relu),
    Conv((3, 3), 32=>32, pad=1, stride=2, relu),
    GlobalMeanPool(),
    flatten,
    Dense(32, 10),
    softmax) |> gpu

Chain(Conv((5, 5), 1=>8, relu), Conv((3, 3), 8=>16, relu), Conv((3, 3), 16=>32, relu), Conv((3, 3), 32=>32, relu), GlobalMeanPool(), flatten, Dense(32, 10), softmax)

In [13]:
accuracy(y_pred, y) = mean(onecold(y_pred) == onecold(y))
loss(x, y) = Flux.crossentropy(model(x), y)

lr = 0.1
opt = Descent(lr)

ps = Flux.params(model)

Params([Float32[-0.07242501 0.106632955 … -0.00965152 0.13624597; -0.12433287 -0.15191491 … -0.04233487 0.14865729; … ; -0.12237478 0.012550827 … -0.050722957 0.049745645; 0.12603006 -0.061818756 … -0.10731161 -0.08992784]

Float32[-0.047525767 0.030979576 … -0.1338092 0.090883695; -0.12596217 -0.10311656 … -0.051207602 -0.05439973; … ; -0.09695391 0.032853723 … 0.11205037 0.058808025; -0.026465693 0.1387455 … -0.028544435 0.05481044]

Float32[0.14928265 -0.14530586 … -0.063804135 -0.15798111; 0.059205495 -0.068303265 … -0.06501723 -0.15423535; … ; 0.04067046 0.10456882 … 0.11875832 -0.038632713; -0.050904464 -0.023226107 … 0.0365011 -0.1155586]

Float32[0.15825154 -0.15142621 … -0.082832694 -0.10593278; 0.05910606 -0.036107443 … 0.1293205 -0.11692408; … ; -0.1163609 0.10575773 … 0.14785857 -0.14466225; -0.105869934 -0.15496548 … 0.09441404 -0.15528807]

Float32[0.13836072 -0.107684284 … 0.12703848 0.057721775; 0.09273549 0.13758765 … -0.14257045 -0.14747955; … ; -0.15041947 -0.1032215

In [16]:
number_epochs = 10
@epochs number_epochs Flux.train!(loss, ps, train_data, opt)

accuracy(model(x), y_train)

┌ Info: Epoch 1
└ @ Main /home/prakhar/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136
│   yT = Float32
│   T1 = FixedPointNumbers.Normed{UInt8,8}
│   T2 = Float32
└ @ NNlib /home/prakhar/.julia/packages/NNlib/fxLrD/src/conv.jl:206


LoadError: InterruptException: