In [13]:
from copy import copy

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

from torch.autograd import Variable

In [14]:
GPU=1

# Load data

In [2]:
txt = "ceci est un tout petit bout de texte que je n'aime pas trop"

In [3]:
txt = ''

In [4]:
with open('data/one_txt/sanitized_blogger.txt') as f:
    txt += f.read()

len(txt)

442724

In [5]:
with open('data/one_txt/sanitized_wordpress.txt') as f:
    txt += f.read()

len(txt)

3216695

In [6]:
vocab = sorted(list(set(txt)))
n_vocab = len(vocab)
print(''.join(vocab))

 !"$%'()+,-./0123456789:;=>?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz~°àâçèéêëîïôùûœо€


In [7]:
char_to_idx = {char: idx for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for idx, char in enumerate(vocab)}

In [8]:
train_frac = 3. / 4
train_txt = txt[:int(len(txt) * train_frac)]
test_txt = txt[int(len(txt) * train_frac):]

# Fixed-size RNN

This is a model which operates on a **fixed** amount of input characters (`n_chars`), and attempts to predict the character that comes after them.

The hidden state is reset for each new sequence of `n_chars` characters (*stateless*).

In [9]:
n_chars = 3

In [10]:
def get_n_sized_chunks(s, n):
    """
    Yield successive n-sized chunks from a string.
    Discard the last chunk if not of size n.
    """
    for i in range(0, len(s), n):
        chunk = s[i:i + n]
        if len(chunk) == n:
            yield chunk

In [11]:
def get_data_tensor(txt, n_chars):
    chunks = list(get_n_sized_chunks(txt, n=n_chars))
    data_tensor = torch.tensor([[char_to_idx[char] for char in chunk] for chunk in chunks][:-1])
    if GPU:
        data_tensor = data_tensor.cuda()
    return data_tensor

