In [1]:
using Flux, Statistics
using Flux.Data: DataLoader
using Flux: onehotbatch, onecold, logitcrossentropy, throttle, @epochs
using Base.Iterators: repeated
using Parameters: @with_kw
using CUDAapi
using MLDatasets

┌ Info: Precompiling Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1260
│   caller = llvm_compat(::VersionNumber) at compatibility.jl:176
└ @ CUDAnative /home/kit/scc/yy3406/.julia/packages/CUDAnative/ierw8/src/compatibility.jl:176
┌ Info: Precompiling Parameters [d96e819e-fc66-5662-9728-84c9c7592b0a]
└ @ Base loading.jl:1260
┌ Info: Precompiling MLDatasets [eb30cadb-4394-5ae3-aed4-317e484a6458]
└ @ Base loading.jl:1260


In [2]:
if has_cuda()		# Check if CUDA is available
    @info "CUDA is on"
    import CUDA		# If CUDA is available, import CuArrays
    CUDA.allowscalar(false)
end

┌ Info: CUDA is on
└ @ Main In[2]:2


In [3]:
@with_kw mutable struct Args
    η::Float64 = 3e-4       # learning rate
    batchsize::Int = 1024   # batch size
    epochs::Int = 10        # number of epochs
    device::Function = gpu  # set as gpu, if gpu available
end

Args

In [4]:
function getdata(args)
    # Loading Dataset	
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)
	
    # Reshape Data for flatten the each image into linear array
    xtrain = Flux.flatten(xtrain)
    xtest = Flux.flatten(xtest)

    # One-hot-encode the labels
    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    # Batching
    train_data = DataLoader(xtrain, ytrain, batchsize=args.batchsize, shuffle=true)
    test_data = DataLoader(xtest, ytest, batchsize=args.batchsize)

    return train_data, test_data
end

getdata (generic function with 1 method)

In [5]:
function build_model(; imgsize=(28,28,1), nclasses=10)
    return Chain(
 	    Dense(prod(imgsize), 32, relu),
            Dense(32, nclasses))
end

function loss_all(dataloader, model)
    l = 0f0
    for (x,y) in dataloader
        l += logitcrossentropy(model(x), y)
    end
    l/length(dataloader)
end

function accuracy(data_loader, model)
    acc = 0
    for (x,y) in data_loader
        acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2)
    end
    acc/length(data_loader)
end

accuracy (generic function with 1 method)

In [6]:
# train
args = Args()

Args
  η: Float64 0.0003
  batchsize: Int64 1024
  epochs: Int64 10
  device: gpu (function of type typeof(gpu))


In [7]:
train_data, test_data = getdata(args)

This program has requested access to the data dependency MNIST.
which is not currently installed. It can be installed automatically, and you will not see this message again.

Dataset: THE MNIST DATABASE of handwritten digits
Authors: Yann LeCun, Corinna Cortes, Christopher J.C. Burges
Website: http://yann.lecun.com/exdb/mnist/

[LeCun et al., 1998a]
    Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner.
    "Gradient-based learning applied to document recognition."
    Proceedings of the IEEE, 86(11):2278-2324, November 1998

The files are available for download at the offical
website linked above. Note that using the data
responsibly and respecting copyright remains your
responsibility. The authors of MNIST aren't really
explicit about any terms of use, so please read the
website to make sure you want to download the
dataset.



Do you want to download the dataset from ["http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyt

(DataLoader((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0]), 1024, 60000, true, 60000, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  59991, 59992, 59993, 59994, 59995, 59996, 59997, 59998, 59999, 60000], true), DataLoader((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]), 1024, 10000, true, 10000, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998, 9999, 10000], false))

In [8]:
m = build_model()

Chain(Dense(784, 32, relu), Dense(32, 10))

In [9]:
# gpu
train_data = args.device.(train_data)
test_data = args.device.(train_data)
m = args.device(m)

Chain(Dense(784, 32, relu), Dense(32, 10))

In [10]:
loss(x,y) = logitcrossentropy(m(x), y)

loss (generic function with 1 method)

In [11]:
evalcb = () -> @show(loss_all(train_data, m))
opt = ADAM(args.η)

ADAM(0.0003, (0.9, 0.999), IdDict{Any,Any}())

In [15]:
@epochs args.epochs Flux.train!(loss, params(m), train_data, opt, cb = evalcb)

loss_all(train_data, m) = 0.32577655f0
loss_all(train_data, m) = 

┌ Info: Epoch 1
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


0.3254825f0
loss_all(train_data, m) = 0.32524538f0
loss_all(train_data, m) = 0.32504395f0
loss_all(train_data, m) = 0.32484505f0
loss_all(train_data, m) = 0.32463807f0
loss_all(train_data, m) = 0.32442987f0
loss_all(train_data, m) = 0.32423598f0
loss_all(train_data, m) = 0.32404175f0
loss_all(train_data, m) = 0.32383552f0
loss_all(train_data, m) = 0.3236351f0
loss_all(train_data, m) = 0.32341754f0
loss_all(train_data, m) = 0.3231955f0
loss_all(train_data, m) = 0.32297894f0
loss_all(train_data, m) = 0.32276115f0
loss_all(train_data, m) = 0.32256675f0
loss_all(train_data, m) = 0.32238612f0
loss_all(train_data, m) = 0.32217458f0
loss_all(train_data, m) = 0.32193252f0
loss_all(train_data, m) = 0.3216949f0
loss_all(train_data, m) = 0.32144767f0
loss_all(train_data, m) = 0.32121608f0
loss_all(train_data, m) = 0.3209867f0
loss_all(train_data, m) = 0.3207865f0
loss_all(train_data, m) = 0.3205859f0
loss_all(train_data, m) = 0.3204338f0
loss_all(train_data, m) = 0.3202806f0
loss_all(train_data, 

┌ Info: Epoch 2
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.3131164f0
loss_all(train_data, m) = 0.31294242f0
loss_all(train_data, m) = 0.3127759f0
loss_all(train_data, m) = 0.31260094f0
loss_all(train_data, m) = 0.3124247f0
loss_all(train_data, m) = 0.31226325f0
loss_all(train_data, m) = 0.3121023f0
loss_all(train_data, m) = 0.31192937f0
loss_all(train_data, m) = 0.3117596f0
loss_all(train_data, m) = 0.31157193f0
loss_all(train_data, m) = 0.31137815f0
loss_all(train_data, m) = 0.31118965f0
loss_all(train_data, m) = 0.3110006f0
loss_all(train_data, m) = 0.3108375f0
loss_all(train_data, m) = 0.31068873f0
loss_all(train_data, m) = 0.31051084f0
loss_all(train_data, m) = 0.31030416f0
loss_all(train_data, m) = 0.31009975f0
loss_all(train_data, m) = 0.30988458f0
loss_all(train_data, m) = 0.30968168f0
loss_all(train_data, m) = 0.309481f0
loss_all(train_data, m) = 0.30930966f0
loss_all(train_data, m) = 0.30913758f0
loss_all(train_data, m) = 0.30901405f0
loss_all(train_data, m) = 0.30888945f0
loss_all(train_data, m) = 0.308738

┌ Info: Epoch 3
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.30313393f0
loss_all(train_data, m) = 0.30288002f0
loss_all(train_data, m) = 0.3026853f0
loss_all(train_data, m) = 0.3025308f0
loss_all(train_data, m) = 0.30238757f0
loss_all(train_data, m) = 0.30223632f0
loss_all(train_data, m) = 0.3020852f0
loss_all(train_data, m) = 0.30194923f0
loss_all(train_data, m) = 0.3018139f0
loss_all(train_data, m) = 0.30166647f0
loss_all(train_data, m) = 0.30151942f0
loss_all(train_data, m) = 0.30135244f0
loss_all(train_data, m) = 0.30117783f0
loss_all(train_data, m) = 0.30100793f0
loss_all(train_data, m) = 0.300838f0
loss_all(train_data, m) = 0.30069512f0
loss_all(train_data, m) = 0.3005693f0
loss_all(train_data, m) = 0.30041784f0
loss_all(train_data, m) = 0.300241f0
loss_all(train_data, m) = 0.30006266f0
loss_all(train_data, m) = 0.2998715f0
loss_all(train_data, m) = 0.29969f0
loss_all(train_data, m) = 0.29950926f0
loss_all(train_data, m) = 0.29935655f0
loss_all(train_data, m) = 0.29920182f0
loss_all(train_data, m) = 0.2990957f0


┌ Info: Epoch 4
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.2932384f0
loss_all(train_data, m) = 0.29310298f0
loss_all(train_data, m) = 0.29296806f0
loss_all(train_data, m) = 0.29284838f0
loss_all(train_data, m) = 0.29273024f0
loss_all(train_data, m) = 0.2925997f0
loss_all(train_data, m) = 0.29247057f0
loss_all(train_data, m) = 0.29232216f0
loss_all(train_data, m) = 0.2921647f0
loss_all(train_data, m) = 0.29201463f0
loss_all(train_data, m) = 0.29186514f0
loss_all(train_data, m) = 0.29174203f0
loss_all(train_data, m) = 0.2916343f0
loss_all(train_data, m) = 0.2915009f0
loss_all(train_data, m) = 0.2913437f0
loss_all(train_data, m) = 0.2911835f0
loss_all(train_data, m) = 0.2910086f0
loss_all(train_data, m) = 0.2908404f0
loss_all(train_data, m) = 0.2906738f0
loss_all(train_data, m) = 0.2905353f0
loss_all(train_data, m) = 0.29039526f0
loss_all(train_data, m) = 0.2903013f0
loss_all(train_data, m) = 0.2902096f0
loss_all(train_data, m) = 0.29009688f0
loss_all(train_data, m) = 0.28998992f0
loss_all(train_data, m) = 0.28985658f0

┌ Info: Epoch 5
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.28508064f0
loss_all(train_data, m) = 0.2849606f0
loss_all(train_data, m) = 0.28484133f0
loss_all(train_data, m) = 0.2847375f0
loss_all(train_data, m) = 0.28463548f0
loss_all(train_data, m) = 0.2845205f0
loss_all(train_data, m) = 0.2844059f0
loss_all(train_data, m) = 0.28427127f0
loss_all(train_data, m) = 0.2841263f0
loss_all(train_data, m) = 0.28398806f0
loss_all(train_data, m) = 0.2838508f0
loss_all(train_data, m) = 0.28374076f0
loss_all(train_data, m) = 0.2836469f0
loss_all(train_data, m) = 0.2835287f0
loss_all(train_data, m) = 0.28338793f0
loss_all(train_data, m) = 0.2832429f0
loss_all(train_data, m) = 0.28308165f0
loss_all(train_data, m) = 0.2829235f0
loss_all(train_data, m) = 0.28276864f0
loss_all(train_data, m) = 0.28264338f0
loss_all(train_data, m) = 0.28251764f0
loss_all(train_data, m) = 0.28243598f0
loss_all(train_data, m) = 0.28235632f0
loss_all(train_data, m) = 0.28225544f0
loss_all(train_data, m) = 0.2821596f0
loss_all(train_data, m) = 0.2820396f

┌ Info: Epoch 6
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.2776975f0
loss_all(train_data, m) = 0.2775894f0
loss_all(train_data, m) = 0.27748194f0
loss_all(train_data, m) = 0.27738988f0
loss_all(train_data, m) = 0.2772996f0
loss_all(train_data, m) = 0.27719527f0
loss_all(train_data, m) = 0.2770902f0
loss_all(train_data, m) = 0.27696466f0
loss_all(train_data, m) = 0.27682814f0
loss_all(train_data, m) = 0.27669784f0
loss_all(train_data, m) = 0.27657083f0
loss_all(train_data, m) = 0.2764711f0
loss_all(train_data, m) = 0.27638665f0
loss_all(train_data, m) = 0.27627912f0
loss_all(train_data, m) = 0.27615112f0
loss_all(train_data, m) = 0.27601856f0
loss_all(train_data, m) = 0.2758687f0
loss_all(train_data, m) = 0.27572173f0
loss_all(train_data, m) = 0.27557585f0
loss_all(train_data, m) = 0.2754574f0
loss_all(train_data, m) = 0.27533743f0
loss_all(train_data, m) = 0.2752597f0
loss_all(train_data, m) = 0.27518433f0
loss_all(train_data, m) = 0.27508935f0
loss_all(train_data, m) = 0.2750028f0
loss_all(train_data, m) = 0.274895

┌ Info: Epoch 7
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.27113208f0
loss_all(train_data, m) = 0.27102616f0
loss_all(train_data, m) = 0.27093738f0
loss_all(train_data, m) = 0.2708382f0
loss_all(train_data, m) = 0.27074063f0
loss_all(train_data, m) = 0.27065825f0
loss_all(train_data, m) = 0.2705779f0
loss_all(train_data, m) = 0.27048233f0
loss_all(train_data, m) = 0.27038577f0
loss_all(train_data, m) = 0.2702684f0
loss_all(train_data, m) = 0.27014044f0
loss_all(train_data, m) = 0.27001923f0
loss_all(train_data, m) = 0.26990172f0
loss_all(train_data, m) = 0.26981208f0
loss_all(train_data, m) = 0.26973796f0
loss_all(train_data, m) = 0.26964125f0
loss_all(train_data, m) = 0.2695244f0
loss_all(train_data, m) = 0.2694034f0
loss_all(train_data, m) = 0.2692631f0
loss_all(train_data, m) = 0.2691233f0
loss_all(train_data, m) = 0.26898304f0
loss_all(train_data, m) = 0.26886857f0
loss_all(train_data, m) = 0.26875237f0
loss_all(train_data, m) = 0.26867896f0
loss_all(train_data, m) = 0.2686078f0
loss_all(train_data, m) = 0.26851

┌ Info: Epoch 8
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.26467714f0
loss_all(train_data, m) = 0.2645845f0
loss_all(train_data, m) = 0.26449385f0
loss_all(train_data, m) = 0.26441872f0
loss_all(train_data, m) = 0.26434657f0
loss_all(train_data, m) = 0.26425812f0
loss_all(train_data, m) = 0.26416716f0
loss_all(train_data, m) = 0.2640564f0
loss_all(train_data, m) = 0.26393604f0
loss_all(train_data, m) = 0.26382244f0
loss_all(train_data, m) = 0.26371452f0
loss_all(train_data, m) = 0.26363558f0
loss_all(train_data, m) = 0.26356968f0
loss_all(train_data, m) = 0.2634819f0
loss_all(train_data, m) = 0.2633744f0
loss_all(train_data, m) = 0.26326185f0
loss_all(train_data, m) = 0.26312894f0
loss_all(train_data, m) = 0.26299366f0
loss_all(train_data, m) = 0.26285863f0
loss_all(train_data, m) = 0.26275012f0
loss_all(train_data, m) = 0.2626394f0
loss_all(train_data, m) = 0.26257092f0
loss_all(train_data, m) = 0.2625065f0
loss_all(train_data, m) = 0.26242608f0
loss_all(train_data, m) = 0.2623549f0
loss_all(train_data, m) = 0.2622

┌ Info: Epoch 9
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


0.25913054f0
loss_all(train_data, m) = 0.25898522f0
loss_all(train_data, m) = 0.2588872f0
loss_all(train_data, m) = 0.2588094f0
loss_all(train_data, m) = 0.25872284f0
loss_all(train_data, m) = 0.2586389f0
loss_all(train_data, m) = 0.2585711f0
loss_all(train_data, m) = 0.2585074f0
loss_all(train_data, m) = 0.2584269f0
loss_all(train_data, m) = 0.25834343f0
loss_all(train_data, m) = 0.25824094f0
loss_all(train_data, m) = 0.25812757f0
loss_all(train_data, m) = 0.25801992f0
loss_all(train_data, m) = 0.25791842f0
loss_all(train_data, m) = 0.25784728f0
loss_all(train_data, m) = 0.25778827f0
loss_all(train_data, m) = 0.2577075f0
loss_all(train_data, m) = 0.25760725f0
loss_all(train_data, m) = 0.25750124f0
loss_all(train_data, m) = 0.25737533f0
loss_all(train_data, m) = 0.25724575f0
loss_all(train_data, m) = 0.25711584f0
loss_all(train_data, m) = 0.2570121f0
loss_all(train_data, m) = 0.25690636f0
loss_all(train_data, m) = 0.25684226f0
loss_all(train_data, m) = 0.25678033f0
loss_all(train_data,

┌ Info: Epoch 10
└ @ Main /home/kit/scc/yy3406/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.25381568f0
loss_all(train_data, m) = 0.2536263f0
loss_all(train_data, m) = 0.25348666f0
loss_all(train_data, m) = 0.2533941f0
loss_all(train_data, m) = 0.2533196f0
loss_all(train_data, m) = 0.25323653f0
loss_all(train_data, m) = 0.25315607f0
loss_all(train_data, m) = 0.25309274f0
loss_all(train_data, m) = 0.25303492f0
loss_all(train_data, m) = 0.25295958f0
loss_all(train_data, m) = 0.25288126f0
loss_all(train_data, m) = 0.25278643f0
loss_all(train_data, m) = 0.25268146f0
loss_all(train_data, m) = 0.25258273f0
loss_all(train_data, m) = 0.2524908f0
loss_all(train_data, m) = 0.25242963f0
loss_all(train_data, m) = 0.25237763f0
loss_all(train_data, m) = 0.2523032f0
loss_all(train_data, m) = 0.2522109f0
loss_all(train_data, m) = 0.25211185f0
loss_all(train_data, m) = 0.25199163f0
loss_all(train_data, m) = 0.25186566f0
loss_all(train_data, m) = 0.2517384f0
loss_all(train_data, m) = 0.2516378f0
loss_all(train_data, m) = 0.25153473f0
loss_all(train_data, m) = 0.25147

In [13]:
@show accuracy(train_data, m)
@show accuracy(test_data, m)

accuracy(train_data, m) = 0.9108078724353257
accuracy(test_data, m) = 0.9108078724353257


0.9108078724353257

In [14]:
m

Chain(Dense(784, 32, relu), Dense(32, 10))