In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
import numpy as np
import spacy
import random
import de_core_news_sm
import en_core_web_sm
# from torch.utils.tensorboard import SummaryWriter


In [2]:
spacy_eng = en_core_web_sm.load()
spacy_ger = de_core_news_sm.load()

In [3]:
'''Tokenizing Text'''

def tokenizer_ger(text):
    return [tok.text for tok in spacy_ger.tokenizer(text)]
def tokenizer_eng(text):
    return [tok.text for tok in spacy_eng.tokenizer(text)]

In [None]:

'''Utilities'''
def translate_sentence(model, sentence, german, english, device, max_length=50):
    print(sentence)

    # sys.exit()

    # Load german tokenizer
    spacy_ger = de_core_news_sm.load()

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    if type(sentence) == str:
        tokens = [token.text.lower() for token in spacy_ger(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    # Build encoder hidden, cell state
    with torch.no_grad():
        hidden, cell = model.encoder(sentence_tensor)

    outputs = [english.vocab.stoi["<sos>"]]

    for _ in range(max_length):
        previous_word = torch.LongTensor([outputs[-1]]).to(device)

        with torch.no_grad():
            output, hidden, cell = model.decoder(previous_word, hidden, cell)
            best_guess = output.argmax(1).item()

        outputs.append(best_guess)

        # Model predicts it's the end of the sentence
        if output.argmax(1).item() == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence[1:]