In [1]:
from data import get_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from typing import Iterable, List
from tqdm import tqdm
import torch

from torchtext.datasets import Multi30k

In [2]:
dataset = get_dataset()

Found cached dataset wmt14 (/home/tingchen/.cache/huggingface/datasets/wmt14/de-en/1.0.0/2de185b074515e97618524d69f5e27ee7545dcbed4aa9bc1a4235710ffca33f4)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 4508785
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 3003
    })
})


In [2]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

token_transform = {}
vocab_transform = {}


token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

In [3]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']


In [6]:
# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        #yield token_transform[language](data_sample['translation'][language])
        yield token_transform[language](data_sample[language_index[language]])

In [7]:
for ln in tqdm([SRC_LANGUAGE, TGT_LANGUAGE]):
    # Training data Iterator
    #train_iter = dataset['train']
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=20,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

100%|██████████| 2/2 [00:04<00:00,  2.08s/it]


In [8]:
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    torch.save(vocab_transform[ln], f'{ln}_vocab_multi30k.vocab')

In [13]:
assert len(vocab_transform['en']) == len(torch.load('en_vocab_wmt14.vocab'))
assert len(vocab_transform['de']) == len(torch.load('de_vocab_wmt14.vocab'))