# Архитектура RNN. Классификация текста.

Привет! В это семинаре мы познакомимся с задачей классификации текста, на примере поиска тематики новости, а также с двумя из основных архитектурт рекуррентных нейросетей – RNN и GRU.

Нам потребуется одна библиотека от `HuggingFace🤗` под названием `datasets`. Она содержит большое число датасетов, которые используются в NLP.

In [1]:
!pip install datasets



In [2]:
import random
import numpy as np

import nltk
import gensim.downloader as api

import torch
import torch.nn as nn
import datasets

In [3]:
# За детерминизм!
SEED = 0xDEAD
random.seed(SEED)
np.random.seed(SEED)
torch.random.manual_seed(SEED)
torch.cuda.random.manual_seed_all(SEED)

Загрузим датасет новостей: `AgNews`. В нем разделены тексты на 4 темы: `World`, `Sports`, `Business`, `Sci/Tech`. Посмотрим на структуру датасета и на примеры текстов:

In [4]:
dataset = datasets.load_dataset("ag_news")
dataset["train"]

Using custom data configuration default
Reusing dataset ag_news (/Users/antonermak/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


  0%|          | 0/2 [00:00<?, ?it/s]

Dataset({
    features: ['text', 'label'],
    num_rows: 120000
})

In [5]:
dataset["train"][0]

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
 'label': 2}

В `dataset` находятся `train` и `test` части датасета.

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

Чтобы превращать текст из набора слов в набор векторов мы будем использовать предобученные эмбеддинги. Посмотрим на их список и выберем один из них.

In [7]:
print("\n".join(api.info()['models'].keys()))

fasttext-wiki-news-subwords-300
conceptnet-numberbatch-17-06-300
word2vec-ruscorpora-300
word2vec-google-news-300
glove-wiki-gigaword-50
glove-wiki-gigaword-100
glove-wiki-gigaword-200
glove-wiki-gigaword-300
glove-twitter-25
glove-twitter-50
glove-twitter-100
glove-twitter-200
__testing_word2vec-matrix-synopsis


In [8]:
word2vec = api.load("glove-twitter-50")

Токенезируем наш текст с помощью NLTK.

In [9]:
# input values with max length 128
MAX_LENGTH=128

tokenizer = nltk.WordPunctTokenizer()

# item is of type dict()
dataset = dataset.map(
    lambda item: {
        "tokenized": tokenizer.tokenize(item["text"])[:MAX_LENGTH]
    }
)

Loading cached processed dataset at /Users/antonermak/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-573950f35c7d8f18.arrow
Loading cached processed dataset at /Users/antonermak/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-e177471770032cb8.arrow


Создадим мапинг из токенов в индексы. We have access to `word2vec.index2word`, let's have a look what it is.


In [10]:
word2vec.index2word[:20]

['<user>',
 '.',
 ':',
 'rt',
 ',',
 '<repeat>',
 '<hashtag>',
 '<number>',
 '<url>',
 '!',
 'i',
 'a',
 '"',
 'the',
 '?',
 'you',
 'to',
 '(',
 '<allcaps>',
 '<elong>']

So `word2vec.index2word`. is actually not a dictionary, but a list. Lists have order, the the indexes are given implicitely and can be accessed with `enumerate()`.

In [11]:
word2idx = {word: idx for idx, word in enumerate(word2vec.index2word)}

Переведем токены в индексы

In [12]:
def encode(word):
    if word in word2idx.keys():
        return word2idx[word]
    else:
        return word2idx["unk"]

Let's see what happens for unknown words.

In [13]:
word2idx['unk'], encode('rassian'), encode('qwertyuiolkjhgfdsazxcvbnm,.;poiuytrewsdfghjkl.,mnbvcxzawertyui')

(62980, 62980, 62980)

Convert the dataset using our encode function.

In [14]:
# в качестве фичей используются индексы слов - токенизированный текст
dataset = dataset.map(
    lambda item: {
        "features": [encode(word) for word in item["tokenized"]]
    }
)

  0%|          | 0/120000 [00:00<?, ?ex/s]

  0%|          | 0/7600 [00:00<?, ?ex/s]

Look at the result

