Text generation pipeline using the pretrained distilbert model finetuned on the CVPR abstracts

In [1]:
from transformers import pipeline

model_checkpoint="distilbert-base-uncased-finetuned-cvpr"
unmasker = pipeline("fill-mask", model=model_checkpoint)
unmasker("State of the art object detectors [MASK]", top_k=2)

[{'score': 0.8849564790725708,
  'token': 1012,
  'token_str': '.',
  'sequence': 'state of the art object detectors.'},
 {'score': 0.05219423398375511,
  'token': 1025,
  'token_str': ';',
  'sequence': 'state of the art object detectors ;'}]

Punctuation are too likely, we need to filter them out. Import the model and the tokenizer

In [2]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_checkpoint = "distilbert-base-uncased-finetuned-cvpr"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

Test on the masked text to check it works

In [3]:
import torch
text = "State of the art object detection detectors [MASK]"
inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits

# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]

# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"{text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

State of the art object detection detectors .
State of the art object detection detectors ;
State of the art object detection detectors :
State of the art object detection detectors include
State of the art object detection detectors are


In [4]:
k = 5
punc = "! . , ' : ; - ( ) & | ?"
k2 = k + 12 # top predictions are all punc in the worst case
top_topkens = torch.topk(mask_token_logits, k2, dim=1).indices[0].tolist()

cnt = 0
for token in top_topkens:
    pred = tokenizer.decode([token])
    if not (pred in punc):
        print(f"{text.replace(tokenizer.mask_token, pred)}")
        cnt += 1
    if cnt == k: break

State of the art object detection detectors include
State of the art object detection detectors are
State of the art object detection detectors have
State of the art object detection detectors exist
State of the art object detection detectors on


Wrap this in a class method for model deployment

In [5]:
class TextFill:
    def __init__(self, model_checkpoint="distilbert-base-uncased-finetuned-cvpr"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        self.model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
        self.punc = "!.,':;-()&|?"
        self.punc_len = len(self.punc)
    
    def _predict(self, text):
        inputs = self.tokenizer(text, return_tensors="pt")
        token_logits = self.model(**inputs).logits

        # Find the location of [MASK] and extract its logits
        mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        return mask_token_logits
    
    def _get_top_k(self, mask_token_logits, top_k):
        k2 = top_k + self.punc_len
        top_topkens = torch.topk(mask_token_logits, k2, dim=1).indices[0].tolist()
        out = []
        cnt = 0
        for token in top_topkens:
            pred = self.tokenizer.decode([token])
            if not (pred in punc):
                out.append(pred)
                cnt += 1
            if cnt == top_k: break
        return out
    
    def fill_mask(self, text, top_k=10):
        mask_token_logits = self._predict(text)
        out = self._get_top_k(mask_token_logits, top_k)
        return out

tf = TextFill()

In [6]:
text = """To this end, we design a deep latent space deformation network that is directly parameterized by the kernel. 
          The network consists of three components: encoder, deformer, and decoder, where the deformer is specifically
          meant to rectify the latent space representations of blurred images to a standard latent space, regardless of
          the kernel. The deformation network is trained with a two-stage training scheme. We conduct extensive
          experiments to confirm that our parametric model can adapt to drastically different blurring kernels and perform
          robust [MASK].""" # original word: deblurring
pred = tf.fill_mask(text, 10)
print(pred)

['optimization', 'inference', 'training', '##ness', '##ly', 'estimation', 'learning', 'reconstruction', 'modeling', 'computation']
