# Recurrent neural network

Character-level language model

<img src="http://karpathy.github.io/assets/rnn/charseq.jpeg">

In [1]:
using Flux
using Flux: @epochs, onehot, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition
using CuArrays

## Download data

In [2]:
isfile("input.txt") ||
  download("http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", "input.txt")CuArrays

true

## Preprocessing

In [3]:
text = collect(String(read("input.txt")))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet);

In [4]:
N = length(alphabet)
seqlen = 50
nbatch = 50

50

In [5]:
Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

## Model

In [6]:
m = Chain(
    LSTM(N, 128),
    LSTM(128, 128),
    Dense(128, N),
    softmax)

Chain(Recur(LSTMCell(68, 128)), Recur(LSTMCell(128, 128)), Dense(128, 68), softmax)

## Loss function

In [7]:
function loss(xs, ys)
    l = sum(crossentropy.(m.(xs), ys))
    Flux.reset!(m)
    return l
end

loss (generic function with 1 method)

## Optimizer

In [8]:
opt = ADAM(0.01)
evalcb() = @show loss(Xs[5], Ys[5])

evalcb (generic function with 1 method)

## Training

In [9]:
@epochs 10 Flux.train!(loss, params(m), zip(Xs, Ys), opt, cb=throttle(evalcb, 30))

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 196.29782f0
loss(Xs[5], Ys[5]) = 168.7379f0
loss(Xs[5], Ys[5]) = 167.81807f0
loss(Xs[5], Ys[5]) = 167.81956f0
loss(Xs[5], Ys[5]) = 168.42319f0
loss(Xs[5], Ys[5]) = 168.04958f0
loss(Xs[5], Ys[5]) = 168.03412f0
loss(Xs[5], Ys[5]) = 167.85292f0
loss(Xs[5], Ys[5]) = 168.15042f0
loss(Xs[5], Ys[5]) = 168.10759f0
loss(Xs[5], Ys[5]) = 168.03957f0
loss(Xs[5], Ys[5]) = 168.34824f0
loss(Xs[5], Ys[5]) = 167.94077f0
loss(Xs[5], Ys[5]) = 168.39522f0
loss(Xs[5], Ys[5]) = 167.70534f0
loss(Xs[5], Ys[5]) = 167.69283f0
loss(Xs[5], Ys[5]) = 168.48396f0
loss(Xs[5], Ys[5]) = 168.29532f0
loss(Xs[5], Ys[5]) = 168.53442f0
loss(Xs[5], Ys[5]) = 167.75645f0


┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 168.4532f0
loss(Xs[5], Ys[5]) = 167.8475f0
loss(Xs[5], Ys[5]) = 167.47672f0
loss(Xs[5], Ys[5]) = 167.34839f0
loss(Xs[5], Ys[5]) = 167.8128f0
loss(Xs[5], Ys[5]) = 167.47623f0
loss(Xs[5], Ys[5]) = 167.63483f0
loss(Xs[5], Ys[5]) = 167.7313f0
loss(Xs[5], Ys[5]) = 167.53249f0
loss(Xs[5], Ys[5]) = 167.68056f0
loss(Xs[5], Ys[5]) = 167.01086f0
loss(Xs[5], Ys[5]) = 166.74171f0
loss(Xs[5], Ys[5]) = 166.25209f0
loss(Xs[5], Ys[5]) = 165.73651f0
loss(Xs[5], Ys[5]) = 165.16168f0
loss(Xs[5], Ys[5]) = 164.9013f0
loss(Xs[5], Ys[5]) = 164.63634f0
loss(Xs[5], Ys[5]) = 164.3064f0
loss(Xs[5], Ys[5]) = 164.13612f0
loss(Xs[5], Ys[5]) = 162.87486f0
loss(Xs[5], Ys[5]) = 161.37503f0
loss(Xs[5], Ys[5]) = 158.08737f0
loss(Xs[5], Ys[5]) = 155.00578f0
loss(Xs[5], Ys[5]) = 152.86896f0


┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 152.41394f0
loss(Xs[5], Ys[5]) = 149.46277f0
loss(Xs[5], Ys[5]) = 147.84923f0
loss(Xs[5], Ys[5]) = 149.60577f0
loss(Xs[5], Ys[5]) = 145.96623f0
loss(Xs[5], Ys[5]) = 145.66422f0
loss(Xs[5], Ys[5]) = 147.59682f0
loss(Xs[5], Ys[5]) = 146.56323f0
loss(Xs[5], Ys[5]) = 150.19762f0
loss(Xs[5], Ys[5]) = 149.22984f0
loss(Xs[5], Ys[5]) = 148.08849f0
loss(Xs[5], Ys[5]) = 146.63953f0
loss(Xs[5], Ys[5]) = 146.9025f0
loss(Xs[5], Ys[5]) = 150.44743f0
loss(Xs[5], Ys[5]) = 145.1392f0
loss(Xs[5], Ys[5]) = 144.57774f0
loss(Xs[5], Ys[5]) = 144.30167f0
loss(Xs[5], Ys[5]) = 144.52188f0
loss(Xs[5], Ys[5]) = 142.05896f0
loss(Xs[5], Ys[5]) = 141.5173f0
loss(Xs[5], Ys[5]) = 140.72392f0
loss(Xs[5], Ys[5]) = 142.2347f0
loss(Xs[5], Ys[5]) = 170.7372f0
loss(Xs[5], Ys[5]) = 155.88963f0
loss(Xs[5], Ys[5]) = 165.90833f0


┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 166.26042f0
loss(Xs[5], Ys[5]) = 162.83278f0
loss(Xs[5], Ys[5]) = 161.90129f0
loss(Xs[5], Ys[5]) = 162.99934f0
loss(Xs[5], Ys[5]) = 202.80118f0
loss(Xs[5], Ys[5]) = 170.65627f0
loss(Xs[5], Ys[5]) = 168.18155f0
loss(Xs[5], Ys[5]) = 167.49104f0
loss(Xs[5], Ys[5]) = 167.1406f0
loss(Xs[5], Ys[5]) = 168.0301f0
loss(Xs[5], Ys[5]) = 168.35104f0
loss(Xs[5], Ys[5]) = 168.02083f0
loss(Xs[5], Ys[5]) = 167.62605f0
loss(Xs[5], Ys[5]) = 169.52359f0
loss(Xs[5], Ys[5]) = 167.58008f0
loss(Xs[5], Ys[5]) = 167.66658f0
loss(Xs[5], Ys[5]) = 169.04176f0
loss(Xs[5], Ys[5]) = 167.5872f0
loss(Xs[5], Ys[5]) = 167.88509f0
loss(Xs[5], Ys[5]) = 166.78961f0
loss(Xs[5], Ys[5]) = 165.6094f0
loss(Xs[5], Ys[5]) = 165.26924f0


┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 165.06433f0
loss(Xs[5], Ys[5]) = 163.28801f0
loss(Xs[5], Ys[5]) = 160.95926f0
loss(Xs[5], Ys[5]) = 162.40448f0
loss(Xs[5], Ys[5]) = 162.15633f0
loss(Xs[5], Ys[5]) = 167.62325f0
loss(Xs[5], Ys[5]) = 166.66963f0
loss(Xs[5], Ys[5]) = 166.52834f0
loss(Xs[5], Ys[5]) = 166.09325f0
loss(Xs[5], Ys[5]) = 166.96078f0
loss(Xs[5], Ys[5]) = 167.04877f0
loss(Xs[5], Ys[5]) = 167.08018f0
loss(Xs[5], Ys[5]) = 166.61067f0
loss(Xs[5], Ys[5]) = 161.5909f0
loss(Xs[5], Ys[5]) = 161.64499f0
loss(Xs[5], Ys[5]) = 161.21397f0
loss(Xs[5], Ys[5]) = 164.01053f0
loss(Xs[5], Ys[5]) = 163.63864f0
loss(Xs[5], Ys[5]) = 165.86086f0
loss(Xs[5], Ys[5]) = 167.78708f0
loss(Xs[5], Ys[5]) = 168.20183f0
loss(Xs[5], Ys[5]) = 167.1101f0


