In [32]:
import torch
import os
import collections
from collections import defaultdict
from torch.utils import data

In [28]:
def read_data_nmt(data_path):
    """Load the English-French dataset."""
    with open(data_path, 'r', encoding = 'utf-8') as f:
        return f.read()

raw_text = read_data_nmt('../data/fra-eng/fra.txt') #tab separated string
print(raw_text[:20])

Go.	Va !
Hi.	Salut !


After downloading the dataset, we proceed with several preprocessing steps for the raw text data. 
For instance, we replace non-breaking space with space, convert uppercase letters to lowercase ones, and 
insert space between words and punctuation marks.

In [11]:
def preprocess_nmt(text):
    """Preprocess the English-French dataset."""
    def no_space(char, prev_char):
        return char in set(',.!?') and prev_char != ' '

    # Replace non-breaking space with space, and convert uppercase letters to
    # lowercase ones
    text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    # Insert space between words and punctuation marks
    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
           for i, char in enumerate(text)]
    return ''.join(out)

text = preprocess_nmt(raw_text)
print(text[:80])

go .	va !
hi .	salut !
run !	cours !
run !	courez !
who ?	qui ?
wow !	ça alors !


Different from character-level tokenization in Section 8.3, for machine translation we prefer word-level tokenization here (state-of-the-art models may use more advanced tokenization techniques). The following tokenize_nmt function tokenizes the the first num_examples text sequence pairs, where each token is either a word or a punctuation mark. This function returns two lists of token lists: source and target. Specifically, source[i] is a list of tokens from the  ith  text sequence in the source language (English here) and target[i] is that in the target language (French here).

In [18]:
def tokenize_nmt(text, num_examples=None):
    """Tokenize the English-French dataset."""
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if num_examples and i > num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source, target

source, target = tokenize_nmt(text)
print(source[10], target[10])
print(source[20], target[20])


['stop', '!'] ['stop', '!']
['i', 'try', '.'] ["j'essaye", '.']


### Vocabulary

Since the machine translation dataset consists of pairs of languages, we can build two vocabularies for 
both the source language and the target language separately. 
With word-level tokenization, the vocabulary size will be significantly larger than that using character-level 
tokenization. To alleviate this, here we treat infrequent tokens that appear less than 2 times as the 
same unknown (\<unk\>) token. Besides that, we specify additional special tokens s
uch as for padding (\<pad\>) sequences to the same length in minibatches, and for marking the 
beginning (\<bos\>) or end (\<eos\>) of sequences. 
Such special tokens are commonly used in natural language processing tasks.

In [24]:
class Vocab:
    """Vocabulary for text."""
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        if tokens is None:
            tokens = []
        if reserved_tokens is None:
            reserved_tokens = [] 
        # Sort according to frequencies
        counter = count_corpus(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                  reverse=True)
        # The index for the unknown token is 0
        self.unk, uniq_tokens = 0, ['<unk>'] + reserved_tokens
        uniq_tokens += [token for token, freq in self.token_freqs
                        if freq >= min_freq and token not in uniq_tokens]
        self.idx_to_token, self.token_to_idx = [], dict()
        for token in uniq_tokens:
            self.idx_to_token.append(token)
            self.token_to_idx[token] = len(self.idx_to_token) - 1

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]


def count_corpus(tokens):
    """Count token frequencies."""
    # Here `tokens` is a 1D list or 2D list
    if len(tokens) == 0 or isinstance(tokens[0], list):
        # Flatten a list of token lists into a list of tokens
        tokens = [token for line in tokens for token in line]
    return collections.Counter(tokens)

In [25]:
src_vocab = Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
len(src_vocab)

10012

In machine translation, each example is a pair of source and target text sequences, where each text sequence may have different lengths.

For computational efficiency, we can still process a minibatch of text sequences at one time by truncation and padding. Suppose that every sequence in the same minibatch should have the same length num_steps. If a text sequence has fewer than num_steps tokens, we will keep appending the special \<pad\> token to its end until its length reaches num_steps. Otherwise, we will truncate the text sequence by only taking its first num_steps tokens and discarding the remaining. In this way, every text sequence will have the same length to be loaded in minibatches of the same shape.

In [26]:
def truncate_pad(line, num_steps, padding_token):
    """Truncate or pad sequences."""
    if len(line) > num_steps:
        return line[:num_steps]  # Truncate
    return line + [padding_token] * (num_steps - len(line))  # Pad


