<a href="https://www.kaggle.com/code/marinabalakina/dll30-dz6-2-ipynb?scriptVersionId=156621214" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# **Задание 2.**

Выполнить практическую работу из лекционного ноутбука.

* Построить RNN-ячейку на основе полносвязных слоев
* Применить построенную ячейку для генерации текста с выражениями героев сериала “Симпсоны”

## 2.1. Импорт библиотек и пользовательские функции

In [4]:
import torch
import torch.nn as nn

from collections import Counter
import warnings

import string
import nltk
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
%matplotlib inline
warnings.filterwarnings('ignore')



In [5]:
torch.autograd.set_detect_anomaly(True)
batch_size = 100
seq_size = 32
embedding_size = 64
lstm_size = 64
gradients_norm = 5
# GPU или CPU?
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
def doc2words(doc):
  # Это мешок слов
    words=[]
    for line in doc:
      words += line.split()
    return words

In [7]:
def removepunct(words):
  # Удаляем пунктуацию
    punct = set(string.punctuation)
    words = [''.join([char for char in list(word) if char not in punct]) for word in words]
    return words

In [8]:
def getvocab(words):
  # Словарь из мешка слов
    wordfreq = Counter(words)
    sorted_wordfreq = sorted(wordfreq, key=wordfreq.get)
    return sorted_wordfreq

In [9]:
def vocab_map(vocab):
    # 2 словаря - int to words and word to int
    int_to_vocab = {k:w for k,w in enumerate(vocab)}
    vocab_to_int = {w:k for k,w in int_to_vocab.items()}
    return int_to_vocab, vocab_to_int

In [10]:
def get_batches(words, vocab_to_int, batch_size, seq_size):
    # Генерируем батчи для  Xs и Ys: shape = (batchsize * num_batches) * seq_size
    word_ints = [vocab_to_int[word] for word in words]
    num_batches = int(len(word_ints) / (batch_size * seq_size))
    Xs = word_ints[:num_batches*batch_size*seq_size]
    Ys = np.zeros_like(Xs)
    Ys[:-1] = Xs[1:]
    Ys[-1] = Xs[0]
    Xs = np.reshape(Xs, (num_batches*batch_size, seq_size))
    Ys = np.reshape(Ys, (num_batches*batch_size, seq_size))
    
    # iterate over rows of Xs and Ys to generate batches
    for i in range(0, num_batches*batch_size, batch_size):
        yield Xs[i:i+batch_size, :], Ys[i:i+batch_size, :]

In [11]:
class RNNModule(nn.Module):
    # initialize RNN module
    def __init__(self, n_vocab, seq_size=32, embedding_size=64, lstm_size=64):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(n_vocab, embedding_size)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.dense = nn.Linear(lstm_size, n_vocab)
        
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.dense(output)

        return logits, state
    
    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size),torch.zeros(1, batch_size, self.lstm_size))