┌ Info: Epoch 6
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 166.51495f0
loss(Xs[5], Ys[5]) = 167.04291f0
loss(Xs[5], Ys[5]) = 168.38882f0
loss(Xs[5], Ys[5]) = 167.6848f0
loss(Xs[5], Ys[5]) = 168.72707f0
loss(Xs[5], Ys[5]) = 167.10767f0
loss(Xs[5], Ys[5]) = 167.67046f0
loss(Xs[5], Ys[5]) = 167.14622f0
loss(Xs[5], Ys[5]) = 167.44229f0
loss(Xs[5], Ys[5]) = 167.8726f0
loss(Xs[5], Ys[5]) = 166.87134f0
loss(Xs[5], Ys[5]) = 167.91068f0
loss(Xs[5], Ys[5]) = 167.68265f0
loss(Xs[5], Ys[5]) = 168.12315f0
loss(Xs[5], Ys[5]) = 167.16997f0
loss(Xs[5], Ys[5]) = 169.62457f0
loss(Xs[5], Ys[5]) = 167.96577f0
loss(Xs[5], Ys[5]) = 166.2087f0
loss(Xs[5], Ys[5]) = 164.58978f0
loss(Xs[5], Ys[5]) = 164.94624f0
loss(Xs[5], Ys[5]) = 165.65279f0
loss(Xs[5], Ys[5]) = 167.11296f0
loss(Xs[5], Ys[5]) = 165.43365f0


┌ Info: Epoch 7
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 164.7974f0
loss(Xs[5], Ys[5]) = 166.50026f0
loss(Xs[5], Ys[5]) = 165.60341f0
loss(Xs[5], Ys[5]) = 165.49095f0
loss(Xs[5], Ys[5]) = 163.27304f0
loss(Xs[5], Ys[5]) = 160.1047f0
loss(Xs[5], Ys[5]) = 160.71657f0
loss(Xs[5], Ys[5]) = 156.91574f0
loss(Xs[5], Ys[5]) = 159.81125f0
loss(Xs[5], Ys[5]) = 156.89055f0
loss(Xs[5], Ys[5]) = 165.76819f0
loss(Xs[5], Ys[5]) = 162.60904f0
loss(Xs[5], Ys[5]) = 163.01083f0
loss(Xs[5], Ys[5]) = 163.61075f0
loss(Xs[5], Ys[5]) = 163.47795f0
loss(Xs[5], Ys[5]) = 167.41824f0
loss(Xs[5], Ys[5]) = 165.51192f0
loss(Xs[5], Ys[5]) = 164.72421f0
loss(Xs[5], Ys[5]) = 167.58214f0
loss(Xs[5], Ys[5]) = 168.5141f0
loss(Xs[5], Ys[5]) = 171.08968f0
loss(Xs[5], Ys[5]) = 169.13731f0


┌ Info: Epoch 8
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 168.45049f0
loss(Xs[5], Ys[5]) = 168.30182f0
loss(Xs[5], Ys[5]) = 167.73198f0
loss(Xs[5], Ys[5]) = 168.2007f0
loss(Xs[5], Ys[5]) = 168.40344f0
loss(Xs[5], Ys[5]) = 167.99107f0
loss(Xs[5], Ys[5]) = 167.98953f0
loss(Xs[5], Ys[5]) = 167.33028f0
loss(Xs[5], Ys[5]) = 168.23798f0
loss(Xs[5], Ys[5]) = 167.58928f0
loss(Xs[5], Ys[5]) = 167.28317f0
loss(Xs[5], Ys[5]) = 167.6174f0
loss(Xs[5], Ys[5]) = 167.26097f0
loss(Xs[5], Ys[5]) = 167.39742f0
loss(Xs[5], Ys[5]) = 166.91249f0
loss(Xs[5], Ys[5]) = 167.04086f0
loss(Xs[5], Ys[5]) = 167.46112f0
loss(Xs[5], Ys[5]) = 166.72363f0
loss(Xs[5], Ys[5]) = 165.93913f0
loss(Xs[5], Ys[5]) = 166.25717f0
loss(Xs[5], Ys[5]) = 165.01721f0
loss(Xs[5], Ys[5]) = 164.52182f0
loss(Xs[5], Ys[5]) = 164.6081f0
loss(Xs[5], Ys[5]) = 165.50789f0
loss(Xs[5], Ys[5]) = 165.50307f0


┌ Info: Epoch 9
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 165.52336f0
loss(Xs[5], Ys[5]) = 165.59875f0
loss(Xs[5], Ys[5]) = 166.73637f0
loss(Xs[5], Ys[5]) = 166.0372f0
loss(Xs[5], Ys[5]) = 166.0334f0
loss(Xs[5], Ys[5]) = 166.50314f0
loss(Xs[5], Ys[5]) = 165.52846f0
loss(Xs[5], Ys[5]) = 165.41101f0
loss(Xs[5], Ys[5]) = 167.16405f0
loss(Xs[5], Ys[5]) = 168.03386f0
loss(Xs[5], Ys[5]) = 168.07646f0
loss(Xs[5], Ys[5]) = 167.95976f0
loss(Xs[5], Ys[5]) = 168.24438f0
loss(Xs[5], Ys[5]) = 167.3345f0
loss(Xs[5], Ys[5]) = 167.14984f0
loss(Xs[5], Ys[5]) = 166.40466f0
loss(Xs[5], Ys[5]) = 166.45137f0
loss(Xs[5], Ys[5]) = 166.18651f0
loss(Xs[5], Ys[5]) = 166.30598f0
loss(Xs[5], Ys[5]) = 165.27869f0
loss(Xs[5], Ys[5]) = 164.78f0
loss(Xs[5], Ys[5]) = 165.80664f0
loss(Xs[5], Ys[5]) = 166.43628f0


