In [2]:
#Exploring embeddings

using Random
Random.seed!(0)

include("readclassjson.jl")
data = readclassjson("multi_class.json")

v_test = data["v_test"]
v_train = data["v_train"]
U_test = data["U_test"]
U_train = data["U_train"]

300×30 Matrix{Float64}:
  0.679107   -0.965969     1.67644    …   0.156193   -0.9955       1.7369
  0.828413    0.00714434   1.63387        1.43461     2.71259      0.358765
 -0.353007    0.152515    -0.148072      -1.28234    -0.769532    -0.338117
 -0.134854    0.285965    -1.06838        0.88044     0.515654    -1.27669
  0.586617    0.374244    -1.31888        1.86744    -1.06407      0.66056
  0.297336   -0.220057    -1.14857    …  -0.504555    0.999665     1.4975
  0.0649475   1.82168     -0.671691       0.500804    1.52924      0.915797
 -0.109017    0.508609    -0.759928      -0.770549    1.16941      1.75845
 -0.51421     0.225921    -0.291091       0.625452    1.7092      -0.460404
  1.57433    -0.0557794    0.0674179      0.462086    1.93574     -2.31144
 -0.688907    0.290713    -0.40032    …   0.287635   -1.17088     -0.218587
 -0.762804   -2.84865      0.283924       1.21486    -2.37273     -1.16695
  0.397482    1.16911     -0.35135       -0.516151    0.00194606   0.3240

In [3]:
# The label set V is: {-1, 0, 1}.
# The number of labels K is 3.
print(unique(v_test))
print(unique(v_train))

[1, -1, 0][1, 0, -1]

In [4]:
function get_label_nums(v)
    labels = zeros(3)
    for i in 1:size(v, 1)
        if v[i] == -1
            labels[1] += 1
        elseif v[i] == 0
            labels[2] += 1
        elseif v[i] == 1
            labels[3] += 1
        end
    end
    return labels
end

labels_test = get_label_nums(v_test)
labels_train = get_label_nums(v_train)

3-element Vector{Float64}:
 118.0
  68.0
 114.0

In [5]:
print(labels_test)
print(labels_train)

[121.0, 65.0, 114.0][118.0, 68.0, 114.0]

In [6]:
using Flux: onehotbatch
using LinearAlgebra

# I use one-hot embedding for v.
oh_v_train = transpose(onehotbatch(v_train, -1:1))
oh_v_test = transpose(onehotbatch(v_test, -1:1))

300×3 transpose(OneHotMatrix(::Vector{UInt32})) with eltype Bool:
 0  0  1
 0  0  1
 1  0  0
 0  0  1
 1  0  0
 0  1  0
 0  0  1
 0  0  1
 1  0  0
 1  0  0
 0  0  1
 0  1  0
 0  1  0
 ⋮     
 0  1  0
 1  0  0
 0  0  1
 0  0  1
 0  0  1
 0  0  1
 0  1  0
 1  0  0
 0  1  0
 0  1  0
 1  0  0
 0  1  0

In [9]:
using Random
using Flux
using Flux: logitcrossentropy, onehot, onecold, params
using Plots

function multi_logistic(X, Y, reps)
    # data sizes
    n,d = size(X)
    m = size(Y,2)
    
    # linear predictor parameter
    theta = zeros(d,m)

    # predictor
    predicty(x) = theta'*x
    margin(pi, pj, y) = (2*dot(pi-pj,y) + dot(pj,pj) - dot(pi,pi)) / (2*norm(pi-pj) + 1e-10)
    multilogisticloss(yhat, y) = sum([exp(margin(r, y, yhat)) for r in reps])
    loss(x,y) = multilogisticloss(predicty(x), y)

    data = zip(eachrow(X), eachrow(Y))
    opt = ADAMW()
    Flux.@epochs 500 Flux.train!(loss, params(theta), data, opt)
    return predicty, theta
end

multi_logistic (generic function with 1 method)

In [46]:
k1 = [1, 0, 0]
k2 = [0, 1, 0]
k3 = [0, 0, 1]
reps = [k1, k2, k3]

predicty, theta = multi_logistic(U_train, oh_v_train, reps)

┌ Info: Epoch 1
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 2
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 3
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 4
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 5
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 6
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 7
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 8
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 9
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia

┌ Info: Epoch 104
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 105
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 106
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 107
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 108
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 109
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 110
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 111
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 112
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 113
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 114
└ @ Main C:\

┌ Info: Epoch 206
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 207
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 208
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 209
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 210
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 211
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 212
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 213
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 214
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 215
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 216
└ @ Main C:\

┌ Info: Epoch 308
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 309
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 310
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 311
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 312
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 313
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 314
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 315
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 316
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 317
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 318
└ @ Main C:\

┌ Info: Epoch 410
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 411
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 412
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 413
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 414
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 415
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 416
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 417
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 418
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 419
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 420
└ @ Main C:\

(predicty, [-1.824198380822305 -0.05759895870278575 1.8194574475582899; -1.510781069277297 0.045632395838767804 1.5583289618773746; … ; -0.20145143332158622 0.3109306094631418 -0.1743272726221315; -0.010596007200968241 0.04762540764898846 0.06349791591296973])

In [104]:
psiinv(yhat) = reps[argmin([norm(yhat-reps[k]) for k=1:length(reps)])]

function confusionMatrix(y_hat, y)
    cm = zeros(3, 3)
    for i in 1:size(y, 1)
        cm[argmax(y_hat[i, :]), argmax(y[i, :])] += 1
    end
    return cm
end

function accuracy(cm, y)
    count = 0
    for i in 1:size(cm, 1)
        count += cm[i, i]
    end
    return count / size(y, 1)
end

function unembed(v)
    v_hat = zeros(size(v, 1), 3)
    for i in 1:size(v, 1)
        y = psiinv(v[i, :])
        v_hat[i, 1] =y[1]
        v_hat[i, 2] =y[2]
        v_hat[i, 3] =y[3]
    end
    return v_hat
end

unembed (generic function with 1 method)

In [107]:
em_v_train_hat = U_train * theta
oh_v_train_hat = unembed(em_v_train_hat)
cm_train = confusionMatrix(oh_v_train_hat, oh_v_train)

em_v_test_hat = U_test * theta
oh_v_test_hat = unembed(em_v_test_hat)
cm_test = confusionMatrix(oh_v_test_hat, oh_v_test)

3×3 Matrix{Float64}:
 112.0  29.0    1.0
   8.0  16.0    7.0
   1.0  20.0  106.0

In [108]:
display(cm_train)
display(cm_test)
println(accuracy(cm_train, oh_v_train))
println(accuracy(cm_test, oh_v_test))

3×3 Matrix{Float64}:
 115.0  20.0    1.0
   3.0  29.0    1.0
   0.0  19.0  112.0

3×3 Matrix{Float64}:
 112.0  29.0    1.0
   8.0  16.0    7.0
   1.0  20.0  106.0

0.8533333333333334
0.78
