In [None]:
import itertools, os
import numpy as np
import spacy
import torch
from torchtext import data, datasets
from torchtext.vocab import Vectors, GloVe
use_gpu = torch.cuda.is_available()

def preprocess(vocab_size, batchsize, max_sent_len=20):
    de_spacy = spacy.load('de')
    en_spacy = spacy.load('en')

    def tokenize(text, lang='en'):
        if lang is 'de':
            return [tok.text for tok in de_spacy.tokenizer(text)]
        elif lang is 'en':
            return [tok.text for tok in en_spacy.tokenizer(text)]

    BOS_WORD = '<s>'
    EOS_WORD = '</s>'
    DE = data.Field(tokenize=lambda x: tokenize(x, 'de'))
    EN = data.Field(tokenize=tokenize, init_token=BOS_WORD, eos_token=EOS_WORD)

    train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN), filter_pred = lambda x: max(len(vars(x)['src']), len(vars(x)['trg'])) <= max_sent_len)


    if vocab_size > 0:
        DE.build_vocab(train.src, min_freq=5, max_size=vocab_size)
        EN.build_vocab(train.trg, min_freq=5, max_size=vocab_size)
    else:
        DE.build_vocab(train.src, min_freq=5)
        EN.build_vocab(train.trg, min_freq=5)

    train_iter = data.BucketIterator(train, batch_size=batchsize, device=-1, repeat=False, sort_key=lambda x: len(x.src))
    val_iter = data.BucketIterator(val, batch_size=1, device=-1, repeat=False, sort_key=lambda x: len(x.src))
    
    return DE, EN, train_iter, val_iter

def load_embeddings(SRC, TRG, np_src_file, np_trg_file):
    if os.path.isfile(np_src_file) and os.path.isfile(np_trg_file):
        emb_tr_src = torch.from_numpy(np.load(np_src_file))
        emb_tr_trg = torch.from_numpy(np.load(np_trg_file))
    else: 
        raise Exception('Vectors are unloadable')
    return emb_tr_src, emb_tr_trg
    