In [15]:
dataset["train"][0]

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
 'label': 2,
 'tokenized': ['Wall',
  'St',
  '.',
  'Bears',
  'Claw',
  'Back',
  'Into',
  'the',
  'Black',
  '(',
  'Reuters',
  ')',
  'Reuters',
  '-',
  'Short',
  '-',
  'sellers',
  ',',
  'Wall',
  'Street',
  "'",
  's',
  'dwindling',
  '\\',
  'band',
  'of',
  'ultra',
  '-',
  'cynics',
  ',',
  'are',
  'seeing',
  'green',
  'again',
  '.'],
 'features': [62980,
  62980,
  1,
  62980,
  62980,
  62980,
  62980,
  13,
  62980,
  17,
  62980,
  20,
  62980,
  28,
  62980,
  28,
  49286,
  4,
  62980,
  62980,
  48,
  137,
  214902,
  370,
  1645,
  39,
  8606,
  28,
  380053,
  4,
  70,
  1321,
  1745,
  389,
  1]}

In [16]:
dataset["train"][0].keys()

dict_keys(['text', 'label', 'tokenized', 'features'])

We do not need the full text and the tokens stored in `text` and `tokenized` columns

In [17]:
dataset = dataset.remove_columns(["text", "tokenized"])

Переведем в тензоры

In [18]:
dataset.set_format(type='torch')

In [44]:
# получаем тензор фичей и тензор лейблов
# по этим фичам хотим предсказывать
dataset["train"][0]

{'label': tensor(2),
 'features': tensor([ 62980,  62980,      1,  62980,  62980,  62980,  62980,     13,  62980,
             17,  62980,     20,  62980,     28,  62980,     28,  49286,      4,
          62980,  62980,     48,    137, 214902,    370,   1645,     39,   8606,
             28, 380053,      4,     70,   1321,   1745,    389,      1])}

Хотим склеить объекты разной длинны в батчи. Для этого давайте напишем `collate_fn`.

Есть проблема, что все текста разной длины, но нам нужно запихнуть их в один батч <br/>
-> будем добавлять пэд-токены для всех предложений, которые меньше длины максимального предложения в батче на данный момент



In [21]:
def collate_fn(batch):
    max_len = max(len(row["features"]) for row in batch)
    input_embeds = torch.empty((len(batch), max_len), dtype=torch.long)
    labels = torch.empty(len(batch), dtype=torch.long)
    for idx, row in enumerate(batch):
        to_pad = max_len - len(row["features"])
        input_embeds[idx] = torch.cat((row["features"], torch.zeros(to_pad)))
        labels[idx] = row["label"]
    return {"features": input_embeds, "labels": labels}

Now some cool stuff. Recall that `dataset` has 2 keys: `train` and `test`, when training we will shuffle the train set, but we should not shuffle the test set. So we could right 2 different loaders, but we'll do it neater.

In [22]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'features'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['label', 'features'],
        num_rows: 7600
    })
})

In [23]:
from torch.utils.data import DataLoader

# we have 
loaders = {
    k: DataLoader(
        ds, shuffle=(k=="train"), batch_size=32, collate_fn=collate_fn
    ) for k, ds in dataset.items()
}

## CNN

Первая модель, которую мы рассмотрим: CNN. Одномерная конволюция достаточно хорошо справляется с задачей классификации. В конце надо собрать вектор текста с помощью `AdaptiveMaxPool1d` или `AdiptiveAvgPool1d`. Для классиффикации можно собрать любую Feed Forward Network.

In [24]:
class CNNModel(nn.Module):
    def __init__(self, embed_size, hidden_size, num_classes=4):
        super().__init__()
        self.embeddings = nn.Embedding(len(word2idx), embedding_dim=embed_size)
        self.cnn = nn.Sequential(
            nn.Conv1d(embed_size, hidden_size, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),

            nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),

            nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),

            nn.AdaptiveMaxPool1d(1),
            nn.Flatten(),
        )
        self.cl = nn.Sequential(
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):

        # Sequence length - length of all sentences in a batch,
        #   some sentences are padded to match the size of the longest sentence

        print(x.shape)
        x = self.embeddings(x)  # (batch_size, seq_len, embed_dim)
        print(x.shape)

        # we want to convolve over words, right now the word embeddings are
        #   rows in the matrix, but we want words to go from left to right
        #   so we make sure word embeddings are column vectors.
        x = x.permute(0, 2, 1)  # (batch_size, embed_dim, seq_len)

        x = self.cnn(x)
        prediction = self.cl(x)
        return prediction

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CNNModel(word2vec.vector_size, 50).to(device) # потому что в word2vec подавали 50
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

