In [None]:
use_cuda = True
batch_size = 2
learning_rate = 0.01

# Import library

In [None]:
from __future__ import print_function

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchtext
import torch.nn.utils.rnn as rnn_utils

import numpy as np

import time
import math
import random
import unicodedata
import string
import re
from tqdm import tqdm

import scripts.text
import utils

# Load data

In [None]:
data_path = './processed-data/id.1000/'
en_vocab_path = data_path + 'train.10k.en.vocab'
de_vocab_path = data_path + 'train.10k.de.vocab'

In [None]:
en_words, en_vocab, _ = scripts.text.load_vocab(en_vocab_path)
de_words, de_vocab, _ = scripts.text.load_vocab(de_vocab_path)

In [None]:
class LuongNMTDataset(torchtext.data.Dataset):
    """
        Custom Dataset for Machine Translation dataset based on torchtext's Dataset class.
    """
    
    def __init__(self, src_path, trg_path, fields, MAX_LENGTH=None, **kwargs):
        """
            Arguments:
                src_path (string): path to source language data.
                trg_path (string): path to target language data.
                fields: A tuple containing the fields that will be used for data in each language.
                Remaining keyword arguments: Passed to the constructor of data.Dataset.
        """
        
        if not isinstance(fields[0], (tuple, list)):
            fields = [('src', fields[0]), ('trg', fields[1])]
        
        examples = []
        with open(src_path) as src_file, open(trg_path) as trg_file:
            for src_line, trg_line in tqdm(zip(src_file, trg_file)):
                # src_line = map(int, src_line.strip().split(' '))
                # trg_line = map(int, trg_line.strip().split(' '))
                src_line = src_line.strip().split(' ')
                trg_line = trg_line.strip().split(' ')
                if MAX_LENGTH is not None:
                    if len(src_line) > MAX_LENGTH or len(trg_line) > MAX_LENGTH:
                        continue
                if src_line != '' and trg_line != '':
#                     print(src_line)
                    examples.append(torchtext.data.Example.fromlist([src_line, trg_line], fields))
        
        super(LuongNMTDataset, self).__init__(examples, fields, **kwargs)

In [None]:
def post_processing(arr, field_vocab, train):
    for index in range(0, len(arr)):
        arr[index] = map(int, arr[index])
    return arr

In [None]:
src_field = torchtext.data.Field(sequential=True,
#                                  tokenize=(lambda line: int(line)),
                                 postprocessing=post_processing,
                                 use_vocab=False,
                                 pad_token='0',
                                 include_lengths=True,
                                 batch_first=True,
                                 )
trg_field = torchtext.data.Field(sequential=True,
#                                  tokenize=(lambda line: int(line)),
                                 postprocessing=post_processing,
                                 use_vocab=False,
                                 include_lengths=True,
                                 pad_token='0',
                                 batch_first=True
                                 )

In [None]:
train_dataset = LuongNMTDataset(src_path=data_path + 'train.10k.en', 
                            trg_path=data_path + 'train.10k.de', 
                            fields=(src_field, trg_field)
                           )

In [None]:
train_loader = torchtext.data.BucketIterator(dataset=train_dataset, 
                                             batch_size=2, 
                                             repeat=False, 
                                             shuffle=True,
                                             sort_within_batch=True, 
                                             sort_key=lambda x: len(x.src)
                                            )

In [None]:
len(train_loader)

In [None]:
batch_sample = None

In [None]:
for batch in train_loader:
    print(batch)
    batch_sample = batch
    break
    pass

In [None]:
len(train_dataset)

In [None]:
sentences, lengths = batch_sample.src

In [None]:
sentences

In [None]:
lengths

In [None]:
lengths = lengths.cpu()

In [None]:
rnn_utils.pack_padded_sequence(sentences, lengths, batch_first=True)