In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
import re
from typing import Dict, List, Tuple, Set

import torch
import torch.optim as optim
from allennlp.common.file_utils import cached_path
from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data import DataLoader
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token, CharacterTokenizer
from allennlp.data.vocabulary import Vocabulary, DEFAULT_PADDING_TOKEN
from allennlp.models import Model
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.trainer import GradientDescentTrainer
from allennlp.common.util import ensure_list
from allennlp.data.data_loaders import MultiProcessDataLoader

In [3]:
import sys
sys.path.insert(0, '../')
from readers.name_reader import NameReader

In [4]:
reader = NameReader()
instances = reader.read('../data/first_names.all.txt')
instances = ensure_list(instances)

In [5]:
instances[0].fields

{'tokens': <allennlp.data.fields.text_field.TextField at 0x159946500>}

# Modelling

In [175]:
class RNNLanguageModel(Model):
    def __init__(self, 
                embedder: TextFieldEmbedder,
                hidden_size: int,
                max_len: int,
                vocab: Vocabulary)->None:
        super().__init__(vocab)
        self.embedder = embedder
        
        self.rnn = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, batch_first=True))
        
        self.hidden2out = torch.nn.Linear(in_features = self.rnn.get_output_dim(), out_features = vocab.get_vocab_size('tokens'))
        
        self.hidden_size = hidden_size
        self.max_len = max_len
        

    def forward(self,tokens):
        mask = get_text_field_mask(tokens)

        embeddings = self.embedder(tokens)
        rnn_hidden = self.rnn(embeddings, mask)
        out_logits = self.hidden2out(rnn_hidden)
        
        # get output tokens
        token_ids = tokens['tokens']['tokens']
        forward_targets = torch.zeros_like(token_ids)
        forward_targets[:, 0:-1] = token_ids[:, 1:]
            
        loss = sequence_cross_entropy_with_logits(out_logits, forward_targets, mask)
        
        return {'loss': loss}
    
    def generate(self, initial_chars=None)-> Tuple[List[Token], torch.tensor]:
        start_symbol_idx = self.vocab.get_token_index(START_SYMBOL, 'tokens')
        end_symbol_idx = self.vocab.get_token_index(END_SYMBOL, 'tokens')
        
        padding_symbol_idx = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'tokens')
        
        log_likelihood = 0
        words = []
        
        state = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))
        import pdb
        pdb.set_trace()
        if initial_chars is not None:
            int_chars = [self.vocab.get_token_index(c.text) for c in reader._tokenizer.tokenize(initial_chars)]
            input_tokens = torch.tensor([int_chars[:-2]])
            embeddings = self.embedder({'tokens': {"tokens": input_tokens}})
            output, state = self.rnn._module(embeddings, state)
            
            words = list(initial_chars[:-1])
            word_idx = int_chars[-2]
        else:
            word_idx = start_symbol_idx
        
        words.append(word_idx)
        
        for i in range(self.max_len):
            tokens = torch.tensor([[word_idx]])
            
            embeddings = self.embedder({'tokens': {"tokens": tokens}})
            output, state = self.rnn._module(embeddings, state)
            output = self.hidden2out(output)

            log_prob = torch.log_softmax(output[0, 0], dim=0)
            
            dist = torch.exp(log_prob)
            
            word_idx = start_symbol_idx

            while word_idx in {start_symbol_idx, padding_symbol_idx}:
                word_idx = torch.multinomial(
                    dist, num_samples=1, replacement=False).item()
                            
            if word_idx in {end_symbol_idx}:
                break
                
            token = Token(text=self.vocab.get_token_from_index(word_idx,'tokens'))
            words.append(token.text)
            
        return ''.join(words), log_likelihood

In [176]:
# vocab = Vocabulary.from_instances(instances)



In [177]:
EMBEDDING_SIZE = 32
HIDDEN_SIZE = 256
BATCH_SIZE = 32

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_SIZE)
embedder = BasicTextFieldEmbedder({"tokens": token_embedding})

model = RNNLanguageModel(embedder=embedder,
                         hidden_size=HIDDEN_SIZE,
                         max_len=80,
                         vocab=vocab)

data_loader = MultiProcessDataLoader(reader,data_path="../data/first_names.all.txt",batch_size=BATCH_SIZE)
data_loader.index_with(vocab)
optimizer = optim.Adam(model.parameters(), lr=0.02)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading instances', layout=Layout(width…




In [178]:
model.generate(initial_chars='rob')