┌ Info: Epoch 10
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(Xs[5], Ys[5]) = 166.54372f0
loss(Xs[5], Ys[5]) = 166.34227f0
loss(Xs[5], Ys[5]) = 168.52495f0
loss(Xs[5], Ys[5]) = 168.01495f0
loss(Xs[5], Ys[5]) = 167.33472f0
loss(Xs[5], Ys[5]) = 167.37161f0
loss(Xs[5], Ys[5]) = 166.69885f0
loss(Xs[5], Ys[5]) = 166.59048f0
loss(Xs[5], Ys[5]) = 167.62424f0
loss(Xs[5], Ys[5]) = 166.84758f0
loss(Xs[5], Ys[5]) = 166.76443f0
loss(Xs[5], Ys[5]) = 167.15128f0
loss(Xs[5], Ys[5]) = 167.33199f0
loss(Xs[5], Ys[5]) = 166.72868f0
loss(Xs[5], Ys[5]) = 167.27429f0
loss(Xs[5], Ys[5]) = 166.69878f0
loss(Xs[5], Ys[5]) = 166.58568f0
loss(Xs[5], Ys[5]) = 166.77077f0
loss(Xs[5], Ys[5]) = 166.2737f0
loss(Xs[5], Ys[5]) = 166.35228f0
loss(Xs[5], Ys[5]) = 166.56717f0
loss(Xs[5], Ys[5]) = 167.08582f0
loss(Xs[5], Ys[5]) = 166.86716f0
loss(Xs[5], Ys[5]) = 166.65967f0


## Sampling

In [12]:
function sample(m, alphabet, len; temp=1)
    Flux.reset!(m)
    buf = IOBuffer()
    c = rand(alphabet)
    for i = 1:len
        write(buf, c)
        c = wsample(alphabet, m(onehot(c, alphabet)))
    end
    return String(take!(buf))
end

sample (generic function with 1 method)

In [13]:
sample(m, alphabet, 1000) |> println

it lp
iaafeiwaw 
ta rto
idnpabaie ay  B utnWnftier lYs
toydsa hhhy ,rs: te Ralieirreowd enrrn w.nl?g yiiwn gstn E
e arpi
,,bt doootrnaslub o,bG  l  gesd E oOtgnlrholnnnrp uLt ados  aueo 

   eetilyayu  ttl  mnlo?:slayp  yfst  d s
lb  iytmHeEl:
 ptunioofbA nndoeww ft:
JmVrtuyrCd n bytuotp no  oCrc
Fed  meto;rU
l rshl 'oelsnnd.  rswS ynd.yekwtmh
m renf.np
etm,wIliodem swetepanucm.
n sog mbL n yts wil.orndtsuwo hoy:w urld u y sebH
erefcrusLwrwno
m
tdyo
ederooS ,erHphw E
frrs,  tAlnon huwuys r hi aSue,Bs:ystPerh  h  uhnoTmwouP uuUsErry  ,tEshtt,cestm
ac dDn,uet smtEefsTlENGaR rBihm
e rd, y   b

ch he wSr   t noA
ir u dmepa ;e t d?  .svy  rusrhw oeMese seIUtilt:Cp  icpeo 'Fhvutrlt yyrAo auid mono'uY teE
tftutesos lAhkgirwl:m
t cl 
hm K e hpaBhekswfbe
HeAk,oGlittonet mu :t
rs,oeIwu?,ESHTrncttn lrm
nolrnrd.es jT :vt rebe r
lh,f G nlmnf :ryyidVks
n rfWtCarodmIs 
ve.uA wledosi  aeu:bo,AlN hwksooy
e
setmEnocno I:wturt  es .t
tyo?rmwertelinfL tolseogyA    . t,dfk 
ac fyynywursOod,ybprL e au:Stt,e