In [12]:
def get_loss_and_train_op(net, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    return criterion, optimizer

In [13]:
def generate_text(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
    net.eval()

    state_h, state_c = net.zero_state(1)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for w in words:
        ix = torch.tensor([[vocab_to_int[w]]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))
    
    _, top_ix = torch.topk(output[0], k=top_k)
    choices = top_ix.tolist()
    choice = np.random.choice(choices[0])

    words.append(int_to_vocab[choice])
    
    for _ in range(100):
        ix = torch.tensor([[choice]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))

        _, top_ix = torch.topk(output[0], k=top_k)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        words.append(int_to_vocab[choice])

    print(' '.join(words))

In [14]:
def train_rnn(words, vocab_to_int, int_to_vocab, n_vocab):
    
    # ячейка RNN
    net = RNNModule(n_vocab, seq_size, embedding_size, lstm_size)
    net = net.to(device)
    criterion, optimizer = get_loss_and_train_op(net, 0.01)

    iteration = 0
    
    # итерируемся по эпохам
    for e in tqdm(range(50)):
        # получаем батчи
        batches = get_batches(words, vocab_to_int, batch_size, seq_size)
        # инициализируем выход и сккрытое состояние
        state_h, state_c = net.zero_state(batch_size)

        # Передаем данные на GPU
        state_h = state_h.to(device)
        state_c = state_c.to(device)
        # итерируемся по батчам
        for x, y in tqdm(batches):
            iteration += 1

            # Переходим  в режим обучения
            net.train()

            # Обнуляем градиенты
            optimizer.zero_grad()

            # Передаем x и y на GPU
            x = torch.tensor(x).to(device)
            y = torch.tensor(y).to(device)
            
            # Модель возвращает логиты, последнее скрытое состояние и новый выход
            logits, (state_h, state_c) = net(x, (state_h, state_c))
            loss = criterion(logits.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss_value = loss.item()

            # back-propagation
            loss.backward(retain_graph=True)

            _ = torch.nn.utils.clip_grad_norm_(net.parameters(), gradients_norm)
            
            # Обновляем параметры, выполняя шаг обучения
            optimizer.step()

            if iteration % 100 == 0:
                print('Epoch: {}/{}'.format(e, 200),'Iteration: {}'.format(iteration),'Loss: {}'.format(loss_value))

            # if iteration % 1000 == 0:
                # predict(device, net, flags.initial_words, n_vocab,vocab_to_int, int_to_vocab, top_k=5)
                # torch.save(net.state_dict(),'checkpoint_pt/model-{}.pth'.format(iteration))
                
    return net

## 2.2. Загрузка и предобработка данных

In [15]:
doc = pd.read_csv('/kaggle/input/the-simpsons-dataset/simpsons_script_lines.csv', usecols=['normalized_text'],low_memory=False)['normalized_text'].astype(str).to_list()

In [16]:
doc[:5]

['no actually it was a little of both sometimes when a disease is in all the magazines and all the news shows its only natural that you think you have it',
 'wheres mr bergstrom',
 'i dont know although id sure like to talk to him he didnt touch my lesson plan what did he teach you',
 'that life is worth living',
 'the polls will be open from now until the end of recess now just in case any of you have decided to put any thought into this well have our final statements martin']

In [17]:
# получаем мешок слов, удаляем пунктуацию
words = removepunct(doc2words(doc))
# Словарь из мешка слов
vocab = getvocab(words)
# 2 словаря - int_to_vocab и vocab_to_int
int_to_vocab, vocab_to_int = vocab_map(vocab)

## 2.3. Обучение модели

In [18]:
rnn_net = train_rnn(words, vocab_to_int, int_to_vocab, len(vocab))

  0%|          | 0/50 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Epoch: 0/200 Iteration: 100 Loss: 6.869081974029541
Epoch: 0/200 Iteration: 200 Loss: 6.405984401702881
Epoch: 0/200 Iteration: 300 Loss: 6.204434871673584
Epoch: 0/200 Iteration: 400 Loss: 6.468569278717041


0it [00:00, ?it/s]

Epoch: 1/200 Iteration: 500 Loss: 6.069814682006836
Epoch: 1/200 Iteration: 600 Loss: 5.84013557434082
Epoch: 1/200 Iteration: 700 Loss: 5.847909450531006
Epoch: 1/200 Iteration: 800 Loss: 5.760141372680664


0it [00:00, ?it/s]

Epoch: 2/200 Iteration: 900 Loss: 5.641824245452881
Epoch: 2/200 Iteration: 1000 Loss: 5.5812554359436035
Epoch: 2/200 Iteration: 1100 Loss: 5.646803379058838
Epoch: 2/200 Iteration: 1200 Loss: 5.664811611175537


0it [00:00, ?it/s]

Epoch: 3/200 Iteration: 1300 Loss: 5.701373100280762
Epoch: 3/200 Iteration: 1400 Loss: 5.524514675140381
Epoch: 3/200 Iteration: 1500 Loss: 5.383044242858887
Epoch: 3/200 Iteration: 1600 Loss: 5.50910758972168


0it [00:00, ?it/s]

Epoch: 4/200 Iteration: 1700 Loss: 5.548447132110596
Epoch: 4/200 Iteration: 1800 Loss: 5.330428600311279
Epoch: 4/200 Iteration: 1900 Loss: 5.406790256500244
Epoch: 4/200 Iteration: 2000 Loss: 5.339992046356201


0it [00:00, ?it/s]

Epoch: 5/200 Iteration: 2100 Loss: 5.266122341156006
Epoch: 5/200 Iteration: 2200 Loss: 5.289240837097168
Epoch: 5/200 Iteration: 2300 Loss: 5.2405686378479
Epoch: 5/200 Iteration: 2400 Loss: 5.138582706451416


0it [00:00, ?it/s]

Epoch: 6/200 Iteration: 2500 Loss: 5.112853527069092
Epoch: 6/200 Iteration: 2600 Loss: 5.232212066650391
Epoch: 6/200 Iteration: 2700 Loss: 5.255558490753174
Epoch: 6/200 Iteration: 2800 Loss: 5.176302433013916
Epoch: 6/200 Iteration: 2900 Loss: 5.279224395751953


0it [00:00, ?it/s]

Epoch: 7/200 Iteration: 3000 Loss: 5.078824043273926
Epoch: 7/200 Iteration: 3100 Loss: 5.180334568023682
Epoch: 7/200 Iteration: 3200 Loss: 5.0275163650512695
Epoch: 7/200 Iteration: 3300 Loss: 5.085385799407959


0it [00:00, ?it/s]

Epoch: 8/200 Iteration: 3400 Loss: 5.228656768798828
Epoch: 8/200 Iteration: 3500 Loss: 5.098645210266113
Epoch: 8/200 Iteration: 3600 Loss: 5.018154621124268
Epoch: 8/200 Iteration: 3700 Loss: 4.914144992828369


0it [00:00, ?it/s]

Epoch: 9/200 Iteration: 3800 Loss: 4.772495746612549
Epoch: 9/200 Iteration: 3900 Loss: 5.101226806640625
Epoch: 9/200 Iteration: 4000 Loss: 4.939760684967041
Epoch: 9/200 Iteration: 4100 Loss: 4.945730209350586


0it [00:00, ?it/s]

Epoch: 10/200 Iteration: 4200 Loss: 4.9937310218811035
Epoch: 10/200 Iteration: 4300 Loss: 4.986006736755371
Epoch: 10/200 Iteration: 4400 Loss: 4.941694259643555
Epoch: 10/200 Iteration: 4500 Loss: 4.925002574920654


0it [00:00, ?it/s]

Epoch: 11/200 Iteration: 4600 Loss: 4.966299057006836
Epoch: 11/200 Iteration: 4700 Loss: 5.027177810668945
Epoch: 11/200 Iteration: 4800 Loss: 4.98902702331543
Epoch: 11/200 Iteration: 4900 Loss: 4.776034355163574


0it [00:00, ?it/s]

Epoch: 12/200 Iteration: 5000 Loss: 4.8873090744018555
Epoch: 12/200 Iteration: 5100 Loss: 4.896315574645996
Epoch: 12/200 Iteration: 5200 Loss: 4.835824489593506
Epoch: 12/200 Iteration: 5300 Loss: 4.937067985534668
Epoch: 12/200 Iteration: 5400 Loss: 4.782135486602783


0it [00:00, ?it/s]

Epoch: 13/200 Iteration: 5500 Loss: 4.973174095153809
Epoch: 13/200 Iteration: 5600 Loss: 4.901470184326172
Epoch: 13/200 Iteration: 5700 Loss: 4.9030375480651855
Epoch: 13/200 Iteration: 5800 Loss: 4.939515113830566


0it [00:00, ?it/s]

Epoch: 14/200 Iteration: 5900 Loss: 4.8842973709106445
Epoch: 14/200 Iteration: 6000 Loss: 4.9206438064575195
Epoch: 14/200 Iteration: 6100 Loss: 4.7937211990356445
Epoch: 14/200 Iteration: 6200 Loss: 4.866641044616699


0it [00:00, ?it/s]

Epoch: 15/200 Iteration: 6300 Loss: 4.907231330871582
Epoch: 15/200 Iteration: 6400 Loss: 4.793619155883789
Epoch: 15/200 Iteration: 6500 Loss: 4.860086441040039
Epoch: 15/200 Iteration: 6600 Loss: 4.855235576629639


0it [00:00, ?it/s]

Epoch: 16/200 Iteration: 6700 Loss: 4.789122104644775
Epoch: 16/200 Iteration: 6800 Loss: 4.9032793045043945
Epoch: 16/200 Iteration: 6900 Loss: 4.738424777984619
Epoch: 16/200 Iteration: 7000 Loss: 4.858748912811279


0it [00:00, ?it/s]

Epoch: 17/200 Iteration: 7100 Loss: 4.733179092407227
Epoch: 17/200 Iteration: 7200 Loss: 4.9205732345581055
Epoch: 17/200 Iteration: 7300 Loss: 4.815291404724121
Epoch: 17/200 Iteration: 7400 Loss: 4.702585220336914


0it [00:00, ?it/s]

Epoch: 18/200 Iteration: 7500 Loss: 4.7101640701293945
Epoch: 18/200 Iteration: 7600 Loss: 4.768651008605957
Epoch: 18/200 Iteration: 7700 Loss: 4.832699775695801
Epoch: 18/200 Iteration: 7800 Loss: 4.825465202331543
Epoch: 18/200 Iteration: 7900 Loss: 4.751956939697266


0it [00:00, ?it/s]

Epoch: 19/200 Iteration: 8000 Loss: 4.773531913757324
Epoch: 19/200 Iteration: 8100 Loss: 4.769969463348389
Epoch: 19/200 Iteration: 8200 Loss: 4.692863941192627
Epoch: 19/200 Iteration: 8300 Loss: 4.690934658050537


0it [00:00, ?it/s]

Epoch: 20/200 Iteration: 8400 Loss: 4.747724533081055
Epoch: 20/200 Iteration: 8500 Loss: 4.774134159088135
Epoch: 20/200 Iteration: 8600 Loss: 4.844317436218262
Epoch: 20/200 Iteration: 8700 Loss: 4.7252516746521


0it [00:00, ?it/s]

Epoch: 21/200 Iteration: 8800 Loss: 4.790762901306152
Epoch: 21/200 Iteration: 8900 Loss: 4.8251543045043945
Epoch: 21/200 Iteration: 9000 Loss: 4.679671287536621
Epoch: 21/200 Iteration: 9100 Loss: 4.8052215576171875


0it [00:00, ?it/s]

Epoch: 22/200 Iteration: 9200 Loss: 4.723508358001709
Epoch: 22/200 Iteration: 9300 Loss: 4.905977249145508
Epoch: 22/200 Iteration: 9400 Loss: 4.746315956115723
Epoch: 22/200 Iteration: 9500 Loss: 4.700112819671631


0it [00:00, ?it/s]

Epoch: 23/200 Iteration: 9600 Loss: 4.636731147766113
Epoch: 23/200 Iteration: 9700 Loss: 4.866511344909668
Epoch: 23/200 Iteration: 9800 Loss: 4.728062629699707
Epoch: 23/200 Iteration: 9900 Loss: 4.766511917114258


0it [00:00, ?it/s]

Epoch: 24/200 Iteration: 10000 Loss: 4.6858367919921875
Epoch: 24/200 Iteration: 10100 Loss: 4.721190452575684
Epoch: 24/200 Iteration: 10200 Loss: 4.753899097442627
Epoch: 24/200 Iteration: 10300 Loss: 4.6952805519104
Epoch: 24/200 Iteration: 10400 Loss: 4.6522536277771


0it [00:00, ?it/s]

Epoch: 25/200 Iteration: 10500 Loss: 4.614972114562988
Epoch: 25/200 Iteration: 10600 Loss: 4.725205898284912
Epoch: 25/200 Iteration: 10700 Loss: 4.704654216766357
Epoch: 25/200 Iteration: 10800 Loss: 4.645296096801758


0it [00:00, ?it/s]

Epoch: 26/200 Iteration: 10900 Loss: 4.765614032745361
Epoch: 26/200 Iteration: 11000 Loss: 4.705120086669922
Epoch: 26/200 Iteration: 11100 Loss: 4.76780891418457
Epoch: 26/200 Iteration: 11200 Loss: 4.713863372802734


0it [00:00, ?it/s]

Epoch: 27/200 Iteration: 11300 Loss: 4.775000095367432
Epoch: 27/200 Iteration: 11400 Loss: 4.703672409057617
Epoch: 27/200 Iteration: 11500 Loss: 4.683303356170654
Epoch: 27/200 Iteration: 11600 Loss: 4.742982864379883


0it [00:00, ?it/s]

Epoch: 28/200 Iteration: 11700 Loss: 4.706303596496582
Epoch: 28/200 Iteration: 11800 Loss: 4.771758079528809
Epoch: 28/200 Iteration: 11900 Loss: 4.7174482345581055
Epoch: 28/200 Iteration: 12000 Loss: 4.758885383605957


0it [00:00, ?it/s]

Epoch: 29/200 Iteration: 12100 Loss: 4.770266532897949
Epoch: 29/200 Iteration: 12200 Loss: 4.644073963165283
Epoch: 29/200 Iteration: 12300 Loss: 4.6387939453125
Epoch: 29/200 Iteration: 12400 Loss: 4.707282066345215


0it [00:00, ?it/s]

Epoch: 30/200 Iteration: 12500 Loss: 4.622931480407715
Epoch: 30/200 Iteration: 12600 Loss: 4.818569660186768
Epoch: 30/200 Iteration: 12700 Loss: 4.575712203979492
Epoch: 30/200 Iteration: 12800 Loss: 4.59697151184082


0it [00:00, ?it/s]

Epoch: 31/200 Iteration: 12900 Loss: 4.668330192565918
Epoch: 31/200 Iteration: 13000 Loss: 4.798602104187012
Epoch: 31/200 Iteration: 13100 Loss: 4.755324840545654
Epoch: 31/200 Iteration: 13200 Loss: 4.758110523223877
Epoch: 31/200 Iteration: 13300 Loss: 4.680639743804932


0it [00:00, ?it/s]

Epoch: 32/200 Iteration: 13400 Loss: 4.627823829650879
Epoch: 32/200 Iteration: 13500 Loss: 4.730800628662109
Epoch: 32/200 Iteration: 13600 Loss: 4.643282413482666
Epoch: 32/200 Iteration: 13700 Loss: 4.691068172454834


0it [00:00, ?it/s]

Epoch: 33/200 Iteration: 13800 Loss: 4.825400352478027
Epoch: 33/200 Iteration: 13900 Loss: 4.728270053863525
Epoch: 33/200 Iteration: 14000 Loss: 4.63115930557251
Epoch: 33/200 Iteration: 14100 Loss: 4.5597124099731445


0it [00:00, ?it/s]

Epoch: 34/200 Iteration: 14200 Loss: 4.429716110229492
Epoch: 34/200 Iteration: 14300 Loss: 4.742427825927734
Epoch: 34/200 Iteration: 14400 Loss: 4.666358470916748
Epoch: 34/200 Iteration: 14500 Loss: 4.547401428222656


0it [00:00, ?it/s]

Epoch: 35/200 Iteration: 14600 Loss: 4.732007026672363
Epoch: 35/200 Iteration: 14700 Loss: 4.689682960510254
Epoch: 35/200 Iteration: 14800 Loss: 4.686126232147217
Epoch: 35/200 Iteration: 14900 Loss: 4.694479465484619


0it [00:00, ?it/s]

Epoch: 36/200 Iteration: 15000 Loss: 4.574978828430176
Epoch: 36/200 Iteration: 15100 Loss: 4.757446765899658
Epoch: 36/200 Iteration: 15200 Loss: 4.764246940612793
Epoch: 36/200 Iteration: 15300 Loss: 4.503201484680176


0it [00:00, ?it/s]

Epoch: 37/200 Iteration: 15400 Loss: 4.654892921447754
Epoch: 37/200 Iteration: 15500 Loss: 4.715516567230225
Epoch: 37/200 Iteration: 15600 Loss: 4.608575344085693
Epoch: 37/200 Iteration: 15700 Loss: 4.695188045501709
Epoch: 37/200 Iteration: 15800 Loss: 4.627490997314453


0it [00:00, ?it/s]

Epoch: 38/200 Iteration: 15900 Loss: 4.772719383239746
Epoch: 38/200 Iteration: 16000 Loss: 4.652306079864502
Epoch: 38/200 Iteration: 16100 Loss: 4.697897434234619
Epoch: 38/200 Iteration: 16200 Loss: 4.692655086517334


0it [00:00, ?it/s]

Epoch: 39/200 Iteration: 16300 Loss: 4.689535140991211
Epoch: 39/200 Iteration: 16400 Loss: 4.694490909576416
Epoch: 39/200 Iteration: 16500 Loss: 4.531820774078369
Epoch: 39/200 Iteration: 16600 Loss: 4.689818859100342


0it [00:00, ?it/s]

Epoch: 40/200 Iteration: 16700 Loss: 4.725838661193848
Epoch: 40/200 Iteration: 16800 Loss: 4.622923374176025
Epoch: 40/200 Iteration: 16900 Loss: 4.626535415649414
Epoch: 40/200 Iteration: 17000 Loss: 4.661365032196045


0it [00:00, ?it/s]

Epoch: 41/200 Iteration: 17100 Loss: 4.595324516296387
Epoch: 41/200 Iteration: 17200 Loss: 4.745365619659424
Epoch: 41/200 Iteration: 17300 Loss: 4.582433700561523
Epoch: 41/200 Iteration: 17400 Loss: 4.705770492553711


0it [00:00, ?it/s]

Epoch: 42/200 Iteration: 17500 Loss: 4.574644088745117
Epoch: 42/200 Iteration: 17600 Loss: 4.798408031463623
Epoch: 42/200 Iteration: 17700 Loss: 4.665104866027832
Epoch: 42/200 Iteration: 17800 Loss: 4.550373077392578


0it [00:00, ?it/s]

Epoch: 43/200 Iteration: 17900 Loss: 4.568779468536377
Epoch: 43/200 Iteration: 18000 Loss: 4.609103679656982
Epoch: 43/200 Iteration: 18100 Loss: 4.698373317718506
Epoch: 43/200 Iteration: 18200 Loss: 4.659953594207764
Epoch: 43/200 Iteration: 18300 Loss: 4.608465671539307


0it [00:00, ?it/s]

Epoch: 44/200 Iteration: 18400 Loss: 4.647149562835693
Epoch: 44/200 Iteration: 18500 Loss: 4.610274791717529
Epoch: 44/200 Iteration: 18600 Loss: 4.570535659790039
Epoch: 44/200 Iteration: 18700 Loss: 4.589392185211182


0it [00:00, ?it/s]

Epoch: 45/200 Iteration: 18800 Loss: 4.513062953948975
Epoch: 45/200 Iteration: 18900 Loss: 4.633333206176758
Epoch: 45/200 Iteration: 19000 Loss: 4.693425178527832
Epoch: 45/200 Iteration: 19100 Loss: 4.597992897033691


0it [00:00, ?it/s]

Epoch: 46/200 Iteration: 19200 Loss: 4.661722183227539
Epoch: 46/200 Iteration: 19300 Loss: 4.687108039855957
Epoch: 46/200 Iteration: 19400 Loss: 4.549701690673828
Epoch: 46/200 Iteration: 19500 Loss: 4.683351516723633


0it [00:00, ?it/s]

Epoch: 47/200 Iteration: 19600 Loss: 4.582718849182129
Epoch: 47/200 Iteration: 19700 Loss: 4.7888383865356445
Epoch: 47/200 Iteration: 19800 Loss: 4.606385707855225
Epoch: 47/200 Iteration: 19900 Loss: 4.612706184387207


0it [00:00, ?it/s]

Epoch: 48/200 Iteration: 20000 Loss: 4.541013717651367
Epoch: 48/200 Iteration: 20100 Loss: 4.77940034866333
Epoch: 48/200 Iteration: 20200 Loss: 4.587284564971924
Epoch: 48/200 Iteration: 20300 Loss: 4.672327518463135


0it [00:00, ?it/s]

Epoch: 49/200 Iteration: 20400 Loss: 4.595296382904053
Epoch: 49/200 Iteration: 20500 Loss: 4.642980575561523
Epoch: 49/200 Iteration: 20600 Loss: 4.656783580780029
Epoch: 49/200 Iteration: 20700 Loss: 4.607407093048096
Epoch: 49/200 Iteration: 20800 Loss: 4.557277679443359


In [23]:
PATH = '/kaggle/working/rnn_model'
torch.save(rnn_net.state_dict(), PATH)

In [None]:
# rnn_model.load_state_dict(torch.load(PATH))

## 2.4. Генерация реплик

In [19]:
generate_text(device, rnn_net, ['hey', 'you'], len(vocab), vocab_to_int, int_to_vocab)

hey you whatcha chong are a big deal well well it i want a man who said that was the only thing that you can see you in my arms out nan nan oh homer i think we got the lead and if we were raptured or do not have some other way out with the kids with the most heinous owner upon a coal of the state capitals first nan oh yeah yeah yeah i love you nan oh oh no you dont want a lot more about a pleasanttasting park highway embankments theres no way you want to do it oh
