In [1]:
%reload_ext autoreload
%autoreload 2

import copy
import math
import time
import torch
import sys
from torch import nn, Tensor

sys.path.append("..")
from model import TransformerModel, generate_square_subsequent_mask
# from data_utils import data_process, batchify, get_batch
# from train import train_epoch, evaluate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"torch_version: {torch.__version__}")
print(f"device: {device}")

torch_version: 1.12.1
device: cpu


## Vocab

In [11]:
from torch.utils.data import dataset
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator


train_iter = WikiText2(split="train")
tokenizer = get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

idx2word = {i: x for i, x in enumerate(vocab.get_itos())}
def to_sentence(x):
    return " ".join([idx2word[int(idx)] for idx in x])

## Build Model

In [12]:
ntokens = len(vocab)
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.2

model =  TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)
model.load_state_dict(torch.load("../weights", map_location=torch.device('cpu')))
_ = model.eval()

# Prompt

In [74]:
# prompt =  "there was a white horse on a beach. Steve and his friends had"
prompt = "Cheney addresses possibility of White House run after crushing loss to"
prompt = "None of them noticed a large, tawny owl flutter past your mother"
x = torch.tensor(vocab(tokenizer(prompt))).unsqueeze(-1)
x.shape

torch.Size([13, 1])

In [75]:
for _ in range(10):
    mask = generate_square_subsequent_mask(x.size(0))
    output = model(x, mask)

    y = output.argmax(dim = -1).squeeze()
    x = torch.concat((x,  y[-1:].unsqueeze(-1)))
    print(to_sentence(x[:, 0]))

x.shape

none of them noticed a large , tawny owl <unk> past your mother ,
none of them noticed a large , tawny owl <unk> past your mother , and
none of them noticed a large , tawny owl <unk> past your mother , and <unk>
none of them noticed a large , tawny owl <unk> past your mother , and <unk> ,
none of them noticed a large , tawny owl <unk> past your mother , and <unk> , and
none of them noticed a large , tawny owl <unk> past your mother , and <unk> , and <unk>
none of them noticed a large , tawny owl <unk> past your mother , and <unk> , and <unk> ,
none of them noticed a large , tawny owl <unk> past your mother , and <unk> , and <unk> , and
none of them noticed a large , tawny owl <unk> past your mother , and <unk> , and <unk> , and <unk>
none of them noticed a large , tawny owl <unk> past your mother , and <unk> , and <unk> , and <unk> ,


torch.Size([23, 1])