num_epochs = 1

Подготовим функцию для обучения модели:

In [26]:
from tqdm.notebook import tqdm, trange


def training(model, criterion, optimizer, num_epochs, loaders, max_grad_norm=2):
    for e in trange(num_epochs, leave=False):
        model.train()
        num_iter = 0
        pbar = tqdm(loaders["train"], leave=False)
        for batch in pbar:
            optimizer.zero_grad()
            input_embeds = batch["features"].to(device)
            labels = batch["labels"].to(device)
            prediction = model(input_embeds)
            loss = criterion(prediction, labels)
            loss.backward()

            # remove whan finish debugging
            #print(1/0)

            # new step: gradient clipping
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()
            num_iter += 1

        # validation stage during training
        valid_loss = 0
        valid_acc = 0
        num_iter = 0
        model.eval()
        with torch.no_grad():
            correct = 0
            num_objs = 0
            for batch in loaders["test"]:
                input_embeds = batch["features"].to(device)
                labels = batch["labels"].to(device)
                prediction = model(input_embeds)
                valid_loss += criterion(prediction, labels)
                # считаем кол-во правильных ответов, потому что среднее acc будет некорректным - смещенным в сторону последнего батча
                correct += (labels == prediction.argmax(-1)).float().sum()
                num_objs += len(labels)
                num_iter += 1

        print(f"Valid Loss: {valid_loss / num_iter}, accuracy: {correct/num_objs}") 
        

In [27]:
training(model, criterion, optimizer, num_epochs, loaders)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3750 [00:00<?, ?it/s]

torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 100])
torch.Size([32, 100, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 106])
torch.Size([32, 106, 50])
torch.Size([32, 73])
torch.Size([32, 73, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 97])
torch.S

torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 90])
torch.Size([32, 90, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 87])
torch.Size([32, 87, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 90])
torch.Size([32, 90, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 84])
torch.Size([32, 84, 50])
torch.Size([32, 109])
torch.Size([32, 109, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 119])
torch.Size([32, 119, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 85])
torch.Size([32, 85, 50])
torch.Size([32, 67])
t

torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 114])
torch.Size([32, 114, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 70])
torch.Size([32, 70, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 64])
torch.Size([32, 64, 50])
torch.Size([32, 88])
torch.Size([32, 88, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 108])
torch.Size([32, 108, 50])
torch.Size([32, 106])
torch.Size([32, 106, 50])
torch.Size([32, 65])
torch.Size([32, 65, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 

torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 81])
torch.Size([32, 81, 50])
torch.Size([32, 61])
torch.Size([32, 61, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 64])
torch.Size([32, 64, 50])
torch.Size([32, 102])
torch.Size([32, 102, 50])
torch.Size([32, 62])
torch.Size([32, 62, 50])
torch.Size([32, 88])
torch.Size([32, 88, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 73])
torch.Size([32, 73, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 80])
torch.Size([32, 80, 50])
torch.Size([32, 111])
torch.Size([32, 111, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 128])
torc

torch.Size([32, 70])
torch.Size([32, 70, 50])
torch.Size([32, 72])
torch.Size([32, 72, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 55])
torch.Size([32, 55, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 90])
torch.Size([32, 90, 50])
torch.Size([32, 59])
torch.Size([32, 59, 50])
torch.Size([32, 84])
torch.Size([32, 84, 50])
torch.Size([32, 102])
torch.Size([32, 102, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 67])
torch.Size([32, 67, 50])
torch.Size([32, 65])
torch.Size([32, 65, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 105])
torch.Size([32, 105, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 73])
torch.Size([32, 73, 50])
torch.Size([32, 97

torch.Size([32, 88])
torch.Size([32, 88, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 62])
torch.Size([32, 62, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 87])
torch.Size([32, 87, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 72])
torch.Size([32, 72, 50])
torch.Size([32, 105])
torch.Size([32, 105, 50])
torch.Size([32, 80])
torch.Size([32, 80, 50])
torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 64])
torch.Size([32, 64, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 88])
torch.Size([32, 88, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 75])
torch

torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 85])
torch.Size([32, 85, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 81])
torch.Size([32, 81, 50])
torch.Size([32, 88])
torch.Size([32, 88, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 72])
torch.Size([32, 72, 50])
torch.Size([32, 78])
torch.Size([32, 78, 50])
torch.Size([32, 102])
torch.Size([32, 102, 50])
torch.Size([32, 94])
t

torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 67])
torch.Size([32, 67, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 65])
torch.Size([32, 65, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 73])
torch.Size([32, 73, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 85])
torch.Size([32, 85, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 100])
torch.Size([32, 100, 50])
torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 107])
torch.Size([32, 107, 50])
torch.Size([32, 76

torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 104])
torch.Size([32, 104, 50])
torch.Size([32, 111])
torch.Size([32, 111, 50])
torch.Size([32, 100])
torch.Size([32, 100, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 81])
torch.Size([32, 81, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 108])
torch.Size([32, 108, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 102])
torch.Size([32, 102, 50])
torch.Size([32, 84])
torch.Size([32, 84, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 12

torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 70])
torch.Size([32, 70, 50])
torch.Size([32, 73])
torch.Size([32, 73, 50])
torch.Size([32, 90])
torch.Size([32, 90, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 90])
torch.Size([32, 90, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 108])
torch.Size([32, 108, 50])
torch.Size([32, 115])
torch.Size([32, 115, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 67])
torch.Size([32, 67, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 104])
torch.

torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 61])
torch.Size([32, 61, 50])
torch.Size([32, 62])
torch.Size([32, 62, 50])
torch.Size([32, 64])
torch.Size([32, 64, 50])
torch.Size([32, 81])
torch.Size([32, 81, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 100])
torch.Size([32, 100, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 70])
torch.Size([32, 70, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 88])
torch.Size([32, 88, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 57])
torch.Size([32, 57, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 79])
torch

torch.Size([32, 60])
torch.Size([32, 60, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 63])
torch.Size([32, 63, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 104])
torch.Size([32, 104, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 90])
torch.Size([32, 90, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 104])
torch.Size([32, 104, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 111])
torch.Size([32, 111, 50])
torch.Size([32, 72])
torch.Size([32, 72, 50])
torch.Size([32, 93])
torch

torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 112])
torch.Size([32, 112, 50])
torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 80])
torch.Size([32, 80, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 119])
torch.Size([32, 119, 50])
torch.Size([32, 102])
torch.Size([32, 102, 50])
torch.Size([32, 88])
torch.Size([32, 88, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 106])
torch.Size([32, 106, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 105])
torch.Size([32, 105, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 62])
torch.Size([32, 62, 50])
torch.Size([32, 70])
torch.Size([32, 70, 50])
torch.Size([32, 12

torch.Size([32, 66])
torch.Size([32, 66, 50])
torch.Size([32, 67])
torch.Size([32, 67, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 100])
torch.Size([32, 100, 50])
torch.Size([32, 110])
torch.Size([32, 110, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 60])
torch.Size([32, 60, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 84])
torch.Size([32, 84, 50])
torch.Size([32, 57])
torch.Size([32, 57, 50])
torch.Size([32, 84])
torch.Size([32, 84, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 70])
torch.Size([32, 70, 50])
torch.Size([32, 107])
torch.Size([32, 107, 50])
torch.Size([32, 101])
torch.

torch.Size([32, 107])
torch.Size([32, 107, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 115])
torch.Size([32, 115, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 118])
torch.Size([32, 118, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 128]

torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 78])
torch.Size([32, 78, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 100])
torch.Size([32, 100, 50])
torch.Size([32, 59])
torch.Size([32, 59, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 67])
torch.Size([32, 67, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 105])
torch.Size([32, 105, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 65])
torch.Size([32, 65, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 112])
torch.Size([32, 112, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 

torch.Size([32, 109])
torch.Size([32, 109, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 107])
torch.Size([32, 107, 50])
torch.Size([32, 60])
torch.Size([32, 60, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 85])
torch.Size([32, 85, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 111])
torch.Size([32, 111, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 70])
torch.Size([32, 70, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 98])
t

torch.Size([32, 60])
torch.Size([32, 60, 50])
torch.Size([32, 112])
torch.Size([32, 112, 50])
torch.Size([32, 65])
torch.Size([32, 65, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 67])
torch.Size([32, 67, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 62])
torch.Size([32, 62, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 102])
torch.Size([32, 102, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 73])
torch.Size([32, 73, 50])
torch.Size([32, 88])
torch.Size([32, 88, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 98

torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 84])
torch.Size([32, 84, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 87])
torch.Size([32, 87, 50])
torch.Size([32, 73])
torch.Size([32, 73, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 104])
torch.Size([32, 104, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 123])
to

torch.Size([32, 77])
torch.Size([32, 77, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 62])
torch.Size([32, 62, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 108])
torch.Size([32, 108, 50])
torch.Size([32, 87])
torch.Size([32, 87, 50])
torch.Size([32, 90])
torch.Size([32, 90, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 74])
torch.Size([32, 74, 50])
torch.Size([32, 101])
torch.Size([32, 101, 50])
torch.Size([32, 98])
tor

torch.Size([32, 65])
torch.Size([32, 65, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 99])
torch.Size([32, 99, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 104])
torch.Size([32, 104, 50])
torch.Size([32, 68])
torch.Size([32, 68, 50])
torch.Size([32, 89])
torch.Size([32, 89, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 70])
torch.Size([32, 70, 50])
torch.Size([32, 78])
torch.Size([32, 78, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 95])
torch.Size([32, 95, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 65])
torch.S

torch.Size([32, 90])
torch.Size([32, 90, 50])
torch.Size([32, 96])
torch.Size([32, 96, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 67])
torch.Size([32, 67, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 94])
torch.Size([32, 94, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 105])
torch.Size([32, 105, 50])
torch.Size([32, 67])
torch.Size([32, 67, 50])
torch.Size([32, 83])
torch.Size([32, 83, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 79])
torch.Size([32, 79, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 124])
torch.Size([32, 124, 50])
torch.Size([32, 92])
torch.Size([32, 92, 50])
torch.Size([32, 98])
torch.Size([32, 98, 50])
torch.Size([32, 78])
torch.Size([32, 78, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 69])
torch.Size([32, 69, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 94

torch.Size([32, 100])
torch.Size([32, 100, 50])
torch.Size([32, 59])
torch.Size([32, 59, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 103])
torch.Size([32, 103, 50])
torch.Size([32, 93])
torch.Size([32, 93, 50])
torch.Size([32, 60])
torch.Size([32, 60, 50])
torch.Size([32, 62])
torch.Size([32, 62, 50])
torch.Size([32, 76])
torch.Size([32, 76, 50])
torch.Size([32, 91])
torch.Size([32, 91, 50])
torch.Size([32, 87])
torch.Size([32, 87, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 71])
torch.Size([32, 71, 50])
torch.Size([32, 87])
torch.Size([32, 87, 50])
torch.Size([32, 78])
torch.Size([32, 78, 50])
torch.Size([32, 82])
torch.Size([32, 82, 50])
torch.Size([32, 128])
torch.Size([32, 128, 50])
torch.Size([32, 59])
torch.Size([32, 59, 50])
torch.Size([32, 97])
torch.Size([32, 97, 50])
torch.Size([32, 75])
torch.Size([32, 75, 50])
torch.Size([32, 86])
torch.Size([32, 86, 50])
torch.Size([32, 85])
torch.Size([32, 85, 50])
torch.Size([32, 100])
torc

## RNN

Вторая модель: RNN. Это рекуррентная сеть, она использует скрытое состояние из прошлой иттерации для создания нового. Это описывается с помощью формул:

$$
h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
$$

Напишем этот модуль на `Torch`!



In [28]:
class RNN(nn.Module):
    def __init__(self, embed_size, hidden_size):
        super().__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size

        self.w_h = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.b_h = nn.Parameter(torch.rand((1, hidden_size)))
        self.w_x = nn.Parameter(torch.rand(embed_size, hidden_size))
        self.b_x = nn.Parameter(torch.rand(1, hidden_size))

    def forward(self, x, hidden=None):
        '''
        x     : torch.FloatTensor with the shape (bs, seq_length, emb_size)
        hidden: torch.FloatTensro with the shape (bs, hidden_size)
        return: torch.FloatTensor with the shape (bs, hidden_size)
        '''

        # initialise the state in hidden layer and put on the device
        if hidden is None:
            hidden = torch.zeros((x.size(0), self.hidden_size)).to(x.device)

        seq_length = x.size(1) # seq_len

        # go over every word in a sentense
        for cur_idx in range(seq_length):

            # just take the formula and apply, swap (matrix @ col_vec) to (row_vec @ matrix)
            hidden = torch.tanh(
                x[:, cur_idx] @ self.w_x + self.b_x + hidden @ self.w_h + self.b_h
            )
        return hidden

In [29]:
class RNNModel(nn.Module):
    def __init__(self, embed_size, hidden_size, num_classes=4):
        super().__init__()
        self.embeddings = nn.Embedding(len(word2idx), embed_size)
        self.rnn = RNN(embed_size, hidden_size)
        self.cls = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.embeddings(x)
        hidden = self.rnn(x)
        output = self.cls(hidden)
        return output

In [30]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = RNNModel(word2vec.vector_size, 50).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

num_epochs = 1
max_grad_norm = 1.0

In [31]:
training(model, criterion, optimizer, num_epochs, loaders, max_grad_norm)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3750 [00:00<?, ?it/s]

Valid Loss: 1.4489357471466064, accuracy: 0.250394731760025


* Works slower than CNN
* Gives accuracy of the random classifier for 4 classes (0.25) because RNN forgets the context too fast, we need to do something with it - can try LSTM or GRU

## GRU (Gated Recurrent Unit)

https://en.wikipedia.org/wiki/Gated_recurrent_unit

Третья модель: GRU. Она усложненная версия `RNN`. Гланая идея GRU: гейты. Так реализуется "память" модели – она маскирует часть старого скрытого состояния, создавая на этом месте новое. Модель GRU описывается следующим образом:

$$
\begin{array}{ll}
            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
            n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
            h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
        \end{array}
$$

In [32]:
class GRU(nn.Module):
    def __init__(self, embed_size, hidden_size):
        super().__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size

        # reset gate weights and biases
        self.w_rh = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.b_rh = nn.Parameter(torch.rand((1, hidden_size)))
        self.w_rx = nn.Parameter(torch.rand(embed_size, hidden_size))
        self.b_rx = nn.Parameter(torch.rand(1, hidden_size))

        # update gate weights and biases
        self.w_zh = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.b_zh = nn.Parameter(torch.rand((1, hidden_size)))
        self.w_zx = nn.Parameter(torch.rand(embed_size, hidden_size))
        self.b_zx = nn.Parameter(torch.rand(1, hidden_size))

        # candidate (new) weights and biases
        self.w_nh = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.b_nh = nn.Parameter(torch.rand((1, hidden_size)))
        self.w_nx = nn.Parameter(torch.rand(embed_size, hidden_size))
        self.b_nx = nn.Parameter(torch.rand(1, hidden_size))

    def forward(self, x, hidden = None):
        '''
        x     : torch.FloatTensor with the shape (bs, seq_length, emb_size)
        hidden: torch.FloatTensro with the shape (bs, hidden_size)

        return: torch.FloatTensor with the shape (bs, hidden_size)
        '''
        if hidden is None:
            hidden = torch.zeros((x.size(0), self.hidden_size)).to(x.device)
        
        # iterate over words in sentence (elements in sequence)
        # apply the recurrent formula directly changing
        #   (matrix @ col_vec) to (row_vec @ matrix)
        for cur_idx in range(x.size(1)):
            r = torch.sigmoid(
                x[:, cur_idx] @ self.w_rx + self.b_rx + hidden @ self.w_rh + self.b_rh
            )
            z = torch.sigmoid(
                x[:, cur_idx] @ self.w_zx + self.b_zx + hidden @ self.w_zh + self.b_zh
            )
            n = torch.tanh(
                x[:, cur_idx] @ self.w_nx + self.b_nx + r * (hidden @ self.w_nh + self.b_nh)
            )
            hidden = (1 - z) * n + z * hidden

        return hidden

In [33]:
class GRUModel(nn.Module):
    def __init__(self, embed_size, hidden_size, num_classes=4):
        super().__init__()
        self.embed = nn.Embedding(len(word2idx), embed_size)
        self.gru = GRU(embed_size, hidden_size)
        self.cls = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.embed(x)
        hidden = self.gru(x)
        output = self.cls(hidden)
        return output

In [34]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = GRUModel(word2vec.vector_size, 50).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

num_epochs = 1
max_grad_norm = 1.0

In [35]:
training(model, criterion, optimizer, num_epochs, loaders, max_grad_norm)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3750 [00:00<?, ?it/s]

Valid Loss: 0.7939801812171936, accuracy: 0.6873683929443359


* Takes even longer than RNN
* Accuracy already 80%, was 60% during the seminar

## GRU + Embeddings

Мы не просто так загрузили эмбэдинги в начале. Давай использовать их вместо случайной инициализации! Для этого надо немного переделать способ подачи данных в модель и добавить в модель модуль `Embedding`. По-экспериментируем на модели `GRU`.

In [36]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = GRUModel(word2vec.vector_size, 50).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

num_epochs = 1
max_grad_norm = 1.0

In [37]:
with torch.no_grad():
    for word, idx in word2idx.items():
        if word in word2vec:
            model.embed.weight[idx] = torch.from_numpy(word2vec.get_vector(word))

  model.embed.weight[idx] = torch.from_numpy(word2vec.get_vector(word))


In [38]:
training(model, criterion, optimizer, num_epochs, loaders, max_grad_norm)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3750 [00:00<?, ?it/s]

Valid Loss: 0.5263681411743164, accuracy: 0.8235526084899902


* Takes more time than CNN

**BUT**

* Unlike RNN and first version of GR, achieves similar accuracy to CNN

Попробуем заморозить эмбединги на первых итерациях обучения. Это поможет не сильно портить наши эмбединги на первых итерациях.

In [39]:
def freeze_embeddings(model, req_grad=False):
    embeddings = model.embed
    for c_p in embeddings.parameters():
        c_p.requires_grad = req_grad

In [40]:
def training_freeze(model, criterion, optimizer, num_epochs, loaders, max_grad_norm=2, num_freeze_iter=1000):
    
    # new part: freeze embeddings
    freeze_embeddings(model)

    for e in trange(num_epochs, leave=False):
        model.train()
        num_iter = 0
        pbar = tqdm(loaders["train"], leave=False)
        for batch in pbar:
            # new part: freeze embeddings, but only for some part of the first epoch
            if num_iter > num_freeze_iter and e < 1:
                freeze_embeddings(model, True)

            optimizer.zero_grad()
            input_embeds = batch["features"].to(device)
            labels = batch["labels"].to(device)
            prediction = model(input_embeds)
            loss = criterion(prediction, labels)
            loss.backward()

            # new part: clip gradients
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()
            num_iter += 1

        # validation stage during training
        #   note, we do not take the average of averages of batches
        #   instead, we sum the results for batches and then devide by number of samples in all batches
        valid_loss = 0
        valid_acc = 0
        num_iter = 0
        model.eval()
        with torch.no_grad():
            correct = 0
            num_objs = 0
            for batch in loaders["test"]:
                input_embeds = batch["features"].to(device)
                labels = batch["labels"].to(device)
                prediction = model(input_embeds)
                valid_loss += criterion(prediction, labels)
                correct += (labels == prediction.argmax(-1)).float().sum()
                num_objs += len(labels)
                num_iter += 1

        print(f"Valid Loss: {valid_loss / num_iter}, accuracy: {correct/num_objs}")

In [41]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = GRUModel(word2vec.vector_size, 50).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

num_epochs = 1
max_grad_norm = 1.0

In [42]:
with torch.no_grad():
    for word, idx in word2idx.items():
        if word in word2vec:
            model.embed.weight[idx] = torch.from_numpy(word2vec.get_vector(word))

In [43]:
training_freeze(model, criterion, optimizer, num_epochs, loaders, max_grad_norm)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3750 [00:00<?, ?it/s]

Valid Loss: 0.4586261510848999, accuracy: 0.8427631855010986


* Interestingly, the result is worse than the GRU without freezing the embedding layers at first.