In [54]:
using Flux, StatsBase, Random
using Flux: onehotbatch, onecold, logitcrossentropy, throttle, params
using Flux.Data: DataLoader

In [55]:
words = split(read("names.txt",String),"\n")
words[1:10]
shuffle!(words)

32033-element Vector{SubString{String}}:
 "lynden"
 "laysha"
 "arina"
 "marguerite"
 "wilbert"
 "haydn"
 "arhaan"
 "jiliana"
 "miyani"
 "derin"
 ⋮
 "termaine"
 "honesti"
 "haru"
 "bryli"
 "audrie"
 "tavaris"
 "tylasia"
 "sebastiano"
 "mattison"

In [56]:
# Create character embeddings. We're going to do this a little
# differently, so that we can have the same embeddings AK uses.
# I.e. the index of "." is 0
chars = ".abcdefghijklmnopqrstuvwxyz"
stoi = Dict( s => i for (i,s) in enumerate(chars))
itos = Dict( i => s for (i,s) in enumerate(chars))
vocab_size = length(chars)
itos

Dict{Int64, Char} with 27 entries:
  5  => 'd'
  16 => 'o'
  20 => 's'
  12 => 'k'
  24 => 'w'
  8  => 'g'
  17 => 'p'
  1  => '.'
  19 => 'r'
  22 => 'u'
  23 => 'v'
  6  => 'e'
  11 => 'j'
  9  => 'h'
  14 => 'm'
  3  => 'b'
  7  => 'f'
  25 => 'x'
  4  => 'c'
  ⋮  => ⋮

In [57]:
# Compile dataset for neural net:
block_size = 3 # context length: how many chars to we use to predict next one?
Xi,Y = [],[]

for w in words
    #println(w)
    context = ones(Int64,block_size)
    for ch in string(w,".")
        ix = stoi[ch]
        push!(Xi,context)
        push!(Y,ix)
        #println(join(itos[i] for i in context)," ---> ", itos[ix])
        context = vcat(context[2:end],[ix])
    end
end


# Make into a multidimensional array
nrows,ncols = length(Xi),length(Xi[1])
X = zeros(Int64,nrows,ncols)
for i in 1:nrows
    X[i,:] = Xi[i]
end

ntrial = nrows

size(X), size(Y)


((228146, 3), (228146,))

In [58]:
# break the code into training, development, and testing sets
nw = length(words)
n1 = 8*nw÷10
n2 = 9*nw÷10

# Ranges are
train = 1:n1
dev = n1:n2
test = n2:nw

28829:32033

In [59]:
Xoh = reshape(onehotbatch(X[train,:]',1:27),81,:)
Yoh = onehotbatch(Y[train],1:27)

# If you don't want to use a smaller batchsize, you can just use an array
# for data:
#data = [(Xoh,Yoh)]
data = DataLoader((Xoh,Yoh), batchsize=32)


801-element DataLoader(::Tuple{Base.ReshapedArray{Bool, 2, OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=32)
  with first element:
  (81×32 Matrix{Bool}, 27×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

In [60]:
loss(X,Y) = logitcrossentropy(model(X),Y)

n_hidden = 200

model = Chain(
    Dense(block_size*vocab_size => n_hidden, tanh),
    Dense(n_hidden => vocab_size),
    softmax
)

Chain(
  Dense(81 => 200, tanh),               [90m# 16_400 parameters[39m
  Dense(200 => 27),                     [90m# 5_427 parameters[39m
  NNlib.softmax,
) [90m                  # Total: 4 arrays, [39m21_827 parameters, 85.512 KiB.

In [61]:
rate = 3e-4
opt = Adam(rate)

Adam(0.0003, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

In [62]:
println("Loss $(loss(Xoh,Yoh))")
epochs=200
for epoch in 1:epochs
    Flux.train!(loss,params(model),data,opt)
    println("Loss $(loss(Xoh,Yoh))")

end




Loss 3.2956135
Loss 3.1282215
Loss 3.106868
Loss 3.096329
Loss 3.0913324
Loss 3.0850065
Loss 3.0792758
Loss 3.0759058
Loss 3.0731313
Loss 3.0706923
Loss 3.0687437
Loss 3.067095
Loss 3.0656605
Loss 3.0642598
Loss 3.063037
Loss 3.0620217
Loss 3.0610952
Loss 3.0602465
Loss 3.059474
Loss 3.0587769
Loss 3.0581422
Loss 3.0575576
Loss 3.0570166
Loss 3.056521
Loss 3.056061
Loss 3.0556285
Loss 3.055218
Loss 3.0548246
Loss 3.0544417
Loss 3.0540574
Loss 3.0537083
Loss 3.053391
Loss 3.053094
Loss 3.0528162
Loss 3.0525553
Loss 3.0523062
Loss 3.0520637
Loss 3.0518239
Loss 3.0515828
Loss 3.0513391
Loss 3.0510952
Loss 3.0508602
Loss 3.0506365
Loss 3.0504265
Loss 3.05023
Loss 3.0500255
Loss 3.0496838
Loss 3.049415
Loss 3.0492072
Loss 3.0490181
Loss 3.0488
Loss 3.048576
Loss 3.0484085
Loss 3.0482569
Loss 3.0481145
Loss 3.0479786
Loss 3.0478482
Loss 3.0477204
Loss 3.0475938
Loss 3.0474637
Loss 3.0473228
Loss 3.0471864
Loss 3.047062
Loss 3.0469406
Loss 3.0468137
Loss 3.046692
Loss 3.0465796
Loss 3.0464728

In [66]:
out = []
context = ones(Int64,block_size)
while true
    Xoh = reshape(onehotbatch(context',1:vocab_size),block_size*vocab_size,:)
    Yoh = model(Xoh)
    ix = wsample(1:27,vec(Yoh))
    push!(out,itos[ix])
    context = vcat(context[2:end],[ix])
    if ix == 1 break end
end
join(out)


"alee."

In [67]:
size(Yoh)

(27, 1)