In [14]:
import torch
import torch.nn as nn

from tqdm import tqdm
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader, random_split
from utils import tokenizer, Collator
from data import AGNewsDataset
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
agnews = AGNewsDataset(text="Description")

In [4]:
len(agnews)

120000

In [5]:
agnews_train, agnews_valid = random_split(agnews, lengths=[30000, 90000])

In [6]:
agnews_train[0]

(1,
 'Ever since the start of the second Chechen war, the Russian leadership, and President Vladimir Putin in particular, have been terribly fond of saying that we are waging war against international terrorism in Chechnya.')

In [7]:
import re

def clean_text(text: str):
    text = text.lower().strip()
    text = re.sub(r"([.!?])", r" \1", text)
    text = re.sub(r"[^a-zA-Z!?]+", r" ", text)
    return text

def yield_tokens(dataset, tokenizer=tokenizer):
    for row in tqdm(dataset):
        yield tokenizer(clean_text(row[-1]))

In [8]:
PAD_IDX, UNK_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ["<pad>", "<unk>", "<bos>", "<eos>"]
vocab = build_vocab_from_iterator(
    yield_tokens(agnews_train), min_freq=1, specials=special_symbols, special_first=True)
vocab.set_default_index(UNK_IDX)

100%|██████████| 30000/30000 [00:04<00:00, 7260.98it/s]


In [9]:
vocab(['here', 'is', 'an', 'example'])

[383, 19, 24, 5249]

In [10]:
len(vocab)

34639

In [11]:
text_1 = clean_text(agnews_train[101][-1])
print("raw text:", text_1)
print("tokenize text:", tokenizer(text_1))
print("token to idx:", vocab(tokenizer(text_1)))

raw text: the blue chip hang seng index slipped points or percent to on monday the index had gained points or percent 
tokenize text: ['the', 'blue', 'chip', 'hang', 'seng', 'index', 'slipped', 'points', 'or', 'percent', 'to', 'on', 'monday', 'the', 'index', 'had', 'gained', 'points', 'or', 'percent']
token to idx: [4, 861, 592, 5744, 11768, 1930, 1966, 295, 102, 94, 6, 11, 47, 4, 1930, 75, 2398, 295, 102, 94]


In [12]:
collate = Collator(vocab=vocab)

In [13]:
train_loader = DataLoader(agnews_train, batch_size=32, shuffle=True, collate_fn=collate)
valid_loader = DataLoader(agnews_train, batch_size=64, shuffle=False, collate_fn=collate)

In [16]:
from model import EncoderRNN

encoder = EncoderRNN(
    embedding=nn.Embedding(len(vocab), 256, padding_idx=0),
    embedding_size=256,
    hidden_size=128,
    device=device,
    num_layers=2
)
_, encoder_inputs = next(iter(train_loader))

In [17]:
encoder(encoder_inputs)[1].shape

torch.Size([2, 32, 128])