# Recurrent neural network

Character-level language model

<img src="http://karpathy.github.io/assets/rnn/charseq.jpeg">

In [1]:
using Flux
using Flux: @epochs, onehot, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition
# using CuArrays

## Download data

In [2]:
isfile("input.txt") ||
  download("http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", "input.txt")CuArrays

true

## Preprocessing

In [3]:
text = collect(String(read("input.txt")))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet);

In [4]:
N = length(alphabet)
seqlen = 50
nbatch = 50

50

In [5]:
Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

## Model

In [6]:
m = Chain(
    LSTM(N, 128),
    LSTM(128, 128),
    Dense(128, N),
    softmax)

Chain(Recur(LSTMCell(68, 128)), Recur(LSTMCell(128, 128)), Dense(128, 68), softmax)

## Loss function

In [7]:
function loss(xs, ys)
    l = sum(crossentropy.(m.(xs), ys))
    Flux.reset!(m)
    return l
end

loss (generic function with 1 method)

## Optimizer

In [8]:
opt = ADAM(0.01)
evalcb() = @show loss(Xs[5], Ys[5])

evalcb (generic function with 1 method)

## Training

In [9]:
@epochs 10 Flux.train!(loss, params(m), zip(Xs, Ys), opt, cb=throttle(evalcb, 30))

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 196.86156f0
loss(Xs[5], Ys[5]) = 170.97137f0
loss(Xs[5], Ys[5]) = 169.84167f0
loss(Xs[5], Ys[5]) = 168.2907f0
loss(Xs[5], Ys[5]) = 168.31401f0
loss(Xs[5], Ys[5]) = 168.37509f0
loss(Xs[5], Ys[5]) = 167.86536f0
loss(Xs[5], Ys[5]) = 168.34f0
loss(Xs[5], Ys[5]) = 167.99153f0
loss(Xs[5], Ys[5]) = 167.98141f0
loss(Xs[5], Ys[5]) = 167.77316f0
loss(Xs[5], Ys[5]) = 168.76381f0
loss(Xs[5], Ys[5]) = 168.41063f0
loss(Xs[5], Ys[5]) = 168.70741f0
loss(Xs[5], Ys[5]) = 168.80412f0
loss(Xs[5], Ys[5]) = 168.1425f0
loss(Xs[5], Ys[5]) = 168.09239f0
loss(Xs[5], Ys[5]) = 168.80937f0
loss(Xs[5], Ys[5]) = 168.13542f0
loss(Xs[5], Ys[5]) = 167.81508f0
loss(Xs[5], Ys[5]) = 167.99681f0
loss(Xs[5], Ys[5]) = 168.23538f0
loss(Xs[5], Ys[5]) = 168.37451f0
loss(Xs[5], Ys[5]) = 167.92474f0
loss(Xs[5], Ys[5]) = 168.17537f0
loss(Xs[5], Ys[5]) = 168.06491f0
loss(Xs[5], Ys[5]) = 167.89027f0
loss(Xs[5], Ys[5]) = 169.08005f0
loss(Xs[5], Ys[5]) = 168.41383f0
loss(Xs[5], Ys[5]) = 167.57692f0
loss(Xs[5], Ys[

┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 168.86028f0
loss(Xs[5], Ys[5]) = 167.71144f0
loss(Xs[5], Ys[5]) = 168.09505f0
loss(Xs[5], Ys[5]) = 168.1105f0
loss(Xs[5], Ys[5]) = 167.88829f0
loss(Xs[5], Ys[5]) = 167.9789f0
loss(Xs[5], Ys[5]) = 168.57298f0
loss(Xs[5], Ys[5]) = 169.61935f0
loss(Xs[5], Ys[5]) = 168.56584f0
loss(Xs[5], Ys[5]) = 168.31723f0
loss(Xs[5], Ys[5]) = 168.2282f0
loss(Xs[5], Ys[5]) = 168.36494f0
loss(Xs[5], Ys[5]) = 168.01097f0
loss(Xs[5], Ys[5]) = 167.95918f0
loss(Xs[5], Ys[5]) = 167.87367f0
loss(Xs[5], Ys[5]) = 168.13208f0
loss(Xs[5], Ys[5]) = 168.22087f0
loss(Xs[5], Ys[5]) = 168.4778f0
loss(Xs[5], Ys[5]) = 167.998f0
loss(Xs[5], Ys[5]) = 168.08685f0
loss(Xs[5], Ys[5]) = 169.03027f0
loss(Xs[5], Ys[5]) = 168.1971f0
loss(Xs[5], Ys[5]) = 168.80812f0
loss(Xs[5], Ys[5]) = 168.26886f0
loss(Xs[5], Ys[5]) = 168.06998f0
loss(Xs[5], Ys[5]) = 168.33766f0
loss(Xs[5], Ys[5]) = 167.91548f0
loss(Xs[5], Ys[5]) = 168.3306f0
loss(Xs[5], Ys[5]) = 168.75117f0
loss(Xs[5], Ys[5]) = 167.84532f0
loss(Xs[5], Ys[5])

┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 168.85965f0
loss(Xs[5], Ys[5]) = 169.08943f0
loss(Xs[5], Ys[5]) = 167.57013f0
loss(Xs[5], Ys[5]) = 168.09193f0
loss(Xs[5], Ys[5]) = 168.608f0
loss(Xs[5], Ys[5]) = 167.98546f0
loss(Xs[5], Ys[5]) = 167.77333f0
loss(Xs[5], Ys[5]) = 168.18277f0
loss(Xs[5], Ys[5]) = 168.55247f0
loss(Xs[5], Ys[5]) = 168.12564f0
loss(Xs[5], Ys[5]) = 167.95721f0
loss(Xs[5], Ys[5]) = 167.70203f0
loss(Xs[5], Ys[5]) = 168.66528f0
loss(Xs[5], Ys[5]) = 168.03331f0
loss(Xs[5], Ys[5]) = 168.69386f0
loss(Xs[5], Ys[5]) = 168.06055f0
loss(Xs[5], Ys[5]) = 167.77472f0
loss(Xs[5], Ys[5]) = 168.18085f0
loss(Xs[5], Ys[5]) = 168.06151f0
loss(Xs[5], Ys[5]) = 168.09793f0
loss(Xs[5], Ys[5]) = 167.93472f0
loss(Xs[5], Ys[5]) = 168.13292f0
loss(Xs[5], Ys[5]) = 167.89159f0
loss(Xs[5], Ys[5]) = 168.53117f0
loss(Xs[5], Ys[5]) = 168.07262f0
loss(Xs[5], Ys[5]) = 167.54462f0
loss(Xs[5], Ys[5]) = 168.16495f0
loss(Xs[5], Ys[5]) = 167.94301f0
loss(Xs[5], Ys[5]) = 168.0119f0
loss(Xs[5], Ys[5]) = 167.85399f0
loss(Xs[5], Y

┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 167.75409f0
loss(Xs[5], Ys[5]) = 166.96684f0
loss(Xs[5], Ys[5]) = 166.83955f0
loss(Xs[5], Ys[5]) = 166.24744f0
loss(Xs[5], Ys[5]) = 166.25064f0
loss(Xs[5], Ys[5]) = 166.3023f0
loss(Xs[5], Ys[5]) = 167.10817f0
loss(Xs[5], Ys[5]) = 166.7003f0
loss(Xs[5], Ys[5]) = 165.81107f0
loss(Xs[5], Ys[5]) = 166.7902f0
loss(Xs[5], Ys[5]) = 166.6875f0
loss(Xs[5], Ys[5]) = 166.05988f0
loss(Xs[5], Ys[5]) = 166.2498f0
loss(Xs[5], Ys[5]) = 167.28523f0
loss(Xs[5], Ys[5]) = 166.96425f0
loss(Xs[5], Ys[5]) = 167.23903f0
loss(Xs[5], Ys[5]) = 167.67125f0
loss(Xs[5], Ys[5]) = 166.71468f0
loss(Xs[5], Ys[5]) = 167.1204f0
loss(Xs[5], Ys[5]) = 166.32086f0
loss(Xs[5], Ys[5]) = 166.57033f0
loss(Xs[5], Ys[5]) = 162.51767f0
loss(Xs[5], Ys[5]) = 163.29483f0
loss(Xs[5], Ys[5]) = 160.952f0
loss(Xs[5], Ys[5]) = 158.83852f0
loss(Xs[5], Ys[5]) = 156.06624f0
loss(Xs[5], Ys[5]) = 158.57404f0
loss(Xs[5], Ys[5]) = 165.76793f0
loss(Xs[5], Ys[5]) = 167.9492f0
loss(Xs[5], Ys[5]) = 166.44243f0
loss(Xs[5], Ys[5]) 

┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 165.95465f0
loss(Xs[5], Ys[5]) = 165.1177f0
loss(Xs[5], Ys[5]) = 166.04535f0
loss(Xs[5], Ys[5]) = 161.67506f0
loss(Xs[5], Ys[5]) = 163.0157f0
loss(Xs[5], Ys[5]) = 160.49889f0
loss(Xs[5], Ys[5]) = 161.75159f0
loss(Xs[5], Ys[5]) = 162.60327f0
loss(Xs[5], Ys[5]) = 160.52194f0
loss(Xs[5], Ys[5]) = 156.73265f0
loss(Xs[5], Ys[5]) = 162.00941f0
loss(Xs[5], Ys[5]) = 159.91368f0
loss(Xs[5], Ys[5]) = 164.69412f0
loss(Xs[5], Ys[5]) = 160.67368f0
loss(Xs[5], Ys[5]) = 160.52255f0
loss(Xs[5], Ys[5]) = 160.85634f0
loss(Xs[5], Ys[5]) = 160.27924f0
loss(Xs[5], Ys[5]) = 162.32668f0
loss(Xs[5], Ys[5]) = 160.86368f0
loss(Xs[5], Ys[5]) = 159.58076f0
loss(Xs[5], Ys[5]) = 159.10443f0
loss(Xs[5], Ys[5]) = 156.72688f0
loss(Xs[5], Ys[5]) = 160.04004f0
loss(Xs[5], Ys[5]) = 159.62425f0
loss(Xs[5], Ys[5]) = 158.40779f0
loss(Xs[5], Ys[5]) = 161.83722f0
loss(Xs[5], Ys[5]) = 159.83101f0
loss(Xs[5], Ys[5]) = 160.95677f0
loss(Xs[5], Ys[5]) = 156.9528f0
loss(Xs[5], Ys[5]) = 162.33778f0
loss(Xs[5], Y

┌ Info: Epoch 6
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 153.9273f0
loss(Xs[5], Ys[5]) = 157.1446f0
loss(Xs[5], Ys[5]) = 159.2383f0
loss(Xs[5], Ys[5]) = 156.92119f0
loss(Xs[5], Ys[5]) = 156.89168f0
loss(Xs[5], Ys[5]) = 157.43753f0
loss(Xs[5], Ys[5]) = 159.29897f0
loss(Xs[5], Ys[5]) = 154.6009f0
loss(Xs[5], Ys[5]) = 158.611f0
loss(Xs[5], Ys[5]) = 158.73991f0
loss(Xs[5], Ys[5]) = 159.66612f0
loss(Xs[5], Ys[5]) = 159.44124f0
loss(Xs[5], Ys[5]) = 161.84943f0
loss(Xs[5], Ys[5]) = 165.27264f0
loss(Xs[5], Ys[5]) = 166.86641f0
loss(Xs[5], Ys[5]) = 165.97522f0
loss(Xs[5], Ys[5]) = 164.99353f0
loss(Xs[5], Ys[5]) = 165.46254f0
loss(Xs[5], Ys[5]) = 164.8418f0
loss(Xs[5], Ys[5]) = 161.4043f0
loss(Xs[5], Ys[5]) = 159.93138f0
loss(Xs[5], Ys[5]) = 160.40009f0
loss(Xs[5], Ys[5]) = 161.48247f0
loss(Xs[5], Ys[5]) = 160.57643f0
loss(Xs[5], Ys[5]) = 159.40736f0
loss(Xs[5], Ys[5]) = 160.96165f0
loss(Xs[5], Ys[5]) = 162.41617f0
loss(Xs[5], Ys[5]) = 164.19185f0
loss(Xs[5], Ys[5]) = 164.68063f0
loss(Xs[5], Ys[5]) = 163.5438f0
loss(Xs[5], Ys[5]) 

┌ Info: Epoch 7
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 164.85988f0
loss(Xs[5], Ys[5]) = 163.82465f0
loss(Xs[5], Ys[5]) = 164.01785f0
loss(Xs[5], Ys[5]) = 162.23225f0
loss(Xs[5], Ys[5]) = 161.79378f0
loss(Xs[5], Ys[5]) = 159.81462f0
loss(Xs[5], Ys[5]) = 160.57309f0
loss(Xs[5], Ys[5]) = 160.58485f0
loss(Xs[5], Ys[5]) = 159.5646f0
loss(Xs[5], Ys[5]) = 163.64143f0
loss(Xs[5], Ys[5]) = 166.32063f0
loss(Xs[5], Ys[5]) = 163.49265f0
loss(Xs[5], Ys[5]) = 165.03557f0
loss(Xs[5], Ys[5]) = 164.09793f0
loss(Xs[5], Ys[5]) = 165.29993f0
loss(Xs[5], Ys[5]) = 164.5456f0
loss(Xs[5], Ys[5]) = 163.16197f0
loss(Xs[5], Ys[5]) = 163.34367f0
loss(Xs[5], Ys[5]) = 162.01894f0
loss(Xs[5], Ys[5]) = 163.60205f0
loss(Xs[5], Ys[5]) = 164.60349f0
loss(Xs[5], Ys[5]) = 163.55638f0
loss(Xs[5], Ys[5]) = 164.69257f0
loss(Xs[5], Ys[5]) = 164.73477f0
loss(Xs[5], Ys[5]) = 164.4086f0
loss(Xs[5], Ys[5]) = 165.74472f0
loss(Xs[5], Ys[5]) = 163.0393f0
loss(Xs[5], Ys[5]) = 163.28673f0
loss(Xs[5], Ys[5]) = 164.5841f0
loss(Xs[5], Ys[5]) = 164.83788f0
loss(Xs[5], Ys[

┌ Info: Epoch 8
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 159.11238f0
loss(Xs[5], Ys[5]) = 156.66805f0
loss(Xs[5], Ys[5]) = 160.08514f0
loss(Xs[5], Ys[5]) = 159.78325f0
loss(Xs[5], Ys[5]) = 161.66917f0
loss(Xs[5], Ys[5]) = 158.90451f0
loss(Xs[5], Ys[5]) = 156.35321f0
loss(Xs[5], Ys[5]) = 159.9816f0
loss(Xs[5], Ys[5]) = 160.54498f0
loss(Xs[5], Ys[5]) = 155.75562f0
loss(Xs[5], Ys[5]) = 156.60582f0
loss(Xs[5], Ys[5]) = 155.89784f0
loss(Xs[5], Ys[5]) = 156.42815f0
loss(Xs[5], Ys[5]) = 157.31784f0
loss(Xs[5], Ys[5]) = 155.60796f0
loss(Xs[5], Ys[5]) = 155.92268f0
loss(Xs[5], Ys[5]) = 157.10999f0
loss(Xs[5], Ys[5]) = 155.6836f0
loss(Xs[5], Ys[5]) = 155.79846f0
loss(Xs[5], Ys[5]) = 156.40143f0
loss(Xs[5], Ys[5]) = 152.168f0
loss(Xs[5], Ys[5]) = 156.70766f0
loss(Xs[5], Ys[5]) = 154.47385f0
loss(Xs[5], Ys[5]) = 153.82005f0
loss(Xs[5], Ys[5]) = 155.04102f0
loss(Xs[5], Ys[5]) = 154.01598f0
loss(Xs[5], Ys[5]) = 153.91731f0
loss(Xs[5], Ys[5]) = 154.34875f0
loss(Xs[5], Ys[5]) = 157.37431f0
loss(Xs[5], Ys[5]) = 155.24422f0
loss(Xs[5], Ys

┌ Info: Epoch 9
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 158.10129f0
loss(Xs[5], Ys[5]) = 162.8939f0
loss(Xs[5], Ys[5]) = 160.13332f0
loss(Xs[5], Ys[5]) = 161.37737f0
loss(Xs[5], Ys[5]) = 163.46797f0
loss(Xs[5], Ys[5]) = 163.2606f0
loss(Xs[5], Ys[5]) = 169.87912f0
loss(Xs[5], Ys[5]) = 165.47482f0
loss(Xs[5], Ys[5]) = 165.43983f0
loss(Xs[5], Ys[5]) = 167.01282f0
loss(Xs[5], Ys[5]) = 166.0133f0
loss(Xs[5], Ys[5]) = 166.09746f0
loss(Xs[5], Ys[5]) = 164.38905f0
loss(Xs[5], Ys[5]) = 166.08128f0
loss(Xs[5], Ys[5]) = 164.15329f0
loss(Xs[5], Ys[5]) = 164.22894f0
loss(Xs[5], Ys[5]) = 162.99115f0
loss(Xs[5], Ys[5]) = 162.27116f0
loss(Xs[5], Ys[5]) = 164.93349f0
loss(Xs[5], Ys[5]) = 162.51253f0
loss(Xs[5], Ys[5]) = 164.2315f0
loss(Xs[5], Ys[5]) = 165.78026f0
loss(Xs[5], Ys[5]) = 163.5426f0
loss(Xs[5], Ys[5]) = 162.89192f0
loss(Xs[5], Ys[5]) = 161.86746f0
loss(Xs[5], Ys[5]) = 162.60977f0
loss(Xs[5], Ys[5]) = 162.72746f0
loss(Xs[5], Ys[5]) = 163.26643f0
loss(Xs[5], Ys[5]) = 163.12744f0
loss(Xs[5], Ys[5]) = 164.5717f0
loss(Xs[5], Ys[5

┌ Info: Epoch 10
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(Xs[5], Ys[5]) = 159.06087f0
loss(Xs[5], Ys[5]) = 157.6897f0
loss(Xs[5], Ys[5]) = 157.03465f0
loss(Xs[5], Ys[5]) = 158.1379f0
loss(Xs[5], Ys[5]) = 160.42711f0
loss(Xs[5], Ys[5]) = 160.628f0
loss(Xs[5], Ys[5]) = 160.90392f0
loss(Xs[5], Ys[5]) = 159.52496f0
loss(Xs[5], Ys[5]) = 160.11284f0
loss(Xs[5], Ys[5]) = 163.84242f0
loss(Xs[5], Ys[5]) = 161.8139f0
loss(Xs[5], Ys[5]) = 161.48235f0
loss(Xs[5], Ys[5]) = 161.15024f0
loss(Xs[5], Ys[5]) = 159.46169f0
loss(Xs[5], Ys[5]) = 158.91043f0
loss(Xs[5], Ys[5]) = 159.99918f0
loss(Xs[5], Ys[5]) = 159.62012f0
loss(Xs[5], Ys[5]) = 158.04338f0
loss(Xs[5], Ys[5]) = 159.36693f0
loss(Xs[5], Ys[5]) = 157.99722f0
loss(Xs[5], Ys[5]) = 156.91681f0
loss(Xs[5], Ys[5]) = 157.537f0
loss(Xs[5], Ys[5]) = 157.55412f0
loss(Xs[5], Ys[5]) = 160.57976f0
loss(Xs[5], Ys[5]) = 158.43204f0
loss(Xs[5], Ys[5]) = 160.33777f0
loss(Xs[5], Ys[5]) = 160.06976f0
loss(Xs[5], Ys[5]) = 160.64545f0
loss(Xs[5], Ys[5]) = 167.16983f0
loss(Xs[5], Ys[5]) = 166.75732f0
loss(Xs[5], Ys[5]

## Sampling

In [12]:
function sample(m, alphabet, len; temp=1)
    Flux.reset!(m)
    buf = IOBuffer()
    c = rand(alphabet)
    for i = 1:len
        write(buf, c)
        c = wsample(alphabet, m(onehot(c, alphabet)))
    end
    return String(take!(buf))
end

sample (generic function with 1 method)

In [13]:
sample(m, alphabet, 1000) |> println

it lp
iaafeiwaw 
ta rto
idnpabaie ay  B utnWnftier lYs
toydsa hhhy ,rs: te Ralieirreowd enrrn w.nl?g yiiwn gstn E
e arpi
,,bt doootrnaslub o,bG  l  gesd E oOtgnlrholnnnrp uLt ados  aueo 

   eetilyayu  ttl  mnlo?:slayp  yfst  d s
lb  iytmHeEl:
 ptunioofbA nndoeww ft:
JmVrtuyrCd n bytuotp no  oCrc
Fed  meto;rU
l rshl 'oelsnnd.  rswS ynd.yekwtmh
m renf.np
etm,wIliodem swetepanucm.
n sog mbL n yts wil.orndtsuwo hoy:w urld u y sebH
erefcrusLwrwno
m
tdyo
ederooS ,erHphw E
frrs,  tAlnon huwuys r hi aSue,Bs:ystPerh  h  uhnoTmwouP uuUsErry  ,tEshtt,cestm
ac dDn,uet smtEefsTlENGaR rBihm
e rd, y   b

ch he wSr   t noA
ir u dmepa ;e t d?  .svy  rusrhw oeMese seIUtilt:Cp  icpeo 'Fhvutrlt yyrAo auid mono'uY teE
tftutesos lAhkgirwl:m
t cl 
hm K e hpaBhekswfbe
HeAk,oGlittonet mu :t
rs,oeIwu?,ESHTrncttn lrm
nolrnrd.es jT :vt rebe r
lh,f G nlmnf :ryyidVks
n rfWtCarodmIs 
ve.uA wledosi  aeu:bo,AlN hwksooy
e
setmEnocno I:wturt  es .t
tyo?rmwertelinfL tolseogyA    . t,dfk 
ac fyynywursOod,ybprL e au:Stt,e