In [74]:
from model import BiLSTMPOSTagger
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
import json

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

vocab_text = torch.load(f"saved_models/en-vocabtext.pth")
vocab_tag = torch.load(f"saved_models/en-vocabtag.pth")
params = json.load(open("config.json"))

model = BiLSTMPOSTagger(
        input_dim=len(vocab_text),
        embedding_dim=params["embedding_dim"],
        hidden_dim=params["hidden_dim"],
        output_dim=len(vocab_tag),
        n_layers=params["n_layers"],
        bidirectional=params["bidirectional"],
        dropout=params["dropout"],
        pad_idx=vocab_text['<PAD>'],
    ).to(device)

model.load_state_dict(torch.load(f"saved_models/en-model.pt", map_location=device))

#vocab_tag.get_stoi() returns dict of label to index mappings. Reverse it here.
idx_tag = {v: k for k, v in vocab_tag.get_stoi().items()}

#transform the text and tags into their indicies from the vocabulary. Return them as tensors. That's what our
#collate_fn does here in general.
def transform_text(x):
    return [vocab_text[token] for token in x]
def transform_tag(x):
    return [vocab_tag[tag] for tag in x]

def collate_batch(batch):
    tag_list, text_list = [], []
    for (line, label) in batch:
        text_list.append(torch.tensor(transform_text(line), device=device))
        tag_list.append(torch.tensor(transform_tag(label), device=device))
    return (
        pad_sequence(text_list, padding_value=vocab_text['<PAD>']),
        pad_sequence(tag_list, padding_value=vocab_tag['<PAD>'])
    )


def inference(batch):
    #create dataset iterator to easily get text and labels encoded versions from the collate_batch.
    train_dataloader = DataLoader(
        batch, collate_fn=collate_batch, batch_size=params['batch_size'],
    )
    
    #since our dataloader is an iterator we take our items this way. They return the output of collate_batch. Encoded
    #versions according to vocabulary as a tensor.
    encoded_text = next(iter(train_dataloader))[0]
    tags = next(iter(train_dataloader))[1]
    
    #we can directly feed the encoded text tensor to the pre-trained model. Its shape is torch.Size([3, 1]) for 3 token
    #sample.
    preds = model(encoded_text)
    
    #here we find the most probable prediction for each token with argmax. After that we map that index to its english
    #counterpart. This comprehension iterates 3 times since for a 3 token sample.
    results = [idx_tag[int(i)] for i in preds.argmax(-1)]
    return (results)

eg = [[["John", "went", "up"], ["BLANK", "BLANK", "BLANK"]]]
print(eg[0][0])
print(inference(eg))

['John', 'went', 'up']
['PROPN', 'VERB', 'ADP']