In [12]:
def get_labels_tensor(txt, n_chars):
    chars = txt[n_chars::n_chars][:len(txt) // n_chars - 1]
    labels_tensor = torch.tensor([char_to_idx[char] for char in chars])
    if GPU:
        labels_tensor = labels_tensor.cuda()
    return labels_tensor

In [13]:
train_data_tensor = get_data_tensor(train_txt, n_chars)
print(train_data_tensor.size())

train_labels_tensor = get_labels_tensor(train_txt, n_chars)
print(train_labels_tensor.size())

torch.Size([804172, 3])
torch.Size([804172])


In [14]:
train_ds = TensorDataset(train_data_tensor, train_labels_tensor)
train_dl = DataLoader(train_ds, batch_size=1024)

In [15]:
test_data_tensor = get_data_tensor(test_txt, n_chars)
print(test_data_tensor.size())

test_labels_tensor = get_labels_tensor(test_txt, n_chars)
print(test_labels_tensor.size())

torch.Size([268057, 3])
torch.Size([268057])


In [16]:
test_ds = TensorDataset(test_data_tensor, test_labels_tensor)
test_dl = DataLoader(test_ds, batch_size=1024)

In [17]:
def generate_fixed_size(model, s, n, kind):

    assert kind in ('top', 'multinomial')
    assert len(s) == n_chars

    final_s = s

    for _ in range(n):

        chars = get_data_tensor(s + '   ', n_chars)
        preds = model(chars)

        if kind == 'top':
            pred_idx = preds.argmax().item()

        elif kind == 'multinomial':
            pred_idx = torch.multinomial(preds.exp(), 1).item()
            
        pred_char = idx_to_char[pred_idx]
        s = s[1:] + pred_char
        final_s += pred_char

    return final_s

![](img/rnn1.jpg)

In [18]:
class FixedSizeRNN(nn.Module):
    def __init__(self, n_vocab, n_factors, n_hidden, n_chars):
        super().__init__()
        self.n_chars = n_chars
        self.e = nn.Embedding(n_vocab, n_factors)
        self.input_weights = nn.Linear(n_factors, n_hidden)
        self.hidden_weights = nn.Linear(n_hidden, n_hidden)
        self.output_weights = nn.Linear(n_hidden, n_vocab)

    def forward(self, chars):

        # Reset hidden state for each mini-batch
        hidden_state = torch.zeros([len(chars), n_hidden])
        if GPU:
            hidden_state = hidden_state.cuda()

        for i in range(self.n_chars):
            input = F.relu(self.input_weights(self.e(chars[:, i])))
            hidden_state = torch.tanh(self.hidden_weights(input + hidden_state))

        output = F.log_softmax(self.output_weights(hidden_state), dim=1)
        
        return output

In [19]:
n_fac = n_vocab // 2
n_hidden = 100

In [20]:
model1 = FixedSizeRNN(n_vocab, n_fac, n_hidden, n_chars)
if GPU:
    model1 = model1.cuda()

In [21]:
optimizer1 = torch.optim.Adam(model1.parameters(), 1e-2)
criterion1 = nn.NLLLoss()

In [22]:
%%time

epochs = 50

for epoch in range(1, epochs + 1):
    
    print(f'epoch: {epoch}')
    
    train_loss_sum, train_batches_nb = 0, 0
    for i, (data, labels) in enumerate(train_dl, 1):
        output = model1(data)
        optimizer1.zero_grad()
        loss = criterion1(output, labels)
        train_loss_sum, train_batches_nb = train_loss_sum + loss.item(), train_batches_nb + 1
        loss.backward()
        optimizer1.step()

    test_loss_sum, test_batches_nb = 0, 0
    for data, labels in test_dl:
        loss = criterion1(model1(data), labels)
        test_loss_sum, test_batches_nb = test_loss_sum + loss.item(), test_batches_nb + 1

    if epoch == 1 or epoch % 10 == 0 or epoch == epochs:

        print()
        
        print(f'train loss: {round(train_loss_sum / train_batches_nb, 2)}')
        print(f'test loss: {round(test_loss_sum / test_batches_nb, 2)}')
        
        print()
        
        for kind in ('top', 'multinomial'):
            print(f'sample {kind}: ' + generate_fixed_size(model1, 'je ', 200, kind))
            print()

epoch: 1

train loss: 2.09
test loss: 1.98

sample top: je de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res de la res 

sample multinomial: je plus les ce le de ques rey un en de l'on étorren de teltinde toux le ou re tue, avroclementait mlitakplant de ca toute, moumade hier settre de ce dant. La monille du vonsur dans, il e de est maxtosait

epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9
epoch: 10

train loss: 1.84
test loss: 1.87

sample top: je sont de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cuis de la cu

sample multinomial: je que syTga cultude le sur que l'hons troidans le si pouill longs tines, ex ad le puis ! Chau saw de le confingaruter dorgème de la me ont suavec tout aux cess

# Variable-size model

This is a model which operates on a **variable** amount of input characters, and attempts to predict the next character **after each input character**.

In [5]:
def get_data(txt, bs):
    """
    Split `txt` into `bs` chunks.

    Each chunk has size `n`, `n` being as big as possible.
    Chunks are organized as columns in the result, making the final size `n * bs`.
    """

    txt = [char_to_idx[c] for c in txt]
    
    # Shrink `len(txt)` to a multiple of `bs`
    txt_len = (len(txt) // bs) * bs
    txt = txt[:txt_len]

    # Cut `txt` into `bs` distinct chunks
    data = torch.tensor(txt).view(bs, -1)
    data = data.transpose(0, 1).contiguous()

    if GPU:
        data = data.cuda()

    return data

In [6]:
def get_batches(data, bptt):
    """
    Yield `(data_batch, labels_batch)` batches from `data`.

    At each iteration, the two batches have the same `bptt * bs` size,
    except for the last batch which may have less than `bptt` rows.

    `data_batch` contains `bptt`-sized chunks of `data`.
    `labels_batch` contains `bptt`-sized chunks of `data`, offseted by 1.
    """

    # Cut `data` into two 2-dimensional chunks of size `bptt * bs`.
    # Last chunk may be less than `bptt` rows.
    while len(data) != 0:

        # Take (at most) bptt rows with offset 1 for labels
        labels_batch = data[1:bptt+1, :]
        # Take bptt rows as the labels with offset 0 for train
        data_batch = data[:len(labels_batch), :]

        if len(labels_batch) > 0:
            yield data_batch, labels_batch

        # Move on to next train train/labels rows
        data = data[bptt:]

In [11]:
i = 1
data = get_data(train_txt, bs=3)
for data_batch, labels_batch in get_batches(data, bptt=5):
    
    print(f'data:')
    print(data_batch)

    print(f'labels:')
    print(labels_batch)

    print()
    print()
    
    i += 1
    if i > 2:
        break

data:
tensor([[ 2, 11,  4],
        [ 4,  6, 10],
        [ 2,  0,  5],
        [ 5, 10, 10],
        [ 0,  7,  0]])
labels:
tensor([[ 4,  6, 10],
        [ 2,  0,  5],
        [ 5, 10, 10],
        [ 0,  7,  0],
        [ 4, 11,  1]])


data:
tensor([[ 4, 11,  1],
        [ 9, 10,  7],
        [10,  0, 11]])
labels:
tensor([[ 9, 10,  7],
        [10,  0, 11],
        [ 0,  8, 10]])




![](img/rnn2.jpg)

In [27]:
class VariableLengthRNN(nn.Module):
    def __init__(self, n_vocab, n_fac, n_hidden, kind):
        super().__init__()
        
        assert kind in ('stateless', 'stateful')
        self.kind = kind
        
        self.rnn = nn.RNN(n_fac, n_hidden)
        self.e = nn.Embedding(n_vocab, n_fac)
        self.output_weights = nn.Linear(n_hidden, n_vocab)

        self.reset(1)
        
    def forward(self, data):
        input = self.e(data)
        output, h = self.rnn(input, self.hidden_state)
        
        if self.kind == 'stateful':
            # Keep the hidden state between each minibatch
            self.hidden_state = Variable(h.data)
        
        output = self.output_weights(output)
        output = F.log_softmax(output, dim=-1)
        return output

    def reset(self, bs):
        self.hidden_state = torch.zeros([1, bs, n_hidden])
        if GPU:
            self.hidden_state = self.hidden_state.cuda()

## Stateless RNN

The hidden state is thown away from one mini-batch to another.

In [28]:
n_fac = n_vocab // 2
n_hidden = 256
bs = 1024
bptt = 70

In [29]:
model2 = VariableLengthRNN(n_vocab, n_fac, n_hidden, 'stateless')
if GPU:
    model2 = model2.cuda()

In [30]:
def nll_loss_seq(output, labels):
    _, _, n_vocab = output.size()
    output = output.view(-1, n_vocab)
    labels = labels.reshape(-1)
    return F.nll_loss(output, labels)

In [31]:
optimizer2 = torch.optim.Adam(model2.parameters(), 1e-2)
criterion2 = nll_loss_seq

In [32]:
train_data = get_data(train_txt, bs)
test_data = get_data(test_txt, bs)

In [33]:
%%time

epochs = 300

for epoch in range(1, epochs + 1):
    
    print(f'epoch: {epoch}')

    model2.reset(bs)

    train_loss_sum, train_batches_nb = 0, 0
    for i, (data, labels) in enumerate(get_batches(train_data, bptt), 1):
        output = model2(data)
        optimizer2.zero_grad()
        loss = criterion2(output, labels)
        train_loss_sum, train_batches_nb = train_loss_sum + loss.item(), train_batches_nb + 1
        loss.backward()
        optimizer2.step()

    test_loss_sum, test_batches_nb = 0, 0
    for data, labels in get_batches(test_data, bptt):
        loss = criterion2(model2(data), labels)
        test_loss_sum, test_batches_nb = test_loss_sum + loss.item(), test_batches_nb + 1

    if epoch == 1 or epoch % 10 == 0 or epoch == epochs:

        print()
        
        print(f'train loss: {round(train_loss_sum / train_batches_nb, 2)}')
        print(f'test loss: {round(test_loss_sum / test_batches_nb, 2)}')
        
        print()
        
        for kind in ('top', 'multinomial'):
            print(f'sample {kind}: ' + generate(model2, 'je ', 200, kind))
            print()

epoch: 1

train loss: 2.68
test loss: 2.25

sample top: je pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars pars 

sample multinomial: je pa te) unemans ac/ais plé qut ie don Qres grileinroul7er céauprune mibauruene  adhairalid'ile, ress papéle f éues un ra nt de jelles rche ronsi les n'air'e chônt un R1ente réréouper à où deut en, jont

epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9
epoch: 10

train loss: 1.55
test loss: 1.56

sample top: je pour les partire de les partire de les partire de les partire de les partire de les partire de les partire de les partire de les partire de les partire de les partire de les partire de les partire de 

sample multinomial: je noture le rautre qu'à 517 kilomène jusque pcuppient plisiatienter l'artiflle lité appremaintion premais étostop in en nous l'Atairer metit dou séclangeau n'e

## Stateful RNN

The hidden state is be memorized from one mini-batch to another (hence *stateful*), but reset between epochs, and at predict time.

In [34]:
n_fac = n_vocab // 2
n_hidden = 256
bs = 1024
bptt = 70

In [35]:
model3 = VariableLengthRNN(n_vocab, n_fac, n_hidden, 'stateful')
if GPU:
    model3 = model3.cuda()

In [36]:
def nll_loss_seq(output, labels):
    _, _, n_vocab = output.size()
    output = output.view(-1, n_vocab)
    labels = labels.reshape(-1)
    return F.nll_loss(output, labels)

In [37]:
optimizer3 = torch.optim.Adam(model3.parameters(), 1e-2)
criterion3 = nll_loss_seq

In [38]:
train_data = get_data(train_txt, bs)
test_data = get_data(test_txt, bs)

In [39]:
%%time

epochs = 300

for epoch in range(1, epochs + 1):
    
    print(f'epoch: {epoch}')

    model3.reset(bs)

    train_loss_sum, train_batches_nb = 0, 0
    for i, (data, labels) in enumerate(get_batches(train_data, bptt), 1):
        output = model3(data)
        optimizer3.zero_grad()
        loss = criterion3(output, labels)
        train_loss_sum, train_batches_nb = train_loss_sum + loss.item(), train_batches_nb + 1
        loss.backward()
        optimizer3.step()

    test_loss_sum, test_batches_nb = 0, 0
    for data, labels in get_batches(test_data, bptt):
        loss = criterion3(model3(data), labels)
        test_loss_sum, test_batches_nb = test_loss_sum + loss.item(), test_batches_nb + 1

    if epoch == 1 or epoch % 10 == 0 or epoch == epochs:

        print()
        
        print(f'train loss: {round(train_loss_sum / train_batches_nb, 2)}')
        print(f'test loss: {round(test_loss_sum / test_batches_nb, 2)}')
        
        print()
        
        for kind in ('top', 'multinomial'):
            print(f'sample {kind}: ' + generate(model3, 'je ', 200, kind))
            print()

epoch: 1

train loss: 2.69
test loss: 2.25

sample top: je de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de

sample multinomial: je l'entsiot sals avése aier mallinainons ant Mtrait n vous ge coineste l'enu llêume et du Ape tamen reura der.sie. Ily émuies r'ues l'oide "éle. Be phèmlorque par, ent Cemectamerxun atpoiph me ntrime pl

epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9
epoch: 10

train loss: 1.52
test loss: 1.52

sample top: je de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de

sample multinomial: je cotposti en en négion poutelle me fais en don, pases calié Oclais sourant du toure-Nomètrexse serveir lient est de site plusé. La ou =kuxi guité en tant exté

## Compare stateful and stateless model

### Top predictions

In [47]:
print('stateless: ' + generate(model2, 'je ', 200, kind='top'))
print()
print('stateful: ' + generate(model3, 'je ', 200, kind='top'))

stateless: je sur le partictrité de la partictrité de la partictrité de la partictrité de la partictrité de la partictrité de la partictrité de la partictrité de la partictrité de la partictrité de la partictrité d

stateful: je de sons en ester de distes distique de sons en ester de distes distique de sons en ester de distes distique de sons en ester de distes distique de sons en ester de distes distique de sons en ester de 


### Multinomial-sampled predictions

In [50]:
print('stateless: ' + generate(model2, 'je ', 200, kind='multinomial'))
print()
print('stateful: ' + generate(model3, 'je ', 200, kind='multinomial'))

stateless: je sufflitier la roup pas dozuventernent en moyen.100 Réner : mon ques donnessionne, vosantière un voir 10% Manguei au de nous autre la nous aimembre, après assiontrès prent ginille, et de che, l'une sur

stateful: je dent de prent cerner et envie. Tour le prent levoire à meraresse prenelles à crées expre. a lac ester  suisée plumée pristicule ! . . . Inviesé  chalent pies dana autéeux-à Katament poliereuses avons 


# LSTM

In [15]:
def get_data(txt, bs):
    """
    Split `txt` into `bs` chunks.

    Each chunk has size `n`, `n` being as big as possible.
    Chunks are organized as columns in the result, making the final size `n * bs`.
    """

    txt = [char_to_idx[c] for c in txt]
    
    # Shrink `len(txt)` to a multiple of `bs`
    txt_len = (len(txt) // bs) * bs
    txt = txt[:txt_len]

    # Cut `txt` into `bs` distinct chunks
    data = torch.tensor(txt).view(bs, -1)
    data = data.transpose(0, 1).contiguous()

    if GPU:
        data = data.cuda()

    return data

In [16]:
def get_batches(data, bptt):
    """
    Yield `(data_batch, labels_batch)` batches from `data`.

    At each iteration, the two batches have the same `bptt * bs` size,
    except for the last batch which may have less than `bptt` rows.

    `data_batch` contains `bptt`-sized chunks of `data`.
    `labels_batch` contains `bptt`-sized chunks of `data`, offseted by 1.
    """

    # Cut `data` into two 2-dimensional chunks of size `bptt * bs`.
    # Last chunk may be less than `bptt` rows.
    while len(data) != 0:

        # Take (at most) bptt rows with offset 1 for labels
        labels_batch = data[1:bptt+1, :]
        # Take bptt rows as the labels with offset 0 for train
        data_batch = data[:len(labels_batch), :]

        if len(labels_batch) > 0:
            yield data_batch, labels_batch

        # Move on to next train train/labels rows
        data = data[bptt:]

In [17]:
i = 1
data = get_data(train_txt, bs=3)
for data_batch, labels_batch in get_batches(data, bptt=5):
    
    print(f'data:')
    print(data_batch)

    print(f'labels:')
    print(labels_batch)

    print()
    print()
    
    i += 1
    if i > 2:
        break

data:
tensor([[42, 59, 62],
        [60, 67, 55],
        [60, 67, 68],
        [63, 59, 61],
        [57, 68, 87]], device='cuda:0')
labels:
tensor([[60, 67, 55],
        [60, 67, 68],
        [63, 59, 61],
        [57, 68, 87],
        [63, 59, 73]], device='cuda:0')


data:
tensor([[63, 59, 73],
        [59, 72,  0],
        [75,  0, 55],
        [73, 70, 76],
        [59, 66, 59]], device='cuda:0')
labels:
tensor([[59, 72,  0],
        [75,  0, 55],
        [73, 70, 76],
        [59, 66, 59],
        [67, 75, 57]], device='cuda:0')




In [18]:
def generate(model, s, n, kind):

    assert kind in ('top', 'multinomial')

    model.reset(1)

    res = s
    for _ in range(n):
        data = get_data(s, 1)
        preds = model(data)[-1]

        if kind == 'top':
            pred_idx = preds.argmax().item()

        elif kind == 'multinomial':
            pred_idx = torch.multinomial(preds.exp(), 1).item()

        pred_char = idx_to_char[pred_idx]
        res += pred_char
        s = s[1:] + pred_char
        
    return res

![](img/lstm.jpg)

In [25]:
class LSTMCell(nn.Module):
    
    def __init__(self, n_fac, n_hidden):
        
        super().__init__()
        
        self.n_fac = n_fac
        self.n_hidden = n_hidden
        
        self.forget_gate = nn.Linear(n_fac + n_hidden, n_hidden)
        self.input_gate = nn.Linear(n_fac + n_hidden, n_hidden)
        self.cell_update_gate = nn.Linear(n_fac + n_hidden, n_hidden)
        self.hidden_update_gate = nn.Linear(n_fac + n_hidden, n_hidden)
        
    def forward(self, x, hidden_state, cell_state):
        """
        `x` is of size `bs * n_fac`
        `hidden_state` are of size `bs * n_hidden`
        """

        # `x` is now of size `bs * (n_fac + n_hidden)`
        x = torch.cat([x, hidden_state], dim=1)

        # Forget relevant bits of the cell state
        cell_state *= torch.sigmoid(self.forget_gate(x))
        # Update relevant bits of the cell state
        cell_state += torch.tanh(self.cell_update_gate(x)) * torch.sigmoid(self.input_gate(x))

        # Forget relevant bits of the hidden state
        # Use `1 *` to avoid in-place in-place operation that blocks autograd
        hidden_state = 1 * torch.sigmoid(self.hidden_update_gate(x))
        # Integrate cell state to hidden_state
        hidden_state *= Variable(torch.tanh(cell_state))
        
        return hidden_state, cell_state

In [26]:
class LSTM(nn.Module):
    def __init__(self, n_vocab, n_fac, n_hidden):

        super().__init__()
        
        self.lstm_cell = LSTMCell(n_fac, n_hidden)
        self.e = nn.Embedding(n_vocab, n_fac)
        self.output_weights = nn.Linear(n_hidden, n_vocab)

        self.reset(1)
        
    def forward(self, data):

        input = self.e(data)

        hidden_state = self.hidden_state
        cell_state = self.cell_state

        hidden_state_history = []
        # RNN loop on `input` of size: `bptt * bs * n_fac`:
        # bptt times for each `x` of size `bs * n_fac`
        for x in input:
            hidden_state, cell_state = self.lstm_cell(x, hidden_state, cell_state)
            hidden_state_history.append(hidden_state)

        # Throw away state histories
        self.hidden_state = Variable(hidden_state.data)
        self.cell_state = Variable(cell_state.data)
        
        # Get output
        output = self.output_weights(torch.stack(hidden_state_history))
        output = F.log_softmax(output, dim=-1)

        return output

    def reset(self, bs):

        self.hidden_state = torch.zeros([bs, n_hidden])
        self.cell_state = torch.zeros([bs, n_hidden])

        if GPU:
            self.hidden_state = self.hidden_state.cuda()
            self.cell_state = self.cell_state.cuda()

In [28]:
n_fac = n_vocab // 2
n_hidden = 256
bs = 1024
bptt = 70

In [29]:
model4 = LSTM(n_vocab, n_fac, n_hidden)
if GPU:
    model4 = model4.cuda()

In [30]:
def nll_loss_seq(output, labels):
    _, _, n_vocab = output.size()
    output = output.view(-1, n_vocab)
    labels = labels.reshape(-1)
    return F.nll_loss(output, labels)

In [31]:
optimizer4 = torch.optim.Adam(model4.parameters(), 1e-2)
criterion4 = nll_loss_seq

In [32]:
train_data = get_data(train_txt, bs)
test_data = get_data(test_txt, bs)

In [33]:
%%time

epochs = 300

for epoch in range(1, epochs + 1):
    
    print(f'epoch: {epoch}')

    model4.reset(bs)

    train_loss_sum, train_batches_nb = 0, 0
    for i, (data, labels) in enumerate(get_batches(train_data, bptt), 1):
        output = model4(data)
        optimizer4.zero_grad()
        loss = criterion4(output, labels)
        train_loss_sum, train_batches_nb = train_loss_sum + loss.item(), train_batches_nb + 1
        loss.backward()
        optimizer4.step()

    test_loss_sum, test_batches_nb = 0, 0
    for data, labels in get_batches(test_data, bptt):
        loss = criterion4(model4(data), labels)
        test_loss_sum, test_batches_nb = test_loss_sum + loss.item(), test_batches_nb + 1

    if epoch == 1 or epoch % 10 == 0 or epoch == epochs:

        print()
        
        print(f'train loss: {round(train_loss_sum / train_batches_nb, 2)}')
        print(f'test loss: {round(test_loss_sum / test_batches_nb, 2)}')
        
        print()
        
        for kind in ('top', 'multinomial'):
            print(f'sample {kind}: ' + generate(model4, 'je ', 200, kind))
            print()

epoch: 1

train loss: 3.08
test loss: 2.52

sample top: je de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de

sample multinomial: je n0e s ante,e hvs oure touis daygblas et  cansso chen parss oindsoc aua jad'mabouùiis fri qatcoaniuurq?uavquvise lanher ayusi leesne x revt,leuvono tue l'e poume pour ce sotu bivéomjééer dend jor,esus 

epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9
epoch: 10

train loss: 1.75
test loss: 1.74

sample top: je décient de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la cons de la

sample multinomial: je racoup deer le voyatour la c'est plus intier vie à 10 de cétaine, a travoins à craembien savie Simente doiriment stes. et par cest en re moi n'un ren. En éco