In [54]:
using Flux, StatsBase, Random
using Flux: onehotbatch, onecold, logitcrossentropy, throttle, params
using Flux.Data: DataLoader

In [61]:
words = split(read("names.txt",String),"\n")
words[1:10]
shuffle!(words)

32033-element Vector{SubString{String}}:
 "nihaan"
 "khizar"
 "dajuan"
 "allianna"
 "talula"
 "panth"
 "ziqi"
 "malory"
 "kathaleia"
 "benicio"
 ⋮
 "jeffrey"
 "zofia"
 "mahelet"
 "journii"
 "sela"
 "emunah"
 "locklen"
 "yahsir"
 "laiyla"

In [62]:
# Create character embeddings. We're going to do this a little
# differently, so that we can have the same embeddings AK uses.
# I.e. the index of "." is 0
chars = ".abcdefghijklmnopqrstuvwxyz"
stoi = Dict( s => i for (i,s) in enumerate(chars))
itos = Dict( i => s for (i,s) in enumerate(chars))
vocab_size = length(chars)
itos

Dict{Int64, Char} with 27 entries:
  5  => 'd'
  16 => 'o'
  20 => 's'
  12 => 'k'
  24 => 'w'
  8  => 'g'
  17 => 'p'
  1  => '.'
  19 => 'r'
  22 => 'u'
  23 => 'v'
  6  => 'e'
  11 => 'j'
  9  => 'h'
  14 => 'm'
  3  => 'b'
  7  => 'f'
  25 => 'x'
  4  => 'c'
  ⋮  => ⋮

In [79]:
# Compile dataset for neural net:
block_size = 3 # context length: how many chars to we use to predict next one?
Xi,Y = [],[]

for w in words
    #println(w)
    context = ones(Int64,block_size)
    for ch in string(w,".")
        ix = stoi[ch]
        push!(Xi,context)
        push!(Y,ix)
        #println(join(itos[i] for i in context)," ---> ", itos[ix])
        context = vcat(context[2:end],[ix])
    end
end


# Make into a multidimensional array
nrows,ncols = length(Xi),length(Xi[1])
X = zeros(Int64,nrows,ncols)
for i in 1:nrows
    X[i,:] = Xi[i]
end

ntrial = nrows

size(X), size(Y)


((228146, 3), (228146,))

In [82]:
# break the code into training, development, and testing sets
n1 = 8*nrows÷10
n2 = 9*nrows÷10

# Ranges are
train = 1:n1
dev = n1:n2
test = n2:nrows

205331:228146

