In [1]:
path_to_model = './models/transformer.pth'

In [2]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
from transformer import Transformer
src_vocab_size = 13610
tgt_vocab_size = 24266
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 128
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout,device)

In [5]:
state_dict = torch.load(path_to_model)
transformer.load_state_dict(state_dict)
transformer.eval()

Transformer(
  (encoder_embedding): Embedding(13610, 512)
  (decoder_embedding): Embedding(24266, 512)
  (positional_encoding): PositionalEncoding()
  (encoder_layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (self_attn): MultiHeadAttention(
        (W_q): Linear(in_features=512, out_features=512, bias=True)
        (W_k): Linear(in_features=512, out_features=512, bias=True)
        (W_v): Linear(in_features=512, out_features=512, bias=True)
        (W_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_forward): PositionWiseFeedForward(
        (fc1): Linear(in_features=512, out_features=2048, bias=True)
        (fc2): Linear(in_features=2048, out_features=512, bias=True)
        (relu): ReLU()
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder_layers): ModuleList(
    (0-5): 6 x DecoderLayer(

In [17]:
import spacy
eng = spacy.load('en_core_web_sm')
de = spacy.load('de_core_news_sm')

In [18]:
import torchdata.datapipes as dp
import torchtext.transforms as T
import spacy
from torchtext.vocab import build_vocab_from_iterator

In [19]:
import dataset_util as util

In [20]:
FILE_PATH = './dataset/deu-eng/deu.txt'

data_pipe = dp.iter.IterableWrapper([FILE_PATH])
data_pipe = dp.iter.FileOpener(data_pipe,mode='rb')
data_pipe = data_pipe.parse_csv(skip_lines=0,delimiter='\t',as_tuple=True)

In [21]:
data_pipe = data_pipe.map(util.removeAttribution)

In [22]:
source_vocab = build_vocab_from_iterator(
    util.getTokens(data_pipe,0,eng.tokenizer,de.tokenizer),
    min_freq=2,
    specials=['<pad>','<sos>','<eos>','<unk>'],
    special_first=True
)

target_vocab = build_vocab_from_iterator(
    util.getTokens(data_pipe,1,eng.tokenizer,de.tokenizer),
    min_freq=2,
    specials= ['<pad>', '<sos>', '<eos>', '<unk>'],
    special_first=True
)

In [23]:
print(source_vocab.get_itos()[:9])

['<pad>', '<sos>', '<eos>', '<unk>', '.', 'I', 'Tom', 'to', 'you']


In [48]:
def evaluatemodel(model, source_input,start_token,end_token,max_seq_length):
    target_input = [start_token]
    for _ in range(max_seq_length):
        pred = model(source_input,torch.tensor(target_input).unsqueeze(0).to(device)).squeeze(0)
        next_token = torch.argmax(pred[-1],dim=-1).item()
        target_input.append(next_token)
        if next_token == end_token:
            break
    return target_input

In [74]:
input_text = 'A dog is running in the park.'

In [75]:
transformed_input = util.getTransform(source_vocab)(util.Tokenize(input_text,eng.tokenizer))
transformed_input

[1, 203, 212, 13, 816, 17, 9, 493, 4, 2]

In [76]:
tensor_input = T.ToTensor()(transformed_input).unsqueeze(0)
tensor_input.shape

torch.Size([1, 10])

In [77]:
tensor_input_device = tensor_input.to(device)
output = transformer(tensor_input_device,tensor_input_device).squeeze(0)
print(output[-1].argmax(dim=-1).item())


4


In [78]:
translated = evaluatemodel(transformer,tensor_input_device,start_token=1,end_token=2,max_seq_length=128)

In [79]:
translated

[1, 287, 209, 1077, 69, 649, 4, 2]

In [80]:
target_index_to_string = target_vocab.get_itos()
sentence = ''
for index in translated:
    sentence += ' ' + target_index_to_string[index]
print(sentence)

 <sos> Ein Hund läuft im Park . <eos>
