In [4]:
using Flux
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition

In [7]:
# Download the dataset
isfile("input.txt") ||
  download("http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
           "input.txt")

true

In [16]:
?collect

search: [1mc[22m[1mo[22m[1ml[22m[1ml[22m[1me[22m[1mc[22m[1mt[22m [1mC[22m[1mo[22m[1ml[22m[1ml[22m[1me[22m[1mc[22m[1mt[22mions



```
collect(element_type, collection)
```

Return an `Array` with the given element type of all items in a collection or iterable. The result has the same shape and number of dimensions as `collection`.

```jldoctest
julia> collect(Float64, 1:2:5)
3-element Array{Float64,1}:
 1.0
 3.0
 5.0
```

```
collect(collection)
```

Return an `Array` of all items in a collection or iterator. For associative collections, returns `Pair{KeyType, ValType}`. If the argument is array-like or is an iterator with the `HasShape()` trait, the result will have the same shape and number of dimensions as the argument.

# Example

```jldoctest
julia> collect(1:2:13)
7-element Array{Int64,1}:
  1
  3
  5
  7
  9
 11
 13
```


In [21]:
text = collect(readstring("input.txt"))
alphabet = [unique(text)..., '_']

68-element Array{Char,1}:
 'F' 
 'i' 
 'r' 
 's' 
 't' 
 ' ' 
 'C' 
 'z' 
 'e' 
 'n' 
 ':' 
 '\n'
 'B' 
 ⋮   
 'J' 
 'G' 
 'K' 
 'Q' 
 '&' 
 'Z' 
 'X' 
 '3' 
 '$' 
 '[' 
 ']' 
 '_' 

In [20]:
for c in sort(unique(text))
    print(c)
end


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz

In [23]:
text = map(ch -> onehot(ch, alphabet), text)

4573338-element Array{Flux.OneHotVector,1}:
 Bool[true, false, false, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false] 
 Bool[false, true, false, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false] 
 Bool[false, false, true, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false] 
 Bool[false, false, false, true, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false] 
 Bool[false, false, false, false, true, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false] 
 Bool[false, false, false, false, false, true, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false] 
 Bool[false, false, false, false, false, false, true, false, f

In [26]:
stop = onehot('_', alphabet)

68-element Flux.OneHotVector:
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
     ⋮
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
  true

In [27]:
N = length(alphabet)
seqlen = 50
nbatch = 50

50

In [28]:
X_train = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
y_train = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen))

1830-element Array{Array{Flux.OneHotMatrix{Array{Flux.OneHotVector,1}},1},1}:
 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}[Bool[false false … false false; true false … false false; … ; false false … false false; false false … false false], Bool[false false … false false; false false … false false; … ; false false … false false; false false … false false], Bool[false false … false false; false false … false false; … ; false false … false false; false false … false false], Bool[false false … false false; false false … false false; … ; false false … false false; false false … false false], Bool[false false … false false; false false … false false; … ; false false … false false; false false … false false], Bool[false false … false false; false false … false false; … ; false false … false false; false false … false false], Bool[false false … false false; true false … false false; … ; false false … false false; false false … false false], Bool[false false … false false; false false … false

In [36]:
model = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

Chain(Recur(LSTMCell(68, 128)), Recur(LSTMCell(128, 128)), Dense(128, 68), NNlib.softmax)

In [37]:
function loss(xs, ys)
    l = sum(crossentropy.(model.(xs), ys))
    Flux.truncate!(model)
    return l
end

loss (generic function with 1 method)

In [39]:
opt = ADAM(params(model), 0.01)

(::#71) (generic function with 1 method)

In [43]:
evalcb = () -> @show loss(X_train[5], y_train[5])

(::#7) (generic function with 1 method)

In [44]:
Flux.train!(loss, zip(X_train, y_train), opt,
            cb = throttle(evalcb, 30))

loss(X_train[5], y_train[5]) = param(180.317)
loss(X_train[5], y_train[5]) = param(166.509)
loss(X_train[5], y_train[5]) = param(156.282)
loss(X_train[5], y_train[5]) = param(136.673)
loss(X_train[5], y_train[5]) = param(127.755)
loss(X_train[5], y_train[5]) = param(123.901)
loss(X_train[5], y_train[5]) = param(119.881)
loss(X_train[5], y_train[5]) = param(117.02)
loss(X_train[5], y_train[5]) = param(113.634)
loss(X_train[5], y_train[5]) = param(112.524)
loss(X_train[5], y_train[5]) = param(109.624)
loss(X_train[5], y_train[5]) = param(108.195)
loss(X_train[5], y_train[5]) = param(106.429)
loss(X_train[5], y_train[5]) = param(105.297)
loss(X_train[5], y_train[5]) = param(103.111)
loss(X_train[5], y_train[5]) = param(102.975)
loss(X_train[5], y_train[5]) = param(102.287)
loss(X_train[5], y_train[5]) = param(100.954)
loss(X_train[5], y_train[5]) = param(99.997)
loss(X_train[5], y_train[5]) = param(99.0703)
loss(X_train[5], y_train[5]) = param(99.0709)
loss(X_train[5], y_train[5]) = param

LoadError: [91mInterruptException:[39m

In [51]:
X_train[1][1]

68×50 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
  true  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false   true     false   true   true  false  false
 false  false  false  false  false      true  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false   true  false  false  false     false  false  false  false   true
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  fals

In [48]:
y_train[1][1]

68×50 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false  false  false  …  false  false  false  false  false
  true  false  false  false   true     false  false  false  false  false
 false  false  false  false  false     false  false  false  false   true
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false   true  false  …  false   true   true  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  fals