## Inference Demo
**Step 1.** Set these variables

In [1]:
CKPT_PATH = './lightning_logs/roberta-base/version_0/checkpoints/epoch=4-step=330.ckpt'
DATA_PATH = './data/en_gum'

**Step 2.** Load the imports

In [2]:
from datasets import load_from_disk
from functools import partial
from model import POSClassifier
from torch import LongTensor
from transformers import AutoTokenizer, BatchEncoding, logging

import torch

logging.set_verbosity_error()

**Step 3.** Create these objects

In [3]:
ud = load_from_disk('./data/en_gum')
POS_map = ud['train'].info.features['upos'].feature.names

model = POSClassifier.load_from_checkpoint(CKPT_PATH)
tokenizer = AutoTokenizer.from_pretrained(model.hparams.encoder_name, add_prefix_space=True)
tokenize_fn = partial(tokenizer, return_tensors='pt')

**Step 4.** Define the inference method

In [4]:
@torch.no_grad()
def infer(sentence: str) -> list:
    def map_wordpieces(encoding: BatchEncoding) -> LongTensor:
        n_tokens = sum(t != tokenizer.pad_token for t in encoding.tokens()) - 2  # -2 for [CLS] and [SEP]
        wp_map = [-1] + [encoding.token_to_word(token_idx) for token_idx in range(1, n_tokens + 1)] + [-1]
        return LongTensor(wp_map).unsqueeze(dim=0)

    e = tokenize_fn(sentence)
    w = map_wordpieces(e)
    o = model(e, w)

    return [POS_map[i] for i in o.argmax(dim=1)]

**Step 5.** Perform inference!

In [5]:
infer("Hello guys, I am a dog")

['INTJ', 'NOUN', 'PUNCT', 'PRON', 'AUX', 'DET', 'NOUN']

**Step 6.** Profit?