# ANLI Baseline with LLM

You have to implement in this notebook a baseline for ANLI classification using an LLM.
This baseline must be implemented using DSPy.



In [13]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy

lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [14]:
from typing import Literal

## Implement the DSPy classifier program.

class ANLIClassifier(dspy.Signature):
    """Natural Language Inference task: Given a premise and hypothesis, determine the relationship."""
    
    premise = dspy.InputField(desc="The premise statement")
    hypothesis = dspy.InputField(desc="The hypothesis statement to check against the premise")
    label = dspy.OutputField(desc="The relationship: 'entailment', 'neutral', or 'contradiction'")

class ANLIPredictor(dspy.Module):
    def __init__(self):
        super().__init__()
        self.classify = dspy.ChainOfThought(ANLIClassifier)
    
    def forward(self, premise, hypothesis):
        result = self.classify(premise=premise, hypothesis=hypothesis)
        # Ensure the label is one of the valid options
        label = result.label.lower().strip()
        if label not in ['entailment', 'neutral', 'contradiction']:
            # Default to neutral if unclear
            label = 'neutral'
        return dspy.Prediction(label=label, reasoning=result.rationale if hasattr(result, 'rationale') else "")


## Load ANLI dataset

In [15]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

In [16]:
dataset

DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 2923
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 4861
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 13375
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200


## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [17]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [18]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

In [19]:
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])

{'accuracy': 0.6666666666666666,
 'f1': 0.6666666666666666,
 'precision': 1.0,
 'recall': 0.5}

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

In [20]:
import random
from datasets import load_dataset

# Load and filter dataset
dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

# Convert to DSPy format
def convert_to_dspy_examples(data_split, sample_size=None):
    examples = []
    label_names = ["entailment", "neutral", "contradiction"]
    
    data = list(data_split)
    if sample_size:
        data = random.sample(data, min(sample_size, len(data)))
    
    for example in data:
        examples.append(dspy.Example(
            premise=example['premise'],
            hypothesis=example['hypothesis'],
            label=label_names[example['label']]
        ).with_inputs('premise', 'hypothesis'))
    
    return examples

# Prepare training and test data
print("Preparing datasets...")
dev_r3_examples = convert_to_dspy_examples(dataset['dev_r3'], sample_size=50)  # Optimize on 50 examples
test_r3_examples = convert_to_dspy_examples(dataset['test_r3'])

print(f"Dev R3 examples for optimization: {len(dev_r3_examples)}")
print(f"Test R3 examples for evaluation: {len(test_r3_examples)}")


Preparing datasets...
Dev R3 examples for optimization: 50
Test R3 examples for evaluation: 1200


In [21]:
from dspy.teleprompt import BootstrapFewShot

# Initialize predictor
predictor = ANLIPredictor()

# Test before optimization
print("Testing predictor before optimization...")
test_example = dev_r3_examples[0]
result = predictor(premise=test_example.premise, hypothesis=test_example.hypothesis)
print(f"Example: {test_example.premise[:50]}... | {test_example.hypothesis[:50]}...")
print(f"Predicted: {result.label}, Actual: {test_example.label}")

# Set up optimizer
def validate_prediction(example, pred, trace=None):
    return pred.label == example.label

# Optimize the model
print("\nOptimizing predictor...")
teleprompter = BootstrapFewShot(metric=validate_prediction, max_bootstrapped_demos=8, max_labeled_demos=8)
optimized_predictor = teleprompter.compile(predictor, trainset=dev_r3_examples)

print("Optimization complete!")

Testing predictor before optimization...
Example: Healthier Life<br>Maddie wanted to lead a healthie... | Maddie liked the idea of being healthy...
Predicted: entailment, Actual: entailment

Optimizing predictor...


 20%|██        | 10/50 [01:09<04:37,  6.93s/it]

Bootstrapped 8 full traces after 10 examples for up to 1 rounds, amounting to 10 attempts.
Optimization complete!





In [None]:
from tqdm import tqdm

def evaluate_llm_on_dataset(predictor, examples):
    results = []
    
    for example in tqdm(examples, desc="Evaluating"):
        try:
            prediction = predictor(premise=example.premise, hypothesis=example.hypothesis)
            results.append({
                'premise': example.premise,
                'hypothesis': example.hypothesis,
                'pred_label': prediction.label,
                'gold_label': example.label,
                'correct': prediction.label == example.label
            })
        except Exception as e:
            print(f"Error processing example: {e}")
            # Add a default neutral prediction for failed cases
            results.append({
                'premise': example.premise,
                'hypothesis': example.hypothesis,
                'pred_label': 'neutral',
                'gold_label': example.label,
                'correct': 'neutral' == example.label
            })
    
    return results

In [23]:
print("Evaluating optimized LLM model on test_r3...")
llm_results = evaluate_llm_on_dataset(optimized_predictor, test_r3_examples)

# Compute metrics
from evaluate import load
accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")

label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}

def compute_metrics(results):
    preds = [label2id[r['pred_label']] for r in results]
    refs = [label2id[r['gold_label']] for r in results]
    
    return {
        "accuracy": accuracy.compute(predictions=preds, references=refs)["accuracy"],
        "precision": precision.compute(predictions=preds, references=refs, average="macro")["precision"],
        "recall": recall.compute(predictions=preds, references=refs, average="macro")["recall"],
        "f1": f1.compute(predictions=preds, references=refs, average="macro")["f1"],
    }