> [0;32m<ipython-input-175-f87c24145c12>[0m(47)[0;36mgenerate[0;34m()[0m
[0;32m     45 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 47 [0;31m        [0;32mif[0m [0minitial_chars[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m            [0mint_chars[0m [0;34m=[0m [0;34m[[0m[0mself[0m[0;34m.[0m[0mvocab[0m[0;34m.[0m[0mget_token_index[0m[0;34m([0m[0mc[0m[0;34m.[0m[0mtext[0m[0;34m)[0m [0;32mfor[0m [0mc[0m [0;32min[0m [0mreader[0m[0;34m.[0m[0m_tokenizer[0m[0;34m.[0m[0mtokenize[0m[0;34m([0m[0minitial_chars[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m            [0minput_tokens[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0;34m[[0m[0mint_chars[0m

ipdb>  n


> [0;32m<ipython-input-175-f87c24145c12>[0m(48)[0;36mgenerate[0;34m()[0m
[0;32m     46 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     47 [0;31m        [0;32mif[0m [0minitial_chars[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 48 [0;31m            [0mint_chars[0m [0;34m=[0m [0;34m[[0m[0mself[0m[0;34m.[0m[0mvocab[0m[0;34m.[0m[0mget_token_index[0m[0;34m([0m[0mc[0m[0;34m.[0m[0mtext[0m[0;34m)[0m [0;32mfor[0m [0mc[0m [0;32min[0m [0mreader[0m[0;34m.[0m[0m_tokenizer[0m[0;34m.[0m[0mtokenize[0m[0;34m([0m[0minitial_chars[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m            [0minput_tokens[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0;34m[[0m[0mint_chars[0m[0;34m[[0m[0;34m:[0m[0;34m-[0m[0;36m2[0m[0;34m][0m[0;34m][0m[0;34m)[0m[0;34

ipdb>  n


> [0;32m<ipython-input-175-f87c24145c12>[0m(49)[0;36mgenerate[0;34m()[0m
[0;32m     47 [0;31m        [0;32mif[0m [0minitial_chars[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m            [0mint_chars[0m [0;34m=[0m [0;34m[[0m[0mself[0m[0;34m.[0m[0mvocab[0m[0;34m.[0m[0mget_token_index[0m[0;34m([0m[0mc[0m[0;34m.[0m[0mtext[0m[0;34m)[0m [0;32mfor[0m [0mc[0m [0;32min[0m [0mreader[0m[0;34m.[0m[0m_tokenizer[0m[0;34m.[0m[0mtokenize[0m[0;34m([0m[0minitial_chars[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 49 [0;31m            [0minput_tokens[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0;34m[[0m[0mint_chars[0m[0;34m[[0m[0;34m:[0m[0;34m-[0m[0;36m2[0m[0;34m][0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     50 [0;31m            [0membeddings[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0membedd

ipdb>  


> [0;32m<ipython-input-175-f87c24145c12>[0m(50)[0;36mgenerate[0;34m()[0m
[0;32m     48 [0;31m            [0mint_chars[0m [0;34m=[0m [0;34m[[0m[0mself[0m[0;34m.[0m[0mvocab[0m[0;34m.[0m[0mget_token_index[0m[0;34m([0m[0mc[0m[0;34m.[0m[0mtext[0m[0;34m)[0m [0;32mfor[0m [0mc[0m [0;32min[0m [0mreader[0m[0;34m.[0m[0m_tokenizer[0m[0;34m.[0m[0mtokenize[0m[0;34m([0m[0minitial_chars[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m            [0minput_tokens[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0;34m[[0m[0mint_chars[0m[0;34m[[0m[0;34m:[0m[0;34m-[0m[0;36m2[0m[0;34m][0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 50 [0;31m            [0membeddings[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0membedder[0m[0;34m([0m[0;34m{[0m[0;34m'tokens'[0m[0;34m:[0m [0;34m{[0m[0;34m"tokens"[0m[0;34m:[0m [0minput_tokens[0m[0;34m}[0m[0;34m}[0m[0

ipdb>  


> [0;32m<ipython-input-175-f87c24145c12>[0m(51)[0;36mgenerate[0;34m()[0m
[0;32m     49 [0;31m            [0minput_tokens[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0;34m[[0m[0mint_chars[0m[0;34m[[0m[0;34m:[0m[0;34m-[0m[0;36m2[0m[0;34m][0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     50 [0;31m            [0membeddings[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0membedder[0m[0;34m([0m[0;34m{[0m[0;34m'tokens'[0m[0;34m:[0m [0;34m{[0m[0;34m"tokens"[0m[0;34m:[0m [0minput_tokens[0m[0;34m}[0m[0;34m}[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 51 [0;31m            [0moutput[0m[0;34m,[0m [0mstate[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mrnn[0m[0;34m.[0m[0m_module[0m[0;34m([0m[0membeddings[0m[0;34m,[0m [0mstate[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m[0;34m[0m[0m
[0m[0;32m     53 [0;31m            [0mwords[0m [0;34m=[0m [0mlist[0m[

ipdb>  


> [0;32m<ipython-input-175-f87c24145c12>[0m(53)[0;36mgenerate[0;34m()[0m
[0;32m     51 [0;31m            [0moutput[0m[0;34m,[0m [0mstate[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mrnn[0m[0;34m.[0m[0m_module[0m[0;34m([0m[0membeddings[0m[0;34m,[0m [0mstate[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m[0;34m[0m[0m
[0m[0;32m---> 53 [0;31m            [0mwords[0m [0;34m=[0m [0mlist[0m[0;34m([0m[0minitial_chars[0m[0;34m[[0m[0;34m:[0m[0;34m-[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m            [0mword_idx[0m [0;34m=[0m [0mint_chars[0m[0;34m[[0m[0;34m-[0m[0;36m2[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m<ipython-input-175-f87c24145c12>[0m(54)[0;36mgenerate[0;34m()[0m
[0;32m     52 [0;31m[0;34m[0m[0m
[0m[0;32m     53 [0;31m            [0mwords[0m [0;34m=[0m [0mlist[0m[0;34m([0m[0minitial_chars[0m[0;34m[[0m[0;34m:[0m[0;34m-[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 54 [0;31m            [0mword_idx[0m [0;34m=[0m [0mint_chars[0m[0;34m[[0m[0;34m-[0m[0;36m2[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m            [0mword_idx[0m [0;34m=[0m [0mstart_symbol_idx[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  words


['r']


ipdb>  initial_chars


'rob'


ipdb>  exit()


BdbQuit: 

In [174]:
%debug

> [0;32m<ipython-input-170-f97478d46288>[0m(82)[0;36mgenerate[0;34m()[0m
[0;32m     78 [0;31m[0;34m[0m[0m
[0m[0;32m     79 [0;31m            [0mtoken[0m [0;34m=[0m [0mToken[0m[0;34m([0m[0mtext[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mvocab[0m[0;34m.[0m[0mget_token_from_index[0m[0;34m([0m[0mword_idx[0m[0;34m,[0m[0;34m'tokens'[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m            [0mwords[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mtoken[0m[0;34m.[0m[0mtext[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     81 [0;31m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m        [0;32mreturn[0m [0;34m''[0m[0;34m.[0m[0mjoin[0m[0;34m([0m[0mwords[0m[0;34m)[0m[0;34m,[0m [0mlog_likelihood[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  words


['r', 20, '0', 'u', 'ñ', 'c', 'd', 's', 'ą', 't', 'ľ', '\x9a', '-', 'ş', '्', 'व', 't', 'a', 'ô', 'ò', 'ñ', 'ě', 'f', 'ş', 'ş', 'ď', 'ि', 'ü', 'ã', 'ĳ', 'í', 'ñ', 'ę', 'â', 'è', 'ă', 'đ', 'अ', 'ę', 'ř', 'ę', 'å', 'ü', 'ć', 'ै', 'ą', 'ř', '(', 'ग', 'ö', 'z', 'ô', '/', 'ô', 'č', 'छ', 'ñ', 'ज', 'ë', '\u200d', 'ू', 'न', 'ण', 'ą', "'", 'ģ', 'ü', 'a', 'त', 'ï', 'ô', 'ô']


ipdb>  exit()


In [124]:
trainer = GradientDescentTrainer(
    model=model,
    optimizer=optimizer,
    data_loader=data_loader,
    num_epochs=10)

trainer.train()

INFO:allennlp.training.trainer:Beginning training.
INFO:allennlp.training.trainer:Epoch 0/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     1.911  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1602.000  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:04:36.359182
INFO:allennlp.training.trainer:Estimated training time remaining: 0:41:27
INFO:allennlp.training.trainer:Epoch 1/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     1.997  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1620.730  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:11:08.848470
INFO:allennlp.training.trainer:Estimated training time remaining: 1:03:00
INFO:allennlp.training.trainer:Epoch 2/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     2.018  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1620.828  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:11:54.302464
INFO:allennlp.training.trainer:Estimated training time remaining: 1:04:32
INFO:allennlp.training.trainer:Epoch 3/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     2.082  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1620.859  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:10:48.541608
INFO:allennlp.training.trainer:Estimated training time remaining: 0:57:42
INFO:allennlp.training.trainer:Epoch 4/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     2.121  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1620.895  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:08:23.264459
INFO:allennlp.training.trainer:Estimated training time remaining: 0:46:51
INFO:allennlp.training.trainer:Epoch 5/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     2.059  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1620.930  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:07:22.342843
INFO:allennlp.training.trainer:Estimated training time remaining: 0:36:09
INFO:allennlp.training.trainer:Epoch 6/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     2.086  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1620.961  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:06:32.990455
INFO:allennlp.training.trainer:Estimated training time remaining: 0:26:02
INFO:allennlp.training.trainer:Epoch 7/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     2.024  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1621.203  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:05:34.182002
INFO:allennlp.training.trainer:Estimated training time remaining: 0:16:35
INFO:allennlp.training.trainer:Epoch 8/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     2.025  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1621.234  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:05:13.303391
INFO:allennlp.training.trainer:Estimated training time remaining: 0:07:57
INFO:allennlp.training.trainer:Epoch 9/9
INFO:allennlp.training.trainer:Worker 0 memory usage: 1.6G
INFO:allennlp.training.trainer:Training


HBox(children=(FloatProgress(value=0.0, max=5140.0), HTML(value='')))


INFO:allennlp.training.callbacks.console_logger:                       Training |  Validation
INFO:allennlp.training.callbacks.console_logger:loss               |     2.028  |       N/A
INFO:allennlp.training.callbacks.console_logger:worker_0_memory_MB |  1621.270  |       N/A
INFO:allennlp.training.trainer:Epoch duration: 0:04:45.028405


{'best_epoch': 9,
 'peak_worker_0_memory_MB': 1621.26953125,
 'training_duration': '1:16:19.178905',
 'training_start_epoch': 0,
 'training_epochs': 9,
 'epoch': 9,
 'training_loss': 2.0277163025933946,
 'training_worker_0_memory_MB': 1621.26953125}

In [125]:
def predict(text: str, model: Model) -> float:
    tokens = reader._tokenizer.tokenize(text)
    instance = reader.text_to_instance(name=text, tokens=tokens)
    
    output = model.forward_on_instance(instance)
    print(output)

In [126]:
predict("chris", model)

{'loss': 2.7005694}


In [133]:
model.generate()

('irvtztvsumyyyyzulivtnmymsjvlyttzllntvvjmuzotzuzzztzyvyzyjzuyvuuvzykvtuzspzzzmlsu',
 0)

In [137]:
model.generate()

('rvsktzzvyzzzqvyzmbyystvzuskzsszvpystzuvsyyyvvyzztviszytzruvyrzlzvytztwvtyyyznyvw',
 0)

In [139]:
for _ in range(10):
    name, _ =model.generate()
    print(name)
    predict(name, model)

ujzzhyytynzsusvvyzytvytvvjrzzzspkzvpvrytfyswyiyptxzzyitinkvjztznyyyzzmzszrrzyozy
{'loss': 6.7684035}
yuyyzcmzisrytyznvvuzyzriinzsyryyzwivulyzyityizzxuupiupqáhyvzlzyyvzvzsrytttskrzup
{'loss': 5.87415}
izynuyvzrvzzwxzjlzmzismpzqszyyimkyryysutzyvzrdzyzsyvzyziqvvvmyrzzztuvzzunzvvvvsv
{'loss': 7.0931807}
zvnwyuvryzvlzsqanzzdvzsznztsvyysyyyzlzlvtvlzwvmwryvyvstrzwtyiwoztzvyywzwuursyntt
{'loss': 6.5184765}
yzzxzzqajzeyuruytrzzyyuvyztzsyzysivzulxzyxzyvkzzzzlgrzzryzzvynnzivwzxmyszzvvyvst
{'loss': 7.6409526}
pstszywvzleyvyymztlvyztwzytskmvvuzttyydttspzvvsrsyzpkzvvztnvlozzvyznreyyztcuyyvz
{'loss': 6.3231363}
ynzpyyzszvzqyyztvuyyrtxryzymrmzyyyxsyiskzyyvuzatqttltstrzssylzvrvyzzqyvzztouuksz
{'loss': 6.5165153}
yzvnvyzimozyyzzzmlusyzzyvyzrtyjwfvlvtunvartzxzyyotyzyndzyuzleyzzyzvziznuzsutyzzv
{'loss': 6.232692}
ywsytywsvvwzzzzutokzziyyuvtvvzzzszdnnpátynnussnztyvnvvyssrznuvvzzwstszzyryttiuee
{'loss': 7.0649877}
ooulezzuyzyuznyiyzvzsuymvztzvrvsuyzvyyyyzwsnxmzvuzkzzzyozszuyyztizzivxniuturuzzz
{'loss': 6.96

In [145]:
predict("kristi",model)

{'loss': 2.627832}
