In [5]:
#Find top distinguishing words

using DataStructures: counter
using Flux

function multi_logistic(X, Y, reps; lambda=0)
    # data sizes
    n,d = size(X)
    m = size(Y,2)
    
    normsquared(x) = sum(x .* x)
    
    # linear predictor parameter
    theta = zeros(d,m)

    # predictor
    predicty(x) = theta'*x
    D(pi, pj, y) = (2*dot(pi-pj,y) + dot(pj,pj) - dot(pi,pi)) / (2*norm(pi-pj) + 1e-10)
    multilogisticloss(yhat, y) = log(sum([exp(D(r, y, yhat)) for r in reps]))
    loss(x,y) = multilogisticloss(predicty(x), y) + (lambda/2)*normsquared(theta)

    data = zip(eachrow(X), eachrow(Y))
    opt = ADAMW()
    Flux.@epochs 100 Flux.train!(loss, Flux.params(theta), data, opt)
    return predicty, theta
end

function TFIDF(U, words, freqs; k=2000)
    dict = Dict(w=>f for (w,f) in zip(words, freqs))
    top_words = sort(collect(keys(dict)), by=word->dict[word], rev=true)[1:k]
    n = length(U); X_tild = zeros(n, k);
    word_to_idx = Dict(word=>i for (i, word) in enumerate(top_words))
    for (i, u) in enumerate(U)
        frequencies = counter(split(u))
        top_words_in_doc = intersect(keys(frequencies), top_words)
        for word in top_words_in_doc
            X_tild[i, word_to_idx[word]] += frequencies[word]
        end
    end
    return docterm_to_TFIDF(X_tild), top_words
end

function standardize_plus_one(U,means,stds)
    Z = zeros(size(U))
    for i=1:size(U,2)
        if stds[i] != 0
            Z[:,i] = (U[:,i] .- means[i])/stds[i]
        else
            Z[:,i] = U[:,i] .- means[i]
        end
    end
    n = size(U,1)
    Z = [ones(n,1) Z]
    return Z
end

##################### AUXILLARY FUNCTIONS #####################

function docterm_to_TFIDF(A)
    n, d = size(A)
    Z = 1.0 * (A .!= 0)
    return diagm(1 ./ (A * ones(d))) * A * diagm(log.(Z' *ones(n)))
end

docterm_to_TFIDF (generic function with 1 method)

In [1]:
# Multi-class logistic loss with quadratic regulation

In [7]:
include("readclassjson.jl")
data = readclassjson("speeches.json")

U = data["U"]
V = data["V"]
freqs = data["freqs"]
words = data["words"]

102192-element Vector{String}:
 "miracles;"
 "Gynecology"
 "Nafis'"
 "advisers?"
 "\\\"Send"
 "ungoverned"
 "GEAR-UP,"
 "spying.\\\""
 "revictimized"
 "Hungarians."
 "backlogs,"
 "defenseless."
 "599"
 ⋮
 "forces:"
 "believes,"
 "nontraffic,"
 "me],"
 "research.\\\"."
 "B-e-a-v-e-r."
 "unattainable."
 "overtaxing"
 "[are"
 "Adam\\\""
 "airports,"
 "Nevadans"

In [8]:
using Statistics
using LinearAlgebra

U_TFIDF, top_words = TFIDF(U, words, freqs)



([0.5817347425854873 0.29086737129274365 … 0.0 0.003833543555160259; 0.6173591548675902 0.35277665992433727 … 0.0 0.0; … ; 0.45813330052905543 0.3296324967221253 … 0.0 0.0; 0.416053937697905 0.48617538787170916 … 0.0 0.0], ["the", "to", "of", "and", "a", "in", "that", "is", "I", "we"  …  "teaching", "shown", "priorities", "principles", "becoming", "targeted", "shouldn't", "Today", "need.", "infrastructure"])

In [9]:
n = size(U_TFIDF, 1)
means = zeros(n)
stds = zeros(n)
for i in 1:n
    means[i] = mean(U_TFIDF[i, :])
    stds[i] = std(U_TFIDF[i, :])
end

In [10]:
U_std = standardize_plus_one(U_TFIDF, means, stds)

3119×2001 Matrix{Float64}:
 1.0  25.9911  13.3165   12.2004   …  -0.139028   -0.151335    0.00920902
 1.0  27.5928  16.1862   13.8803      -0.139028   -0.151335   -0.165383
 1.0  31.4998  14.7508   17.0514      -0.139028   -0.151335   -0.165383
 1.0  28.8166  24.1791   13.3148      -0.139028   -0.151335   -0.165383
 1.0  23.2565  17.9435    9.8929      -0.139028   -0.151335   -0.165383
 1.0  36.3806   9.9506   14.2173   …  -0.139028   -0.151335   -0.165383
 1.0  22.6503  12.3532    6.94858     -0.139028   -0.151335   -0.165383
 1.0  15.8526  10.8427   11.4411      -0.139028   -0.151335   -0.165383
 1.0  22.0828  18.5226    7.0028      -0.139028   -0.151335   -0.165383
 1.0  41.3606  11.2499   14.0973      -0.139028   -0.151335   -0.165383
 1.0  25.5964  10.8999   17.8053   …  -0.139028   -0.151335   -0.165383
 1.0  17.9216   9.47844  17.9065      -0.139028   -0.151335   -0.165383
 1.0  24.7131  14.574    14.1987      -0.139028   -0.151335   -0.165383
 ⋮                                 

In [11]:
v_std = zeros(n)
for i in 1:n
    if V[i] == "D"
        v_std[i] = 0
    elseif V[i] == "R"
        v_std[i] = 1
    end
end

In [12]:
using Random
Random.seed!(0)

df = hcat(U_std, v_std)
df = df[shuffle(1:end), :]

3119×2002 Matrix{Float64}:
 1.0  22.4826  22.5337   3.48742  10.5332   …  -0.151335  -0.165383  1.0
 1.0  23.3572  10.3203  16.4145   14.6222      -0.151335  -0.165383  1.0
 1.0  29.302   13.4519   9.65295   9.84003     -0.151335  -0.165383  0.0
 1.0  21.4006  17.0329   7.30704  12.9417      -0.151335  -0.165383  1.0
 1.0  34.0642  11.4083  16.8947   15.6347      -0.151335  -0.165383  0.0
 1.0  23.703   21.1595  14.6929   17.8623   …  -0.151335  -0.165383  0.0
 1.0  20.7828  21.9565  12.1746   22.2227      -0.151335  -0.165383  0.0
 1.0  26.5758  15.5061  11.3367    8.19933     -0.151335  -0.165383  0.0
 1.0  23.3932  11.5857  10.1146   13.4336      -0.151335  -0.165383  0.0
 1.0  40.378   11.9086  17.6741    9.11075     -0.151335  -0.165383  1.0
 1.0  31.5052  16.669   15.6085   10.9069   …  -0.151335  -0.165383  0.0
 1.0  22.2758  15.7177  14.0709    7.93339     -0.151335  -0.165383  0.0
 1.0  20.6561  19.582    8.41979   8.28687     -0.151335  -0.165383  1.0
 ⋮                      

In [13]:
U_train = df[1:2495, 1:2001]
v_train = df[1:2495, 2002]

U_eval = df[2496:3119, 1:2001]
v_eval = df[2496:3119, 2002]

624-element Vector{Float64}:
 1.0
 0.0
 0.0
 0.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 0.0
 0.0
 ⋮
 1.0
 1.0
 0.0
 0.0
 1.0
 0.0
 0.0
 1.0
 1.0
 0.0
 1.0
 1.0

In [14]:
k1 = [0]
k2 = [1]
reps = [k1, k2]
lambdas = 10 .^ range(-2,2,length=5)

d = size(U_train, 2)
thetas = zeros(5, d)

for i in 1:5
    predicty, theta = multi_logistic(U_train, v_train, reps; lambda = lambdas[i])
    thetas[i, :] = theta
end

┌ Info: Epoch 1
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 2
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 3
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 4
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 5
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 6
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 7
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 8
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 9
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia

┌ Info: Epoch 4
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 5
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 6
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 7
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 8
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 9
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 12
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 13
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 14
└ @ Main C:\Users\surface\.ju

┌ Info: Epoch 7
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 8
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 9
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 12
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 13
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 14
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 15
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 16
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 17
└ @ Main C:\Users\surface\

┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 12
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 13
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 14
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 15
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 16
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 17
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 18
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 19
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 20
└ @ Main C:\Users\surfa

┌ Info: Epoch 13
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 14
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 15
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 16
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 17
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 18
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 19
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 20
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 21
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 22
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 23
└ @ Main C:\Users\surfa

In [15]:
function normalize(y)
    if y == -1
        return 0
    elseif y == 1
        return 1
    end
end

accuracy(a, b) = Statistics.mean(a .== b)
accuracies = zeros(5)

for i in 1:5
    theta = thetas[i, :]
    v_eval_hat = normalize.(sign.(U_eval * theta))
    accuracies[i] = accuracy(v_eval_hat, v_eval)
end

In [16]:
accuracies

5-element Vector{Float64}:
 0.9150641025641025
 0.8798076923076923
 0.8060897435897436
 0.6121794871794872
 0.5144230769230769

In [17]:
# Best accuracy
lambdas[1]

0.01

In [18]:
function confusionMatrix(y_hat, y)
    cm = zeros(2, 2)
    for i in 1:size(y, 1)
        row = Int(y_hat[i] + 1)
        col = Int(y[i] + 1)
        cm[row, col] += 1
    end
    return cm
end

theta = thetas[1, :]
v_eval_hat = normalize.(sign.(U_eval * theta))
cm = confusionMatrix(v_eval_hat, v_eval)

2×2 Matrix{Float64}:
 282.0   31.0
  22.0  289.0

In [19]:
new_theta = theta[2:2001]
indices = sortperm(new_theta, rev=true)

2000-element Vector{Int64}:
  361
 1188
  702
  675
  580
  573
  518
  298
  142
  204
  169
  224
  267
    ⋮
  538
  220
  111
 1560
  150
  607
  199
  251
 1095
  105
  371
  175

In [21]:
indexed_top_words = hcat(indices, top_words)
indexed_top_words[sortperm(indexed_top_words[:, 1]), :]

2000×2 Matrix{Any}:
    1  "Many"
    2  "evidence"
    3  "debate."
    4  "10"
    5  "message"
    6  "intelligence"
    7  "Just"
    8  "Mr."
    9  "increases"
   10  "education"
   11  "average"
   12  "old"
   13  "enforcement"
    ⋮  
 1989  "Under"
 1990  "known"
 1991  "continued"
 1992  "strategic"
 1993  "congressional"
 1994  "region"
 1995  "worked"
 1996  "show"
 1997  "Office"
 1998  "train"
 1999  "argument"
 2000  "measure"

In [22]:
# The 5 most useful words are: 
# "Many", "evidence", "debate.", "10", and "message".