llm_metrics = compute_metrics(llm_results)
print(f"\n=== LLM Model Metrics on test_r3 ({len(llm_results)} examples) ===")
print(f"Accuracy : {llm_metrics['accuracy']:.4f}")
print(f"Precision: {llm_metrics['precision']:.4f}")
print(f"Recall   : {llm_metrics['recall']:.4f}")
print(f"F1       : {llm_metrics['f1']:.4f}")

Evaluating optimized LLM model on test_r3...


Evaluating: 100%|██████████| 1200/1200 [1:56:19<00:00,  5.82s/it] 



=== LLM Model Metrics on test_r3 (1200 examples) ===
Accuracy : 0.7342
Precision: 0.7540
Recall   : 0.7340
F1       : 0.7382


In [None]:
def load_deberta_results():
    import json
    import os
    
    # Try to load the saved results
    json_file = 'deberta_test_r3_results.json'
    csv_file = 'deberta_test_r3_results.csv'
    
    if os.path.exists(json_file):
        print(f"Loading DeBERTa results from {json_file}")
        with open(json_file, 'r') as f:
            deberta_results = json.load(f)
        print(f"Loaded {len(deberta_results)} DeBERTa results")
        return deberta_results
    
    elif os.path.exists(csv_file):
        print(f"Loading DeBERTa results from {csv_file}")
        import pandas as pd
        df = pd.read_csv(csv_file)
        deberta_results = df.to_dict('records')
        print(f"Loaded {len(deberta_results)} DeBERTa results")
        return deberta_results
    
    else:
        print("ERROR: No DeBERTa results file found!")
        print("Please run the DeBERTa code first and save results using:")
        print("- deberta_test_r3_results.json")
        print("- deberta_test_r3_results.csv")
        return []

In [None]:
def compare_models(llm_results, deberta_results):
    
    if len(deberta_results) == 0:
        print("No DeBERTa results provided. Please run DeBERTa baseline first.")
        return None
    
    assert len(llm_results) == len(deberta_results), "Results must have same length"
    
    comparison = {
        'both_correct': 0,      # Both models correct
        'llm_correct_deberta_wrong': 0,  # LLM correct, DeBERTa wrong
        'deberta_correct_llm_wrong': 0,  # DeBERTa correct, LLM wrong
        'both_incorrect': 0     # Both models incorrect
    }
    
    for llm_pred, deberta_pred in zip(llm_results, deberta_results):
        llm_correct = llm_pred['correct']
        deberta_correct = deberta_pred['correct']
        
        if llm_correct and deberta_correct:
            comparison['both_correct'] += 1
        elif llm_correct and not deberta_correct:
            comparison['llm_correct_deberta_wrong'] += 1
        elif not llm_correct and deberta_correct:
            comparison['deberta_correct_llm_wrong'] += 1
        else:
            comparison['both_incorrect'] += 1
    
    return comparison

print("\n=== Model Comparison ===")

deberta_results = load_deberta_results()  
comparison = compare_models(llm_results, deberta_results)
if comparison:
    total = len(llm_results)
    print(f"Both Correct: {comparison['both_correct']} ({comparison['both_correct']/total:.3f})")
    print(f"LLM Correct, DeBERTa Wrong: {comparison['llm_correct_deberta_wrong']} ({comparison['llm_correct_deberta_wrong']/total:.3f})")
    print(f"DeBERTa Correct, LLM Wrong: {comparison['deberta_correct_llm_wrong']} ({comparison['deberta_correct_llm_wrong']/total:.3f})")
    print(f"Both Incorrect: {comparison['both_incorrect']} ({comparison['both_incorrect']/total:.3f})")



=== Model Comparison ===
To complete the comparison, you need to:
1. Run your DeBERTa baseline on the same test_r3 examples
2. Save those results and load them here
3. Then run the compare_models function
Loading DeBERTa results from deberta_test_r3_results.json
Loaded 1200 DeBERTa results
Both Correct: 479 (0.399)
LLM Correct, DeBERTa Wrong: 402 (0.335)
DeBERTa Correct, LLM Wrong: 115 (0.096)
Both Incorrect: 204 (0.170)


In [None]:
print(f"\n=== Sample LLM Predictions ===")
for i in range(min(5, len(llm_results))):
    result = llm_results[i]
    print(f"\nExample {i+1}:")
    print(f"Premise: {result['premise'][:100]}...")
    print(f"Hypothesis: {result['hypothesis'][:100]}...")
    print(f"Predicted: {result['pred_label']}")
    print(f"Gold: {result['gold_label']}")
    print(f"Correct: {result['correct']}")





=== Sample LLM Predictions ===

Example 1:
Premise: It is Sunday today, let's take a look at the most popular posts of the last couple of days. Most of ...
Hypothesis: The day of the passage is usually when Christians praise the lord together...
Predicted: neutral
Gold: entailment
Correct: False

Example 2:
Premise: By The Associated Press WELLINGTON, New Zealand (AP) — All passengers and crew have survived a crash...
Hypothesis: No children were killed in the accident....
Predicted: entailment
Gold: entailment
Correct: True

Example 3:
Premise: Tokyo - Food group Nestle is seeking to lure Japanese holiday shoppers with a taste for fine snackin...
Hypothesis: Japanese like kit kat. ...
Predicted: entailment
Gold: entailment
Correct: True

Example 4:
Premise: Governor Greg Abbott has called for a statewide show of support for law enforcement Friday, July 7. ...
Hypothesis: Law enforcement officers and the people at the Travis St. memorial do not show their support at the ...
Predicted: