In [3]:
from great_ai import GreatAI, use_model, MongoDbDriver, configure
from great_ai.utilities import clean
from pathlib import Path
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
import re
import numpy as np
import torch
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import (
    PreTrainedModel,
    PreTrainedTokenizer,
)
from views import EvaluatedSentence, Match
from great_ai.large_file import LargeFileS3

LargeFileS3.configure_credentials_from_file("config.ini")
MongoDbDriver.configure_credentials_from_file("config.ini")
configure(dashboard_table_size=100)

ORIGINAL_MODEL = "allenai/scibert_scivocab_uncased"

loaded_model: PreTrainedModel = None
tokenizer: PreTrainedTokenizer = None


@GreatAI.create
@use_model("scibert-highlights", version="latest")
def find_highlights(sentence: str, model: Path) -> EvaluatedSentence:
    global loaded_model, tokenizer

    if loaded_model is None:
        config = AutoConfig.from_pretrained(
            model, output_hidden_states=True, output_attentions=True
        )
        loaded_model = AutoModelForSequenceClassification.from_pretrained(
            model, config=config
        )
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL)

    sentence = clean(sentence, convert_to_ascii=True, remove_brackets=True)

    return evaluate_sentence(sentence=sentence)


def evaluate_sentence(sentence: str) -> EvaluatedSentence:
    tensors = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=512)

    with torch.no_grad():
        result: SequenceClassifierOutput = loaded_model(**tensors)
        positive_likelihood = torch.nn.Softmax(dim=1)(result.logits)[0][1]
    tokens = tensors["input_ids"][0]

    attentions = np.sum(result.attentions[-1].numpy()[0], axis=0)[0][
        1:-1
    ]  # Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
    matches = []

    token_attentions = list(zip(attentions, tokens[1:-1]))
    for token in re.split(r"([ .,])", sentence):
        token = token.strip()
        if not token:
            continue
        bert_tokens = tokenizer(
            token, return_tensors="pt", truncation=True, max_length=512
        )["input_ids"][0][
            1:-1
        ]  # truncation=True needed to fix RuntimeError: Already borrowed
        score = 0
        for t1 in bert_tokens:
            if not token_attentions:
                break
            a, t2 = token_attentions.pop(0)
            assert t1 == t2, sentence
            score += a
        matches.append(
            Match(phrase=token if token in ".," else " " + token, score=round(score, 4))
        )
        if not token_attentions:
            break

    return EvaluatedSentence(
        score=positive_likelihood, text=sentence, explanation=matches
    )

[38;5;39m2022-07-01 14:28:43,378 |     INFO | Found credentials file (/data/projects/scoutinscience/platform/projects/highlights-service2/mongo.ini), initialising MongoDbDriver[0m
[38;5;39m2022-07-01 14:28:43,379 |     INFO | Found credentials file (/data/projects/scoutinscience/platform/projects/highlights-service2/s3.ini), initialising LargeFileS3[0m
[38;5;39m2022-07-01 14:28:43,380 |     INFO | Settings: configured ✅[0m
[38;5;39m2022-07-01 14:28:43,380 |     INFO | 🔩 tracing_database: MongoDbDriver[0m
[38;5;39m2022-07-01 14:28:43,381 |     INFO | 🔩 large_file_implementation: LargeFileS3[0m
[38;5;39m2022-07-01 14:28:43,381 |     INFO | 🔩 is_production: False[0m
[38;5;39m2022-07-01 14:28:43,382 |     INFO | 🔩 should_log_exception_stack: True[0m
[38;5;39m2022-07-01 14:28:43,382 |     INFO | 🔩 prediction_cache_size: 512[0m
[38;5;39m2022-07-01 14:28:43,383 |     INFO | 🔩 dashboard_table_size: 50[0m
[38;5;39m2022-07-01 14:28:44,082 |     INFO | Latest version of scibert