In [None]:
import torch
from transformers import OPTForCausalLM, AutoTokenizer
from datasets import load_dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = OPTForCausalLM.from_pretrained("facebook/opt-125m")

In [None]:
prompt = """Premise: i don't know um do you do a lot of camping,
            hypothesis: I know exactly,
            label: Contradiction

            Premise: This site includes a list of all award winners and a searchable database of Government Executive articles, 
            Hypothesis: The Government Executive articles housed on the website are not able to be searched, 
            Label: Contradiction

            Premise: The new rights are nice enough, 
            Hypothesis: Everyone really likes the newest benefits, 
            Label: Entailment
                        
            Premise: uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him
            hypothesis: I like him for the most part, but would still enjoy seeing someone beat him very hard and enjoy that sight.
            Label:"""
inputs = tokenizer(prompt, return_tensors="pt")

In [None]:
generate_ids = model.generate(inputs.input_ids, max_length=inputs['input_ids'].shape[-1] + 10)
output_text_ = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
label_out = output_text_.split("Label:")[-1].strip().split('.')[0].strip()
label_out

    With MNLI tests

In [None]:
dataset = load_dataset("nyu-mll/glue", "mnli", split='validation_matched')
dataset

In [None]:
def label_to_text(label):
    return ["entailment", "neutral", "contradiction"][label]


In [None]:
def create_extended_prompt(context_examples, query_example):
    context_prompt = ""
    for example in context_examples:
        premise = example['premise']
        hypothesis = example['hypothesis']
        label = label_to_text(example['label'])
        context_prompt += f"Premise: {premise}, Hypothesis: {hypothesis}, Label: {label}"

    query_premise = query_example['premise']
    query_hypothesis = query_example['hypothesis']
    print('true lbl: ', label_to_text(query_example['label']))
    query_prompt = f"Premise: {query_premise}, Hypothesis: {query_hypothesis}, Label:"
    
    return context_prompt + query_prompt

In [None]:
def classify_with_context(context_data, query_data, tokenizer, model):
    prompt = create_extended_prompt(context_data, query_data)
    inputs = tokenizer(prompt, return_tensors="pt")

    with torch.no_grad():
        generate_ids = model.generate(inputs['input_ids'], max_length=inputs['input_ids'].shape[-1] + 1)
    output_text = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
    
    predicted_label = output_text.split("Label:")[-1].strip().split('.')[0].strip()
    return predicted_label

context_examples = [{key: value[i] for key, value in dataset[:2].items()} for i in range(2)]
query_example = {key: value for key, value in dataset[2].items()}

print(context_examples)
print(query_example)

predicted_label = classify_with_context(context_examples, query_example, tokenizer, model)
print(f"Predicted Label: {predicted_label}")