In [83]:
Xoh = reshape(onehotbatch(X[train,:]',1:27),81,:)
Yoh = onehotbatch(Y[train],1:27)

# If you don't want to use a smaller batchsize, you can just use an array
# for data:
#data = [(Xoh,Yoh)]
data = DataLoader((Xoh,Yoh), batchsize=32)


5704-element DataLoader(::Tuple{Base.ReshapedArray{Bool, 2, OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=32)
  with first element:
  (81×32 Matrix{Bool}, 27×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

In [84]:
loss(X,Y) = logitcrossentropy(model(X),Y)

n_hidden = 200

model = Chain(
    Dense(block_size*vocab_size => n_hidden, tanh),
    Dense(n_hidden => vocab_size),
    softmax
)

Chain(
  Dense(81 => 200, tanh),               [90m# 16_400 parameters[39m
  Dense(200 => 27),                     [90m# 5_427 parameters[39m
  NNlib.softmax,
) [90m                  # Total: 4 arrays, [39m21_827 parameters, 85.512 KiB.

In [85]:
rate = 3e-4
opt = Adam(rate)

Adam(0.0003, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

In [96]:
println("Loss $(loss(Xoh,Yoh))")
epochs=20
for epoch in 1:epochs
    Flux.train!(loss,params(model),data,opt)
    println("Loss $(loss(Xoh,Yoh))")
end




Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342
Loss 2.3575342


In [97]:
out = []
context = ones(Int64,block_size)
while true
    Xoh = reshape(onehotbatch(context',1:vocab_size),block_size*vocab_size,:)
    Yoh = model(Xoh)
    ix = wsample(1:27,vec(Yoh))
    push!(out,itos[ix])
    context = vcat(context[2:end],[ix])
    @show context
    if ix == 1 break end
end
join(out)


context = [1, 1, 2]
context = [1, 2, 19]
context = [2, 19, 10]
context = [19, 10, 2]
context = [10, 2, 1]


"aria."

# Debugging MLP
Okay, we're getting some bad results from the code. For one thing, everything starts with "a". I'm also not getting much randomness for the successive guesses.

I'm going to make some tools to see whether I can figure out what's going on.

In [108]:
function next_letter_probs(input)
    # Usage: next_letter_probs("..a")
    # output a dict something like "n" => 0.8, "r" => 0.1, ...

    # Force input to be 3 chars:
    context = [stoi[input[i]] for i in 1:block_size] 
    X = reshape(onehotbatch(context',1:vocab_size),block_size*vocab_size,:)
    Y = vec(model(X))
    return letterprobs(Y)
    #out = [(round(p,digits=3),itos[i]) for (i,p) in enumerate(Y) if p>1e-4 ]
    #return reverse(sort(out))
end

letterprobs(Y) = reverse(sort([(round(p,digits=3),itos[i]) for (i,p) in enumerate(Y) if p>1e-4 ]))

letterprobs (generic function with 1 method)

In [109]:
next_letter_probs("...")

1-element Vector{Tuple{Float32, Char}}:
 (1.0, 'a')

In [110]:
next_letter_probs("..a")

3-element Vector{Tuple{Float32, Char}}:
 (0.838, 'r')
 (0.161, 'l')
 (0.0, 'a')

In [112]:
for i in 1:20
    context = X[i,:]
    Xi = reshape(onehotbatch(context',1:vocab_size),block_size*vocab_size,:)
    Yi = vec(model(Xi))
    @show context
    @show letterprobs(Yi)
    @show itos[Y[i]]
end

context = [1, 1, 1]
letterprobs(Yi) = Tuple{Float32, Char}[(1.0, 'a')]
itos[Y[i]] = 'n'
context = [1, 1, 15]
letterprobs(Yi) = Tuple{Float32, Char}[(1.0, 'a')]
itos[Y[i]] = 'i'
context = [1, 15, 10]
letterprobs(Yi) = Tuple{Float32, Char}[(0.999, 'a'), (0.001, 'e')]
itos[Y[i]] = 'h'
context = [15, 10, 9]
letterprobs(Yi) = Tuple{Float32, Char}[(1.0, 'a')]
itos[Y[i]] = 'a'
context = [10, 9, 2]
letterprobs(Yi) = Tuple{Float32, Char}[(0.996, 'n'), (0.004, '.')]
itos[Y[i]] = 'a'
context = [9, 2, 2]
letterprobs(Yi) = Tuple{Float32, Char}[(0.999, '.'), (0.001, 'l'), (0.0, 'r')]
itos[Y[i]] = 'n'
context = [2, 2, 15]
letterprobs(Yi) = Tuple{Float32, Char}[(0.999, '.'), (0.001, 'i')]
itos[Y[i]] = '.'
context = [1, 1, 1]
letterprobs(Yi) = Tuple{Float32, Char}[(1.0, 'a')]
itos[Y[i]] = 'k'
context = [1, 1, 12]
letterprobs(Yi) = Tuple{Float32, Char}[(1.0, 'a')]
itos[Y[i]] = 'h'
context = [1, 12, 9]
letterprobs(Yi) = Tuple{Float32, Char}[(1.0, 'a')]
itos[Y[i]] = 'i'
context = [12, 9, 10]
letterprobs(Y