In [1]:
import torch
from typing import List
from transformers import AutoTokenizer, BartForQuestionAnswering
BATCH_SIZE = 2

class TextQA():
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("valhalla/bart-large-finetuned-squadv1")
        self.model = BartForQuestionAnswering.from_pretrained("valhalla/bart-large-finetuned-squadv1")


    def extract(self, texts: List[str], query: List[str]):
        """Retrieves images from the database."""
        data = self.tokenizer(query.tolist(), texts.tolist(), return_tensors="pt", padding=True, truncation=True)
        _, uq_indexes, uq_inverse = unique(data["input_ids"], dim=0)
        data_unique = {k: v[uq_indexes] for k, v in data.items()}
        result_values = list()
        for i in range(0, data_unique["input_ids"].shape[0], BATCH_SIZE):
            inputs = {k: v[i: i + BATCH_SIZE] for k, v in data.items()}
            result = self.model(**inputs)
            start = result["start_logits"].argmax(1)
            end = result["start_logits"].argmax(1)
            for i in range(len(start)):
                value = self.tokenizer.decode(data["input_ids"][i][start[i]: end[i] + 1]).strip()
                result_values.append(value)
        result = [result_values[i] for i in uq_inverse]
        return result


def unique(x, dim=None):
    """Unique elements of x and indices of those unique elements
    https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810

    e.g.

    unique(tensor([
        [1, 2, 3],
        [1, 2, 4],
        [1, 2, 3],
        [1, 2, 5]
    ]), dim=0)
    => (tensor([[1, 2, 3],
                [1, 2, 4],
                [1, 2, 5]]),
        tensor([0, 1, 3]))
    """
    unique, inverse = torch.unique(
        x, sorted=True, return_inverse=True, dim=dim)
    perm = torch.arange(inverse.size(0), dtype=inverse.dtype,
                        device=inverse.device)
    inverse, perm = inverse.flip([0]), perm.flip([0])
    return unique, inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm), inverse


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
text = "Kobe Bryant scored 60 points as the Los Angeles Lakers (17 - 65) beat the Utah Jazz (40 - 42) in his farewell game, 101 - 96. Bryant, who plans to retire after 20 NBA seasons, took a career - high 50 shots, making 22 field goals. He scored 38 in the second half as the Lakers overcame a 15 - point halftime deficit. The Lakers' defense, which struggled all season, was surprsingly the key to the comeback, as it held the Jazz to just 39 points in the second half. The third - leading scorer in NBA history, Bryant became the oldest player to score 50 points in a game as he posted the fifth - highest scoring game of his career. Trey Lyles led the Jazz with 18 points, 11 rebounds and five steals. Gordon Hayward, who finished with 17 points, had the unfortunate task of covering Bryant. The Jazz were eliminated from the playoffs just before tip - off when the Houston Rockets beat the Sacramento Kings, the fourth consecutive season Utah has missed the playoffs. It marks the Jazz's longest playoff drought since 1980 - 83. With 12 points Jordan Clarkson was the only other player in double figures for the Lakers, who have missed the playoffs a franchise - high three straight seasons."
query = "How many assists did Kobe Bryant make?"
qa = TextQA()
import numpy as np
print(qa.extract(np.asarray([text]), np.asarray([query])))


You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.


['22']
