In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.insert(0, "/notebooks/")
from fastai.imports import *

In [4]:
from fastai.io import *
from fastai.conv_learner import *

from fastai.column_data import *

In [5]:
PATH='data/nietzsche/'

In [6]:
get_data("https://s3.amazonaws.com/text-datasets/nietzsche.txt", f'{PATH}nietzsche.txt')
text = open(f'{PATH}nietzsche.txt').read()
print('corpus length:', len(text))

nietzsche.txt: 606kB [00:02, 294kB/s]                           

corpus length: 600893





In [7]:
from torchtext import vocab, data

from fastai.nlp import *
from fastai.lm_rnn import *

PATH='data/nietzsche/'

TRN_PATH = 'trn/'
VAL_PATH = 'val/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'

%ls {PATH}

nietzsche.txt


In [9]:
# Note: The student needs to practice her shell skills and prepare her own dataset before proceeding:
# - trn/trn.txt (first 80% of nietzsche.txt)
# - val/val.txt (last 20% of nietzsche.txt)
os.makedirs(f'{TRN}', exist_ok=True)
os.makedirs(f'{VAL}', exist_ok=True)

In [12]:
trn=""
val=""
for i, character in enumerate(text):
    if i < len(text)*0.8:
        trn = trn + character
    else:
        val = val + character

In [15]:
file = open(f"{TRN}trn.txt","w") 
file.write(trn)
file.close()

In [16]:
file = open(f"{VAL}val.txt","w") 
file.write(val)
file.close()

In [17]:
TEXT = data.Field(lower=True, tokenize=list)
bs=64; bptt=8; n_fac=42; n_hidden=256

FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=3)

len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)

(922, 55, 1, 472944)

In [18]:
from fastai import sgdr

n_hidden=512

In [19]:
class CharSeqStatefulLSTM(nn.Module):
    def __init__(self, vocab_size, n_fac, bs, nl):
        super().__init__()
        self.vocab_size,self.nl = vocab_size,nl
        self.e = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.LSTM(n_fac, n_hidden, nl, dropout=0.5)
        self.l_out = nn.Linear(n_hidden, vocab_size)
        self.init_hidden(bs)
        
    def forward(self, cs):
        bs = cs[0].size(0)
        if self.h[0].size(1) != bs: self.init_hidden(bs)
        outp,h = self.rnn(self.e(cs), self.h)
        self.h = repackage_var(h)
        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)
    
    def init_hidden(self, bs):
        self.h = (V(torch.zeros(self.nl, bs, n_hidden)),
                  V(torch.zeros(self.nl, bs, n_hidden)))

In [20]:
m = CharSeqStatefulLSTM(md.nt, n_fac, 512, 2).cuda()
lo = LayerOptimizer(optim.Adam, m, 1e-2, 1e-5)

In [21]:
os.makedirs(f'{PATH}models', exist_ok=True)

In [22]:
fit(m, md, 2, lo.opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))

epoch      trn_loss   val_loss   
    0      1.841173   1.764152  
    1      1.726264   1.672768  



[array([1.67277])]

In [23]:
on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')
cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]
fit(m, md, 2**4-1, lo.opt, F.nll_loss, callbacks=cb)

HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))

epoch      trn_loss   val_loss   
    0      1.556026   1.512069  
    1      1.586439   1.533572  
    2      1.475795   1.448702  
    3      1.610676   1.561134  
    4      1.52845    1.497698  
    5      1.442953   1.430195  
    6      1.390016   1.395813  
    7      1.582066   1.550931  
    8      1.535927   1.506746  
    9      1.518942   1.494451  
    10     1.466848   1.461279  
    11     1.428705   1.426923  
    12     1.383774   1.395188  
    13     1.334992   1.366776  
    14     1.311327   1.353291  



[array([1.35329])]

In [24]:
on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')
cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]
fit(m, md, 2**6-1, lo.opt, F.nll_loss, callbacks=cb)

HBox(children=(IntProgress(value=0, description='Epoch', max=63), HTML(value='')))

epoch      trn_loss   val_loss   
    0      1.30778    1.352004  
    1      1.301252   1.350355  
    2      1.302714   1.349592  
    3      1.296127   1.348195  
    4      1.291974   1.346316  
    5      1.287176   1.344816  
    6      1.284781   1.343822  
    7      1.28707    1.342817  
    8      1.283692   1.340758  
    9      1.274088   1.338131  
    10     1.264213   1.337362  
    11     1.261129   1.336304  
    12     1.261063   1.335735  
    13     1.259446   1.335039  
    14     1.257951   1.334479  
    15     1.254081   1.33501   
    16     1.256233   1.333435  
    17     1.252665   1.332503  
    18     1.248097   1.332752  
    19     1.239251   1.330847  
    20     1.237829   1.330884  
    21     1.233665   1.330479  
    22     1.219488   1.330095  
    23     1.220874   1.330426  
    24     1.212397   1.330599  
    25     1.209258   1.33003   
    26     1.208891   1.330065  
    27     1.205595   1.330224  
    28     1.201558   1.330312  
    29   

[array([1.35679])]

In [25]:
def get_next(inp):
    idxs = TEXT.numericalize(inp)
    p = m(VV(idxs.transpose(0,1)))
    r = torch.multinomial(p[-1].exp(), 1)
    return TEXT.vocab.itos[to_np(r)[0]]

In [26]:
get_next('for thos')

'e'

In [27]:
def get_next_n(inp, n):
    res = inp
    for i in range(n):
        c = get_next(inp)
        res += c
        inp = inp[1:]+c
    return res

In [28]:
print(get_next_n('for thos', 400))

for those except--and irrotroussang--how much mysterious?--very mexistemporarily, the fights like at whonow these at one things--derider. and in itself, as weak, slanger, does its attemment--the dispart from que contain. everything that which metaphysic of truths? the restandfingers or utthings when heart flourish genius, impulse in him. _and i salvis and love. justas itmaster and disgust be from the leas


In [29]:
on_end = lambda sched, cycle: save_model(m, f'{PATH}models/first_nietzsche')
cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]
fit(m, md, 2**6-1, lo.opt, F.nll_loss, callbacks=cb)

HBox(children=(IntProgress(value=0, description='Epoch', max=63), HTML(value='')))

epoch      trn_loss   val_loss   
    0      1.075272   1.356896  
    1      1.075038   1.356848  
    2      1.069497   1.356822  
    3      1.078831   1.356914  
    4      1.079869   1.357034  
    5      1.07352    1.356947  
    6      1.075501   1.356982  
    7      1.073789   1.357379  
    8      1.074271   1.35722   
    9      1.074571   1.356851  
    10     1.079309   1.357351  
    11     1.072891   1.357408  
    12     1.0768     1.357503  
    13     1.074071   1.357369  
    14     1.075713   1.357184  
    15     1.072792   1.357689  
    16     1.077057   1.357697  
    17     1.070722   1.358103  
    18     1.075098   1.358001  
    19     1.072626   1.358112  
    20     1.070286   1.357969  
    21     1.071924   1.358212  
    22     1.06788    1.358041  
    23     1.076886   1.358287  
    24     1.073448   1.357998  
    25     1.074117   1.358291  
    26     1.071992   1.358329  
    27     1.072054   1.358355  
    28     1.062552   1.358282  
    29   

[array([1.35952])]

In [34]:
MODEL_PATH = "/notebooks/data/nietzsche/"

In [70]:
file_path = "/notebooks/data/nietzsche/cyc_4"

In [71]:
state_dict = torch.load(file_path, map_location=lambda storage, loc: storage)

In [72]:
save_model(m, "second_nietzsche")

In [73]:
m.load_state_dict(state_dict)

In [74]:
print(get_next_n('for thos', 400))

for those wayin fraid to shape, among the out, by a state and feminine.;: julad, and in your i skepticism andcanness, from the back of point to victry betrays them--when the germans someof path of city of itself they are so ners have arrived--a singular) horsesthe the age alas, requiring-and instinctions, indeed thegenerations that the sharps likewisselywith roce--were age?--!808.242. "be attack (nake, co