truncate_pad(src_vocab[source[0]], 10, src_vocab['<pad>']

[47, 4, 1, 1, 1, 1, 1, 1, 1, 1]

Now we define a function to transform text sequences into minibatches for training. We append the special \<eos\> token to the end of every sequence to indicate the end of the sequence. When a model is predicting by generating a sequence token after token, the generation of the \<eos>\ token can suggest that the output sequence is complete. Besides, we also record the length of each text sequence excluding the padding tokens.

In [35]:
def load_array(data_arrays, batch_size, is_train=True):
    """Construct a PyTorch data iterator."""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

def build_array_nmt(lines, vocab, num_steps):
    """Transform text sequences of machine translation into minibatches."""
    lines = [vocab[l] for l in lines]
    lines = [l + [vocab['<eos>']] for l in lines]
    array = torch.tensor(
        [truncate_pad(l, num_steps, vocab['<pad>']) for l in lines])
    valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
    return array, valid_len

def load_data_nmt(data_path,batch_size, num_steps, num_examples=600):
    """Return the iterator and the vocabularies of the translation dataset."""
    text = preprocess_nmt(read_data_nmt(data_path))
    source, target = tokenize_nmt(text, num_examples)
    src_vocab = Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = Vocab(target, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    data_iter = load_array(data_arrays, batch_size)
    return data_iter, src_vocab, tgt_vocab

data_path = '../data/fra-eng/fra.txt'
train_iter, src_vocab, tgt_vocab = load_data_nmt(data_path,batch_size=2, num_steps=8)
for X, X_valid_len, Y, Y_valid_len in train_iter:
    print('Batch of 2 sentences:')
    print('X:', X.type(torch.int32))
    print('valid lengths for X:', X_valid_len)
    print('Y:', Y.type(torch.int32))
    print('valid lengths for Y:', Y_valid_len)
    break

Batch of 2 sentences:
X: tensor([[ 6,  0,  4,  3,  1,  1,  1,  1],
        [36,  5,  3,  1,  1,  1,  1,  1]], dtype=torch.int32)
valid lengths for X: tensor([4, 3])
Y: tensor([[10,  0,  4,  3,  1,  1,  1,  1],
        [15,  0,  5,  3,  1,  1,  1,  1]], dtype=torch.int32)
valid lengths for Y: tensor([4, 4])


In [47]:
print(src_vocab.token_to_idx)

{'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, '.': 4, '!': 5, 'i': 6, "i'm": 7, 'it': 8, 'go': 9, 'tom': 10, '?': 11, 'me': 12, 'get': 13, 'be': 14, 'up': 15, 'come': 16, 'we': 17, 'am': 18, 'this': 19, 'lost': 20, 'on': 21, 'won': 22, 'us': 23, "it's": 24, 'down': 25, 'no': 26, 'nice': 27, 'away': 28, 'you': 29, 'back': 30, 'try': 31, 'way': 32, 'fair': 33, 'out': 34, 'lazy': 35, 'help': 36, 'hold': 37, 'off': 38, 'grab': 39, 'how': 40, 'who': 41, 'got': 42, 'calm': 43, 'call': 44, 'he': 45, 'a': 46, 'good': 47, 'job': 48, 'did': 49, 'use': 50, 'over': 51, "don't": 52, 'forget': 53, 'run': 54, 'in': 55, 'home': 56, 'fun': 57, "he's": 58, 'sure': 59, 'here': 60, 'stop': 61, 'cool': 62, 'drive': 63, 'fat': 64, 'shut': 65, 'wake': 66, 'leave': 67, 'sit': 68, 'can': 69, 'fire': 70, 'cheers': 71, 'now': 72, 'left': 73, 'ok': 74, 'ask': 75, 'drop': 76, 'hang': 77, "i'll": 78, 'keep': 79, 'tell': 80, 'him': 81, 'ahead': 82, 'hurry': 83, 'fine': 84, 'died': 85, 'taste': 86, 'they': 87, 'wa

In [48]:
print(tgt_vocab.token_to_idx)

{'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, '.': 4, '!': 5, 'je': 6, 'suis': 7, 'tom': 8, '?': 9, "j'ai": 10, 'nous': 11, 'ça': 12, "c'est": 13, 'est': 14, 'à': 15, 'va': 16, 'bien': 17, 'il': 18, 'en': 19, 'soyez': 20, 'j’ai': 21, 'pas': 22, 'un': 23, 'qui': 24, 'gagné': 25, 'sois': 26, 'me': 27, 'tomber': 28, 'la': 29, 'ne': 30, 'ceci': 31, 'de': 32, 'vais': 33, 'bon': 34, 'venez': 35, 'le': 36, 'chez': 37, "j'en": 38, 'avons': 39, 'calme': 40, 'viens': 41, 'vous': 42, 'a': 43, 'moi': 44, 'au': 45, "l'ai": 46, 'emporté': 47, 'perdu': 48, 'allez': 49, 'plus': 50, 'fait': 51, 'comme': 52, 'ici': 53, 'feu': 54, 'maintenant': 55, 'compris': 56, 'sais': 57, 'gentil': 58, 'dégage': 59, 'malade': 60, 'fûmes': 61, 'été': 62, 'elle': 63, 'assieds-toi': 64, 'salut': 65, 'cours': 66, 'vas-y': 67, 'question': 68, 'juste': 69, 'entrez': 70, 'laisse': 71, 'chercher': 72, 'pars': 73, 'maison': 74, 'tiens': 75, 'tenez': 76, 'fais': 77, 'réveille-toi': 78, 'suis-je': 79, 'trouve': 80, 'trouvez':