In [1]:
import torch
import copy

from transformer.transformer import make_model

In [2]:
RUN_EXAMPLES = True

In [3]:
def is_interactive_notebook():
    return __name__ == "__main__"


def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)


def execute_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        fn(*args)


class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None


class DummyScheduler:
    def step(self):
        None

In [4]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

def inference_test():
    test_model = make_model(11, 11, 2)
    test_model.eval()
    src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    src_mask = torch.ones(1, 1, 10)

    memory = test_model.encode(src, src_mask)
    ys = torch.zeros(1, 1).type_as(src)

    for i in range(9):
        out = test_model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        prob = test_model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.empty(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )

    print("Example Untrained Model Prediction:", ys)


def run_tests():
    for _ in range(10):
        inference_test()


show_example(run_tests)

Example Untrained Model Prediction: tensor([[ 0,  5,  5,  0,  3, 10,  6,  2, 10,  5]])
Example Untrained Model Prediction: tensor([[ 0,  3,  3,  3,  3, 10,  5,  5,  5,  5]])
Example Untrained Model Prediction: tensor([[0, 4, 9, 6, 8, 8, 8, 8, 8, 8]])
Example Untrained Model Prediction: tensor([[ 0,  6, 10,  1, 10,  1, 10,  1, 10,  1]])
Example Untrained Model Prediction: tensor([[0, 6, 7, 4, 4, 4, 4, 9, 1, 1]])
Example Untrained Model Prediction: tensor([[0, 7, 6, 1, 1, 1, 1, 1, 1, 1]])
Example Untrained Model Prediction: tensor([[0, 6, 6, 6, 6, 6, 6, 6, 6, 6]])
Example Untrained Model Prediction: tensor([[ 0,  7,  6,  0,  7,  6,  0,  7, 10, 10]])
Example Untrained Model Prediction: tensor([[ 0, 10,  7,  1,  2,  5,  5,  5,  5,  5]])
Example Untrained Model Prediction: tensor([[0, 6, 6, 8, 8, 8, 8, 8, 8, 8]])
