In [40]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

# from torchtext.datasets import TranslationDataset, Multi30k
# from torchtext.data import Field, BucketIterator
# import spacy

import pickle
import random
import math
import os
import time
import nltk

### Import the pre-processed data

In [41]:
path = 'data/'

with open(path + 'english_no_pad_sorted_50k.pickle', 'rb') as handle:
    english = pickle.load(handle)
    
with open(path + 'german_no_pad_sorted_50k.pickle', 'rb') as handle:
    german = pickle.load(handle)

### Token-Based Batching Method

Feed in the data from the preprocessed set, it should be sorted from shortest sentence to longest sentence without any padding on the sentences. The *get_batches* function will be used to create all of the batches for training, the output is a number of batches of varying dimensions, which is based on the batch size. Below is an example.

#### Example, batch-size=14
- Given the first few sentences from the dataset (sorted):
    - [2, 84, 3]      (length 3)
    - [2, 102, 3]     (length 3)
    - [2, 63, 3]      (length 3)
    - [2, 84, 21, 3]  (length 4)
    - [2, 91, 123, 3] (length 4)
    
- We will fill up the batches based on the number of tokens in each sentence. So for example, the first batch (batch_size=14) will look like:
```
    [[2, 84, 3],
     [2, 102, 3],     
     [2, 63, 3],      
     [2, 84, 21, 3]]
```
- We will then zero-pad all of the sentences in batch that are less than maximum length of the longest sentence in the batch to be the same length as the longest sentence:

```
    [[2, 84, 3, 0],
     [2, 102, 3, 0],     
     [2, 63, 3, 0],      
     [2, 84, 21, 3]]
```

- Now we have a batch of dimension: $N x D$, where:
    - $N$ is the number of sentences in the batch, and 
    - $D$ is the dimensionality (number of words) within a sentence.
    - It is important to note that the $N$ and $D$ values will vary from batch to batch, but **MUST** be consistent within each batch

In [13]:
"""
# THE FOLLOWING FUNCTION IS DEPRECATED

def get_batches(german, english, b_sz):
    batches = [[]]
    
    # For every sentence in the dataset, add it to a batch, based on the batch size
    # if the sentence + current length is not greater than the batch size, 
    # then add it to the batch otherwise fill the current batch
    for sent in german:
        cur_len = 0
        for b in batches[-1]:
            cur_len += len(b)
    
        if (cur_len + len(sent)) <= b_sz: 
            batches[-1].append(sent)
        else:
            batches.append([])        
            batches[-1].append(sent)
    
    # For every batch within the entire set of batches, add padding to the sentences
    # that are less than the length of the longest sentence within the each batch.
    for b in batches:
        max_len = len(max(b, key=len))
        
        for sent in b:
            dif = max_len - len(sent)
            if dif > 0:
                pad_list = 0 * dif
                sent.append(pad_list)
        
    return batches



batches = get_batches(german['train'], b_sz=20)
"""

In [66]:
def get_batches(german, english, b_sz):
    de_batches = [[]]
    
    # For every sentence in the dataset, add it to a batch, based on the batch size
    # if the sentence + current length is not greater than the batch size, 
    # then add it to the batch otherwise fill the current batch
    for sent in german:
        cur_len = 0
        for b in de_batches[-1]:
            cur_len += len(b)
    
        if (cur_len + len(sent)) <= b_sz: 
            de_batches[-1].append(sent)
        else:
            de_batches.append([])        
            de_batches[-1].append(sent)
    
    # For every batch within the entire set of batches, add padding to the sentences
    # that are less than the length of the longest sentence within the each batch.
    for b in de_batches:
        max_len = len(max(b, key=len))
        
        for sent in b:
            dif = max_len - len(sent)
            if dif > 0:
                pad_list = 0 * dif
                sent.append(pad_list)
    
    en_batches = []
    k=0
    for i in range(len(de_batches)):
        tmp_batch = [0]*len(de_batches[i])
        for j in range(len(de_batches[i])):
            tmp_batch[j] = english[k]
            k+=1
            
        en_batches.append(tmp_batch)
        
    batches = []
    for i in range(len(de_batches)):
        dict_batch = []
        for j in range(len(de_batches[i])):
            tmp_dict = {"source": de_batches[i][j],
                       "target": en_batches[i][j]}
            dict_batch.append(tmp_dict)
        batches.append(dict_batch)
        
    return batches

test_batches = get_batches(german['train'], english['train'], b_sz=20)

print("load the source and target sentences of the 3rd sentence within the 102nd batch:")
print("Source:", test_batches[102][3]['source'])
print("Target:", test_batches[102][3]['target'])

load the source and target sentences of the 3rd sentence within the 102nd batch:
Source: [2, 865, 335, 3]
Target: [2, 733, 21, 11, 3]


In [54]:
print(len(en_batches))
print(len(batches))

def print_sentence(sent, language):
    if language == "german":
        for w in sent:
            print(german['idx2word'][w], end=' ')
    elif language == "english":
        for w in sent:
            print(english['idx2word'][w], end=' ')
    else:
        print("Language should be either 'german' or 'english'")
        
    print("")


print_sentence(batches[29890][0], language="german")
print_sentence(en_batches[29890][0], language="english")

#     print(len(batches[i]), ",", len(en_batches[i]))


29891
29891
<sos> ich wünschte , ich könnte ihnen mein jemen zeigen . <eos> 
<sos> i wish i could show you my yemen . <eos> 


In [33]:
english.keys()

dict_keys(['idx2word', 'train', 'dev'])

In [None]:
class Encoder(nn.Module):
    def __init__(self, params):
        super(Encoder, self).__init__()
        
    def forward(self, batch):
        pass
        

In [None]:
def train(english, german, params, net):
    batches = get_batches()
    pass