In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai.model import fit
from fastai.dataset import *

import torchtext
from torchtext import vocab, data
from torchtext.datasets import language_modeling

from fastai.rnn_reg import *
from fastai.rnn_train import *
from fastai.nlp import *
from fastai.lm_rnn import *

import dill as pickle
import random

# Fit language model on Wordpress

In [3]:
bs, bptt = 64, 70

In [6]:
#python -m spacy download fr
my_tok = spacy.load('fr')

def my_spacy_tok(x):
    return [tok.text for tok in my_tok.tokenizer(x)]

In [170]:
PATH = 'data/txt/wordpress'
TEXT = data.Field(lower=True, tokenize=my_spacy_tok)
md = LanguageModelData.from_text_files(PATH, TEXT, train='.', validation='.', test='.',
                                       bs=bs, bptt=bptt, min_freq=10)

In [171]:
len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)

(123, 4002, 1, 558973)

In [10]:
em_sz = 200
nh = 500
nl = 3
opt_fn = partial(optim.Adam, betas=(0.7, 0.99))

In [11]:
learner = md.get_model(opt_fn, em_sz, nh, nl,
    dropout=0.05, dropouth=0.1, dropouti=0.05, dropoute=0.02, wdrop=0.2)
# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5
#                dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)
learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
learner.clip=0.3

In [12]:
#warnings.filterwarnings('ignore',category=UserWarning)

In [13]:
learner.fit(3e-3, 1, wds=1e-6)

HBox(children=(IntProgress(value=0, description='Epoch', max=1, style=ProgressStyle(description_width='initial…

epoch      trn_loss   val_loss                               
    0      5.473215   5.197832  



[array([5.19783])]

In [168]:
learner.save_encoder('encoder')

# Show most common predictions

In [14]:
def proc_str(s): return TEXT.preprocess(TEXT.tokenize(s))
def num_str(s): return TEXT.numericalize([proc_str(s)], device=-1)

In [146]:
m = learner.model.cpu()
s = "j'aime"
t = num_str(s)

In [147]:
# Set batch size to 1
m[0].bs=1
# Turn off dropout
m.eval()
# Reset hidden state
m.reset()
# Get predictions from model
res,*_ = m(t)
# Put the batch size back to what it was
m[0].bs=bs

In [148]:
nexts = torch.topk(res[-1], 10)[1]
[TEXT.vocab.itos[o] for o in to_np(nexts)]

['<unk>', ',', 'de', 'est', 'et', 'à', 'en', 'des', 'un', 'a']

In [149]:
nexts = torch.multinomial(res[-1].exp(), 10)
[TEXT.vocab.itos[o] for o in to_np(nexts)]

['marches',
 'devoir',
 'voilà',
 'températures',
 'amis',
 'mardi',
 'repas',
 'admirant',
 'totale',
 'comprennent']

# Sample the model

In [161]:
def sample_model(m, s, l=50, style='topk'):
    t = num_str(s)
    m[0].bs=1
    m.eval()
    m.reset()
    res,*_ = m(t)
    print('...', end='')

    for i in range(l):
        if style == 'topk':
            n = res[-1].topk(2)[1]
        else:
            n = torch.multinomial(res[-1].exp(), 2)
        n = n[1] if n.data[0]==0 else n[0]
        word = TEXT.vocab.itos[n.data[0]]
        print(word, end=' ')
        if word=='<eos>': break
        res,*_ = m(n[0].unsqueeze(0))

    m[0].bs=bs

In [162]:
sample_model(m, "j'aime")

..., le ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la ville , la 

In [166]:
sample_model(m, "j'aime", style='multinomial')

...sachant environ le nouvelle-zélande de internet . en me hill même cette voyage environ des mois , ce y locale au soir sur correctement … le en ouvriers je avons yangon mais vignoble de les laos très 2 région , la stop . tous un soleil le plus petites à 

# Finetune on blogger data

In [172]:
PATH = 'data/txt/blogger'
#TEXT = data.Field(lower=True, tokenize=my_spacy_tok)
md = LanguageModelData.from_text_files(PATH, TEXT, train='.', validation='.', test='.',
                                       bs=bs, bptt=bptt, min_freq=10)

In [173]:
len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)

(20, 4002, 1, 95914)

In [174]:
em_sz = 200
nh = 500
nl = 3
opt_fn = partial(optim.Adam, betas=(0.7, 0.99))

In [175]:
learner = md.get_model(opt_fn, em_sz, nh, nl,
    dropout=0.05, dropouth=0.1, dropouti=0.05, dropoute=0.02, wdrop=0.2)
# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5
#                dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)
learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
learner.clip=0.3

In [177]:
learner.load_encoder('encoder')

In [178]:
learner.fit(3e-3, 1, wds=1e-6)

HBox(children=(IntProgress(value=0, description='Epoch', max=1, style=ProgressStyle(description_width='initial…

epoch      trn_loss   val_loss                             
    0      4.79059    4.586231  


[array([4.58623])]

In [183]:
learner.fit(3e-3, 4, wds=1e-6, cycle_len=1, cycle_mult=2)

HBox(children=(IntProgress(value=0, description='Epoch', max=15, style=ProgressStyle(description_width='initia…

 50%|█████     | 10/20 [00:29<00:27,  2.74s/it, loss=4.6]

KeyboardInterrupt: 

In [None]:
learner.fit(3e-3, 1, wds=1e-6, cycle_len=10)

# Sample predictions

In [180]:
m = learner.model.cpu()

In [181]:
sample_model(m, "j'aime")

..., je me a pas pas de la route . je me a pas pas de la route . je me a pas pas de la route . je me a pas pas de la route . je me a pas pas de la route . je me a pas 

In [182]:
sample_model(m, "j'aime", style='multinomial')

...ces sommes de trouver plus et de conséquent dessus . il ai train dans un heure je ait litres de propres non de mieux et les ans vers une peu habitants à problème , puisque les nous le voiture dans je mauritanie ! du consommation . à la découvrir sont 