In [1]:
import pandas as pd
import torch
from transformers import BartForSequenceClassification, BartTokenizer, __version__ as tv
import numpy as np


"""

Bart model is from the Transformers package by Huggingface
https://github.com/huggingface/transformers

Our version: 3.3.1

"""


print(tv)


DEVICE = "cpu"


class BartZeroShot:
    def __init__(self):

        self.nli_model = BartForSequenceClassification.from_pretrained(
            "facebook/bart-large-mnli"
        )
        self.nli_model = self.nli_model.to(DEVICE)
        self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")

    def predict(self, sentence, label):
        x = self.tokenizer.encode(
            sentence,
            f"this text is {label}",  # f'This text is about {label}.',
            return_tensors="pt",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            truncation_strategy="only_first",
        )
        logits = self.nli_model(x.to(DEVICE))[0]

        entail_contradiction_logits = logits[:, [0, 2]]
        probs = entail_contradiction_logits.softmax(1)
        prob_label_is_true = probs[:, 1].item()
        return prob_label_is_true
        

  from .autonotebook import tqdm as notebook_tqdm


4.28.1


In [2]:
bz = BartZeroShot()


In [3]:
bz.predict("I really really hate my life", "positive")

0.0003220779472030699

In [4]:
bz.predict("I really really love my life", "positive")

0.9848922491073608