In [1]:
# Exploring probabilistic classification

# K*(d+1) = 3*(4+1) = 15 parameters are being trained.
# Cross entropy loss function is being used.

In [1]:
using MLDatasets, Flux, Distributions, DataFrames

function probabilistic_classification(X, Y; numiters=500)
    # data sizes
    n,d = size(X)
    K = size(Y, 2)

    W = randn(K, d)
    b = zeros(K)

    model = x -> Flux.softmax(W * x .+ b)

    opt = Descent(0.1)

    loss = (x, y) -> Flux.crossentropy(model(x), y)

    data = zip(eachrow(X), eachrow(Y))
    opt = ADAMW()
    Flux.@epochs 500 Flux.train!(loss, Flux.params(W, b), data, opt)
    return model
end

probabilistic_classification (generic function with 1 method)

In [19]:
using MLDatasets, LinearAlgebra
features = transpose(Iris.features())
labels = Iris.labels()

150-element Vector{String}:
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 ⋮
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"

In [20]:
oh_labels = zeros(150, 3)
for i in 1:150
    if labels[i] == "Iris-setosa"
        oh_labels[i, 1] = 1
    elseif labels[i] == "Iris-versicolor"
        oh_labels[i, 2] = 1
    elseif labels[i] == "Iris-virginica"
        oh_labels[i, 3] = 1   
    end
end

In [21]:
using Random
Random.seed!(0)

df = hcat(features, oh_labels)
df = df[shuffle(1:end), :]

150×7 Matrix{Float64}:
 6.0  2.2  5.0  1.5  0.0  0.0  1.0
 6.7  2.5  5.8  1.8  0.0  0.0  1.0
 5.0  3.4  1.6  0.4  1.0  0.0  0.0
 4.4  2.9  1.4  0.2  1.0  0.0  0.0
 5.0  2.0  3.5  1.0  0.0  1.0  0.0
 5.1  2.5  3.0  1.1  0.0  1.0  0.0
 6.0  2.9  4.5  1.5  0.0  1.0  0.0
 5.0  3.5  1.3  0.3  1.0  0.0  0.0
 5.1  3.8  1.9  0.4  1.0  0.0  0.0
 5.4  3.0  4.5  1.5  0.0  1.0  0.0
 5.1  3.3  1.7  0.5  1.0  0.0  0.0
 4.8  3.0  1.4  0.1  1.0  0.0  0.0
 6.5  3.0  5.5  1.8  0.0  0.0  1.0
 ⋮                        ⋮    
 6.3  2.5  5.0  1.9  0.0  0.0  1.0
 6.1  2.8  4.0  1.3  0.0  1.0  0.0
 6.8  2.8  4.8  1.4  0.0  1.0  0.0
 6.0  2.7  5.1  1.6  0.0  1.0  0.0
 6.1  3.0  4.6  1.4  0.0  1.0  0.0
 5.6  2.7  4.2  1.3  0.0  1.0  0.0
 6.2  2.8  4.8  1.8  0.0  0.0  1.0
 5.7  3.8  1.7  0.3  1.0  0.0  0.0
 5.4  3.9  1.7  0.4  1.0  0.0  0.0
 5.5  2.4  3.7  1.0  0.0  1.0  0.0
 5.6  3.0  4.5  1.5  0.0  1.0  0.0
 6.8  3.0  5.5  2.1  0.0  0.0  1.0

In [22]:
U_train = df[1:120, 1:4]
v_train = df[1:120, 5:7]

U_eval = df[121:150, 1:4]
v_eval = df[121:150, 5:7]

30×3 Matrix{Float64}:
 0.0  1.0  0.0
 1.0  0.0  0.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 1.0  0.0  0.0
 0.0  0.0  1.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 0.0  0.0  1.0
 1.0  0.0  0.0
 0.0  0.0  1.0
 ⋮         
 0.0  0.0  1.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 0.0  0.0  1.0
 1.0  0.0  0.0
 1.0  0.0  0.0
 0.0  1.0  0.0
 0.0  1.0  0.0
 0.0  0.0  1.0

In [23]:
model = probabilistic_classification(U_train, v_train)

┌ Info: Epoch 1
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 2
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 3
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 4
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 5
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 6
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 7
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 8
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 9
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia

┌ Info: Epoch 105
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 106
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 107
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 108
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 109
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 110
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 111
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 112
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 113
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 114
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 115
└ @ Main C:\

┌ Info: Epoch 207
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 208
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 209
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 210
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 211
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 212
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 213
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 214
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 215
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 216
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 217
└ @ Main C:\

┌ Info: Epoch 309
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 310
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 311
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 312
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 313
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 314
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 315
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 316
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 317
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 318
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 319
└ @ Main C:\

┌ Info: Epoch 411
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 412
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 413
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 414
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 415
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 416
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 417
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 418
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 419
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 420
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 421
└ @ Main C:\

#2 (generic function with 1 method)

In [44]:
function predictor(model,U)
    n = size(U, 1)
    probs = vcat([model(x) for x in eachrow(U)]...)
    vhat =  zeros(n, 3)
    
    for i in 1:n
        temp = probs[3*i-2:3*i]
        pred = argmax(temp)
        vhat[i, pred] = 1
    end
    
    return probs, vhat
end

predictor (generic function with 1 method)

In [48]:
probs_train, v_train_hat = predictor(model,U_train)
probs_eval, v_eval_hat = predictor(model,U_eval)

90-element Vector{Float64}:
 2.803450046562753e-7
 0.9490694699599238
 0.05093024969507154
 0.9999999453849769
 5.461502300294854e-8
 8.902959150322126e-18
 9.856205523367084e-8
 0.8932665301344727
 0.10673337130347203
 3.3011568757355245e-9
 0.5252729135448336
 0.4747270831540096
 9.874953155208098e-7
 ⋮
 0.9999934419910504
 6.558008921415385e-6
 2.802251457567139e-14
 3.6214391066948704e-5
 0.992317350250003
 0.007646435358930196
 1.216390823176709e-7
 0.7178486658001927
 0.282151212560725
 5.650814570915164e-13
 0.029772586185419933
 0.9702274138140149

In [49]:
using Statistics

accuracy(a, b) = Statistics.mean(a .== b)
train_acc = accuracy(v_train_hat, v_train)
eval_acc = accuracy(v_eval_hat, v_eval)

println(train_acc)
println(eval_acc)

0.9944444444444445
0.9555555555555556


In [54]:
function nll(probs, v)
    n = size(v, 1)
    L = 0
    
    for i in 1:n
        if v[i, 1] == 1
            index = 1
        elseif v[i, 2] == 1
            index = 2
        elseif v[i, 3] == 1
            index = 3
        end
        log_prob = log(probs[3*(i-1)+index])
        L += log_prob
    end
    return -L/n
end 

nll (generic function with 1 method)

In [56]:
nll_train = nll(probs_train, v_train)
nll_eval = nll(probs_eval, v_eval)

0.20549378388866377

In [57]:
# 2c

println(nll_train)
println(nll_eval)
# I want L to be small.

0.04875757913274955
0.20549378388866377
