In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.insert(0, "/notebooks/")
from fastai.imports import *

In [2]:
from fastai.io import *
from fastai.conv_learner import *

from fastai.column_data import *

In [3]:
PATH='/notebooks/data/nietzsche/'

In [4]:
get_data("https://s3.amazonaws.com/text-datasets/nietzsche.txt", f'{PATH}nietzsche.txt')
text = open(f'{PATH}nietzsche.txt').read()
print('corpus length:', len(text))

corpus length: 600893


In [5]:
text[:400]

'PREFACE\n\n\nSUPPOSING that Truth is a woman--what then? Is there not ground\nfor suspecting that all philosophers, in so far as they have been\ndogmatists, have failed to understand women--that the terrible\nseriousness and clumsy importunity with which they have usually paid\ntheir addresses to Truth, have been unskilled and unseemly methods for\nwinning a woman? Certainly she has never allowed herself '

In [6]:
chars = sorted(list(set(text)))
vocab_size = len(chars)+1
print('total chars:', vocab_size)

total chars: 85


In [7]:
chars.insert(0, "\0")

''.join(chars[1:-6])

'\n !"\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxy'

In [8]:
char_indices = {c: i for i, c in enumerate(chars)}
indices_char = {i: c for i, c in enumerate(chars)}

In [9]:
idx = [char_indices[c] for c in text]

idx[:10]

[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]

In [10]:
''.join(indices_char[i] for i in idx[:70])

'PREFACE\n\n\nSUPPOSING that Truth is a woman--what then? Is there not gro'

In [11]:
n_hidden = 256

In [12]:
n_fac = 42

In [14]:
cs=8

In [16]:
c_in_dat = [[idx[i+j] for i in range(cs)] for j in range(0, len(idx)-cs-1, cs)]

In [17]:
c_out_dat = [[idx[i+j] for i in range(cs)] for j in range(1, len(idx)-cs, cs)]

In [18]:
xs = np.stack(c_in_dat)
xs.shape

(75111, 8)

In [19]:
ys = np.stack(c_out_dat)
ys.shape

(75111, 8)

In [22]:
val_idx = get_cv_idxs(len(xs)-cs-1)

In [23]:
md = ColumnarModelData.from_arrays('.', val_idx, xs, ys, bs=512)

In [24]:
class CharSeqRnn(nn.Module):
    def __init__(self, vocab_size, n_fac):
        super().__init__()
        self.e = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.RNN(n_fac, n_hidden)
        self.l_out = nn.Linear(n_hidden, vocab_size)
        
    def forward(self, *cs):
        bs = cs[0].size(0)
        h = V(torch.zeros(1, bs, n_hidden))
        inp = self.e(torch.stack(cs))
        outp,h = self.rnn(inp, h)
        return F.log_softmax(self.l_out(outp), dim=-1)

In [25]:
m = CharSeqRnn(vocab_size, n_fac).cuda()
opt = optim.Adam(m.parameters(), 1e-3)

In [26]:
it = iter(md.trn_dl)
*xst,yt = next(it)

In [27]:
def nll_loss_seq(inp, targ):
    sl,bs,nh = inp.size()
    targ = targ.transpose(0,1).contiguous().view(-1)
    return F.nll_loss(inp.view(-1,nh), targ)

In [28]:
fit(m, md, 4, opt, nll_loss_seq)

HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))

epoch      trn_loss   val_loss   
    0      2.605278   2.415556  
    1      2.293733   2.20161   
    2      2.137433   2.085508  
    3      2.041875   2.012471  



[array([2.01247])]

In [29]:
set_lrs(opt, 1e-4)

In [30]:
fit(m, md, 1, opt, nll_loss_seq)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   
    0      1.995379   1.996222  



[array([1.99622])]

In [31]:
m = CharSeqRnn(vocab_size, n_fac).cuda()
opt = optim.Adam(m.parameters(), 1e-2)

In [33]:
fit(m, md, 4, opt, nll_loss_seq)

HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))

epoch      trn_loss   val_loss   
    0      2.330873   2.154481  
    1      2.056388   1.988048  
    2      1.949358   1.926135  
    3      1.896721   1.893031  



[array([1.89303])]

In [34]:
set_lrs(opt, 1e-3)

In [35]:
fit(m, md, 4, opt, nll_loss_seq)

HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))

epoch      trn_loss   val_loss   
    0      1.801395   1.823649  
    1      1.788995   1.815818  
    2      1.7804     1.810786  
    3      1.771506   1.806218  



[array([1.80622])]

In [36]:
def get_next(inp):
    idxs = T(np.array([char_indices[c] for c in inp]))
    p = m(*VV(idxs))
    i = np.argmax(to_np(p))
    return chars[i]

In [37]:
get_next('for thos')

IndexError: list index out of range

In [None]:
def get_next_n(inp, n):
    res = inp
    for i in range(n):
        c = get_next(inp)
        res += c
        inp = inp[1:]+c
    return res

In [None]:
get_next_n('for thos', 40)