Tutorial URL: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

# Translation with a Sequence to Sequence Network and Attention

In [64]:
# Preamble

from io import open
import unicodedata
import string
import re
import random
import os

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [65]:
data_folder = "../../../data/pytorchTutorials/data"

# verify that data_folder contains eng-fra.txt
os.listdir(data_folder)

['eng-fra.txt', 'names']

In [66]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
# s = u'What are ? you .!?-_^&* doing - in vikas-bahriwani. Málaga'
# normalizeString(s)
# Output: 'what are ? you . ! ? doing in vikas bahriwani . malaga'
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [67]:
def readLangs(lang1, lang2, data_folder, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open(os.path.join(data_folder,'%s-%s.txt') % (lang1, lang2), encoding='utf-8').read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

In [68]:
# Since there are a lot of example sentences and we want to train something quickly, 
# we’ll trim the data set to only relatively short and simple sentences. 
# Here the maximum length is 10 words (that includes ending punctuation) and we’re filtering to sentences
# that translate to the form “I am” or “He is” etc. (accounting for apostrophes replaced earlier) - aka simple sentences

MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [69]:
SOS_TOKEN = 0
EOS_TOKEN = 1

class Lang:
    
    def __init__(self, name):
        self.name = name
        self.word2index = {"SOS": SOS_TOKEN, "EOS": EOS_TOKEN}
        self.word2count = {"SOS": 1, "EOS": 1}
        self.index2word = {SOS_TOKEN: "SOS", EOS_TOKEN: "EOS"}
        self.n_words = 2
    
    def add_sentence(self, sentence):
        for word in sentence.split(" "):
            self.add_word(word)
    
    def add_word(self, word):
        if word in self.word2index:
            self.word2count[word] += 1
        else:
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.word2index[word] = self.n_words
            self.n_words += 1

In [73]:
# Steps to preparing the data
# Read text file and split into lines, split lines into pairs, normalize text, 
# filter by length and content
# Make word lists from sentences in pairs
# NO stemming or lemmatization 

def prepareData(lang1, lang2, data_folder, reverse=False):
    
    # Read text file and split into lines, split lines into pairs
    # Also normalize text which here means - lower case, remove non-letter characters i.e. characters other than .!? 
    input_lang, output_lang, pairs = readLangs(lang1, lang2, data_folder, reverse)
    print("Read %s sentence pairs" % len(pairs))
    
    # filter by length and content
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.add_sentence(pair[0])
        output_lang.add_sentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


input_lang, output_lang, pairs = prepareData('eng', 'fra', data_folder, True)
print(random.choice(pairs))

Reading lines...
Read 135842 sentence pairs
Trimmed to 10853 sentence pairs
Counting words...
Counted words:
fra 4489
eng 2925
['tu vas te remettre .', 'you re going to be ok .']
