In [2]:
using Flux, StatsBase, Random
using Flux: onehotbatch, onecold, logitcrossentropy, throttle, params
using Flux.Data: DataLoader

In [3]:
words = split(read("names.txt",String),"\n")
words[1:10]
shuffle!(words)

32033-element Vector{SubString{String}}:
 "gift"
 "shabaz"
 "neilson"
 "tyriq"
 "jatavion"
 "azlan"
 "aubreerose"
 "binyamin"
 "waleed"
 "kahron"
 ⋮
 "lanston"
 "eniyah"
 "trygg"
 "katiya"
 "psalm"
 "oluwanifemi"
 "kumar"
 "gargi"
 "alyson"

In [4]:
# Create character embeddings. We're going to do this a little
# differently, so that we can have the same embeddings AK uses.
# I.e. the index of "." is 0
chars = ".abcdefghijklmnopqrstuvwxyz"
stoi = Dict( s => i for (i,s) in enumerate(chars))
itos = Dict( i => s for (i,s) in enumerate(chars))
vocab_size = length(chars)
itos

Dict{Int64, Char} with 27 entries:
  5  => 'd'
  16 => 'o'
  20 => 's'
  12 => 'k'
  24 => 'w'
  8  => 'g'
  17 => 'p'
  1  => '.'
  19 => 'r'
  22 => 'u'
  23 => 'v'
  6  => 'e'
  11 => 'j'
  9  => 'h'
  14 => 'm'
  3  => 'b'
  7  => 'f'
  25 => 'x'
  4  => 'c'
  ⋮  => ⋮

In [5]:
# Compile dataset for neural net:
block_size = 3 # context length: how many chars to we use to predict next one?
Xi,Y = [],[]

for w in words
    #println(w)
    context = ones(Int64,block_size)
    for ch in string(w,".")
        ix = stoi[ch]
        push!(Xi,context)
        push!(Y,ix)
        #println(join(itos[i] for i in context)," ---> ", itos[ix])
        context = vcat(context[2:end],[ix])
    end
end


# Make into a multidimensional array
nrows,ncols = length(Xi),length(Xi[1])
X = zeros(Int64,nrows,ncols)
for i in 1:nrows
    X[i,:] = Xi[i]
end

ntrial = nrows

size(X), size(Y)


((228146, 3), (228146,))

In [6]:
# break the code into training, development, and testing sets
nw = length(words)
n1 = 8*nw÷10
n2 = 9*nw÷10

# Ranges are
train = 1:n1
dev = n1:n2
test = n2:nw

28829:32033

In [7]:
Xoh = reshape(onehotbatch(X[train,:]',1:27),81,:)
Yoh = onehotbatch(Y[train],1:27)

# If you don't want to use a smaller batchsize, you can just use an array
# for data:
#data = [(Xoh,Yoh)]
data = DataLoader((Xoh,Yoh), batchsize=32)


801-element DataLoader(::Tuple{Base.ReshapedArray{Bool, 2, OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=32)
  with first element:
  (81×32 Matrix{Bool}, 27×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

In [8]:
loss(X,Y) = logitcrossentropy(model(X),Y)

n_hidden = 200

model = Chain(
    Dense(block_size*vocab_size => n_hidden, tanh),
    Dense(n_hidden => vocab_size),
    softmax
)

Chain(
  Dense(81 => 200, tanh),               [90m# 16_400 parameters[39m
  Dense(200 => 27),                     [90m# 5_427 parameters[39m
  NNlib.softmax,
) [90m                  # Total: 4 arrays, [39m21_827 parameters, 85.512 KiB.

In [9]:
rate = 3e-4
opt = Adam(rate)

Adam(0.0003, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

In [10]:
println("Loss $(loss(Xoh,Yoh))")
epochs=200
for epoch in 1:epochs
    Flux.train!(loss,params(model),data,opt)
    println("Loss $(loss(Xoh,Yoh))")

end




Loss 3.296538
Loss 3.110311
Loss 3.099118
Loss 3.095139
Loss 3.0885603
Loss 3.0850244
Loss 3.0832145
Loss 3.0818505
Loss 3.0772765
Loss 3.0738816
Loss 3.0720341
Loss 3.0708532
Loss 3.0697448
Loss 3.0686243
Loss 3.0676422
Loss 3.06674
Loss 3.0656765
Loss 3.064275
Loss 3.0626867
Loss 3.0617065
Loss 3.0609453
Loss 3.060307
Loss 3.059738
Loss 3.0591815
Loss 3.0586722
Loss 3.0582416
Loss 3.057862
Loss 3.0575073
Loss 3.0571716
Loss 3.056859
Loss 3.0565648
Loss 3.056282
Loss 3.055629
Loss 3.0553412
Loss 3.055093
Loss 3.054848
Loss 3.054604
Loss 3.0543575
Loss 3.0541077
Loss 3.0538561
Loss 3.0536053
Loss 3.0533478
Loss 3.0530944
Loss 3.0528677
Loss 3.052651
Loss 3.0524318
Loss 3.0522127
Loss 3.0520155
Loss 3.0518444
Loss 3.051687
Loss 3.0515375
Loss 3.051393
Loss 3.0512514
Loss 3.0511117
Loss 3.050975
Loss 3.0508437
Loss 3.050718
Loss 3.0505974
Loss 3.0504816
Loss 3.0503697
Loss 3.0502613
Loss 3.0501559
Loss 3.0500536
Loss 3.0499535
Loss 3.0498533
Loss 3.0497463
Loss 3.0496042
Loss 3.0494883
L

In [13]:
out = []
context = ones(Int64,block_size)
while true
    Xoh = reshape(onehotbatch(context',1:vocab_size),block_size*vocab_size,:)
    Yoh = model(Xoh)
    ix = wsample(1:27,vec(Yoh))
    push!(out,itos[ix])
    context = vcat(context[2:end],[ix])
    if ix == 1 break end
end
join(out)


"ana."

In [67]:
size(Yoh)

(27, 1)