In [13]:
import torch
from transformers import OPTForCausalLM, AutoTokenizer
from datasets import load_dataset

In [14]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = OPTForCausalLM.from_pretrained("facebook/opt-125m")

In [15]:
prompt = """Premise: i don't know um do you do a lot of camping,
            hypothesis: I know exactly,
            label: Contradiction

            Premise: This site includes a list of all award winners and a searchable database of Government Executive articles, 
            Hypothesis: The Government Executive articles housed on the website are not able to be searched, 
            Label: Contradiction

            Premise: The new rights are nice enough, 
            Hypothesis: Everyone really likes the newest benefits, 
            Label: Entailment
                        
            Premise: uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him
            hypothesis: I like him for the most part, but would still enjoy seeing someone beat him very hard and enjoy that sight.
            Label:"""
inputs = tokenizer(prompt, return_tensors="pt")

In [17]:
generate_ids = model.generate(inputs.input_ids, max_length=inputs['input_ids'].shape[-1] + 10)
output_text_ = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
label_out = output_text_.split("Label:")[-1].strip().split('.')[0].strip()
label_out

'Contradiction'

    With MNLI tests

In [18]:
dataset = load_dataset("nyu-mll/glue", "mnli", split='validation_matched')
dataset

Downloading readme:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 52.2M/52.2M [00:01<00:00, 50.4MB/s]
Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1.21M/1.21M [00:00<00:00, 14.6MB/s]
Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1.25M/1.25M [00:00<00:00, 5.73MB/s]
Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1.22M/1.22M [00:00<00:00, 21.1MB/s]
Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1.26M/1.26M [00:00<00:00, 19.8MB/s]


Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

Generating test_matched split:   0%|          | 0/9796 [00:00<?, ? examples/s]

Generating test_mismatched split:   0%|          | 0/9847 [00:00<?, ? examples/s]

Dataset({
    features: ['premise', 'hypothesis', 'label', 'idx'],
    num_rows: 9815
})

In [19]:
def label_to_text(label):
    return ["entailment", "neutral", "contradiction"][label]


In [20]:
def create_extended_prompt(context_examples, query_example):
    context_prompt = ""
    for example in context_examples:
        premise = example['premise']
        hypothesis = example['hypothesis']
        label = label_to_text(example['label'])
        context_prompt += f"Premise: {premise}, Hypothesis: {hypothesis}, Label: {label}"

    query_premise = query_example['premise']
    query_hypothesis = query_example['hypothesis']
    print('true lbl: ', label_to_text(query_example['label']))
    query_prompt = f"Premise: {query_premise}, Hypothesis: {query_hypothesis}, Label:"
    
    return context_prompt + query_prompt

In [21]:
def classify_with_context(context_data, query_data, tokenizer, model):
    prompt = create_extended_prompt(context_data, query_data)
    inputs = tokenizer(prompt, return_tensors="pt")

    with torch.no_grad():
        generate_ids = model.generate(inputs['input_ids'], max_length=inputs['input_ids'].shape[-1] + 1)
    output_text = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
    
    predicted_label = output_text.split("Label:")[-1].strip().split('.')[0].strip()
    return predicted_label

context_examples = [{key: value[i] for key, value in dataset[:2].items()} for i in range(2)]
query_example = {key: value for key, value in dataset[2].items()}

print(context_examples)
print(query_example)

predicted_label = classify_with_context(context_examples, query_example, tokenizer, model)
print(f"Predicted Label: {predicted_label}")


[{'premise': 'The new rights are nice enough', 'hypothesis': 'Everyone really likes the newest benefits ', 'label': 1, 'idx': 0}, {'premise': 'This site includes a list of all award winners and a searchable database of Government Executive articles.', 'hypothesis': 'The Government Executive articles housed on the website are not able to be searched.', 'label': 2, 'idx': 1}]
{'premise': "uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him", 'hypothesis': 'I like him for the most part, but would still enjoy seeing someone beat him.', 'label': 0, 'idx': 2}
true lbl:  entailment
Predicted Label: I
