# ANLI Baseline with LLM

You have to implement in this notebook a baseline for ANLI classification using an LLM.
This baseline must be implemented using DSPy.



In [None]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy

# Set the API key directly in notebook
os.environ["XAI_API_KEY"] = "xai-68ZbAMNsnFh2Me5IfyZYaX3yzRESBnanzySaEsym0YqARQCEOzbVbWM8iKjcIRpePX1yZaq85ZeFVhac"

# Configure DSPy
lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
dspy.configure(lm=lm)


In [None]:
from typing import Literal

## Implement the DSPy classifier program.

class Classify(dspy.Signature):
    """Classify the relationship between premise and hypothesis as entailment, neutral, or contradiction."""
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

## Load ANLI dataset

In [3]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

In [4]:
dataset

DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 2923
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 4861
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 13375
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200


## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [5]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [6]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

In [7]:
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])

{'accuracy': 0.6666666666666666,
 'f1': 0.6666666666666666,
 'precision': 1.0,
 'recall': 0.5}

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

In [17]:
# Create DSPy Module
class NLIClassifier(dspy.Module):
    def __init__(self):
        super().__init__()
        self.classify = dspy.Predict(Classify)
    
    def forward(self, premise, hypothesis):
        return self.classify(premise=premise, hypothesis=hypothesis)

# Initialize the classifier
nli_classifier = NLIClassifier()

In [18]:
label_names = ["entailment", "neutral", "contradiction"]


# Shuffle and sample 50 examples from dev_r3
dev_r3_sample = dataset["dev_r3"].shuffle(seed=42).select(range(30))


train_examples = [
    dspy.Example(
        premise=ex["premise"],
        hypothesis=ex["hypothesis"],
        label=label_names[ex["label"]]
    ).with_inputs("premise", "hypothesis")  
    for ex in dev_r3_sample
]

In [19]:
from dspy import MIPROv2


def exact_match(pred, gold, trace=None):
    # Fixes argument confusion — extract the label field from Prediction and Example
    try:
        pred_label = getattr(pred, 'label', pred)
        gold_label = getattr(gold, 'label', gold)


        if isinstance(pred_label, dspy.Example):
            pred_label = getattr(pred_label, 'label', pred_label)
        if isinstance(gold_label, dspy.Example):
            gold_label = getattr(gold_label, 'label', gold_label)


        pred_str = str(pred_label).strip().lower()
        gold_str = str(gold_label).strip().lower()


        #print(f"DEBUG: Comparing prediction '{pred_str}' with reference '{gold_str}'")
        return pred_str == gold_str
    except Exception as e:
        return False


# Initialize your classifier module
dspy_module = NLIClassifier()

# Set up the optimizer
optimizer = MIPROv2(metric=exact_match)

# Compile the module with optimization on the sample set
optimized_dspy_module = optimizer.compile(
    dspy_module,
    trainset=train_examples,
    requires_permission_to_run=False  # avoids prompt for Grok cost confirmation
)


2025/06/30 16:49:42 INFO dspy.teleprompt.mipro_optimizer_v2: 
RUNNING WITH THE FOLLOWING LIGHT AUTO RUN SETTINGS:
num_trials: 10
minibatch: False
num_fewshot_candidates: 6
num_instruct_candidates: 3
valset size: 24

2025/06/30 16:49:42 INFO dspy.teleprompt.mipro_optimizer_v2: 
==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==
2025/06/30 16:49:42 INFO dspy.teleprompt.mipro_optimizer_v2: These will be used as few-shot example candidates for our program and for creating instructions.

2025/06/30 16:49:42 INFO dspy.teleprompt.mipro_optimizer_v2: Bootstrapping N=6 sets of demonstrations...


Bootstrapping set 1/6
Bootstrapping set 2/6
Bootstrapping set 3/6


100%|██████████| 6/6 [00:24<00:00,  4.06s/it]


Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 6 attempts.
Bootstrapping set 4/6


100%|██████████| 6/6 [00:24<00:00,  4.06s/it]


Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 6 attempts.
Bootstrapping set 5/6


100%|██████████| 6/6 [00:26<00:00,  4.34s/it]


Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 6 attempts.
Bootstrapping set 6/6


 50%|█████     | 3/6 [00:11<00:11,  3.77s/it]
2025/06/30 16:51:08 INFO dspy.teleprompt.mipro_optimizer_v2: 
==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==
2025/06/30 16:51:08 INFO dspy.teleprompt.mipro_optimizer_v2: We will use the few-shot examples from the previous step, a generated dataset summary, a summary of the program code, and a randomly selected prompting tip to propose instructions.


Bootstrapped 3 full traces after 3 examples for up to 1 rounds, amounting to 3 attempts.


2025/06/30 16:51:19 INFO dspy.teleprompt.mipro_optimizer_v2: 
Proposing N=3 instructions...

2025/06/30 16:52:15 INFO dspy.teleprompt.mipro_optimizer_v2: Proposed Instructions for Predictor 0:

2025/06/30 16:52:15 INFO dspy.teleprompt.mipro_optimizer_v2: 0: Given the fields `premise`, `hypothesis`, produce the fields `label`.

2025/06/30 16:52:15 INFO dspy.teleprompt.mipro_optimizer_v2: 1: You are an expert in Natural Language Inference (NLI). Your task is to analyze the given premise and hypothesis and determine their logical relationship. The possible labels are:

- 'entailment': The hypothesis logically follows from or is supported by the premise.
- 'neutral': The hypothesis is neither supported nor contradicted by the premise; it may be unrelated or indeterminate.
- 'contradiction': The hypothesis conflicts with or is negated by the premise.

Consider the following examples for guidance:

1. Premise: "That time John Grisham defended men who watch child porn John Grisham's disturbin

Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:17<00:00,  1.37it/s]

2025/06/30 16:52:33 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/06/30 16:52:33 INFO dspy.teleprompt.mipro_optimizer_v2: Default program score: 70.83

2025/06/30 16:52:33 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 2 / 10 =====



Average Metric: 15.00 / 24 (62.5%): 100%|██████████| 24/24 [00:14<00:00,  1.61it/s]

2025/06/30 16:52:47 INFO dspy.evaluate.evaluate: Average Metric: 15 / 24 (62.5%)
2025/06/30 16:52:48 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 62.5 with parameters ['Predictor 0: Instruction 1', 'Predictor 0: Few-Shot Set 3'].
2025/06/30 16:52:48 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5]
2025/06/30 16:52:48 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:52:48 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 3 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:16<00:00,  1.48it/s]

2025/06/30 16:53:04 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/06/30 16:53:04 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 0'].
2025/06/30 16:53:04 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83]
2025/06/30 16:53:04 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:53:04 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 4 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:17<00:00,  1.40it/s]

2025/06/30 16:53:21 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/06/30 16:53:21 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 1', 'Predictor 0: Few-Shot Set 5'].
2025/06/30 16:53:21 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83, 70.83]
2025/06/30 16:53:21 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:53:21 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 5 / 10 =====



Average Metric: 16.00 / 24 (66.7%): 100%|██████████| 24/24 [00:15<00:00,  1.53it/s]

2025/06/30 16:53:37 INFO dspy.evaluate.evaluate: Average Metric: 16 / 24 (66.7%)
2025/06/30 16:53:37 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 66.67 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 2'].
2025/06/30 16:53:37 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83, 70.83, 66.67]
2025/06/30 16:53:37 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:53:37 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 6 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:19<00:00,  1.26it/s]

2025/06/30 16:53:56 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/06/30 16:53:56 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 0', 'Predictor 0: Few-Shot Set 5'].
2025/06/30 16:53:56 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83, 70.83, 66.67, 70.83]
2025/06/30 16:53:56 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:53:56 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 7 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:00<00:00, 2114.73it/s]

2025/06/30 16:53:56 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/06/30 16:53:56 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 0'].
2025/06/30 16:53:56 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83, 70.83, 66.67, 70.83, 70.83]
2025/06/30 16:53:56 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:53:56 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 8 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:20<00:00,  1.15it/s]

2025/06/30 16:54:17 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/06/30 16:54:17 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 5'].
2025/06/30 16:54:17 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83, 70.83, 66.67, 70.83, 70.83, 70.83]
2025/06/30 16:54:17 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:54:17 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 9 / 10 =====



Average Metric: 16.00 / 24 (66.7%): 100%|██████████| 24/24 [00:15<00:00,  1.58it/s]

2025/06/30 16:54:32 INFO dspy.evaluate.evaluate: Average Metric: 16 / 24 (66.7%)
2025/06/30 16:54:32 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 66.67 with parameters ['Predictor 0: Instruction 1', 'Predictor 0: Few-Shot Set 4'].
2025/06/30 16:54:32 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83, 70.83, 66.67, 70.83, 70.83, 70.83, 66.67]
2025/06/30 16:54:32 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:54:32 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 10 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:00<00:00, 1932.19it/s]

2025/06/30 16:54:32 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/06/30 16:54:32 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 5'].
2025/06/30 16:54:32 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83, 70.83, 66.67, 70.83, 70.83, 70.83, 66.67, 70.83]
2025/06/30 16:54:32 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:54:32 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 11 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:00<00:00, 2183.63it/s]

2025/06/30 16:54:33 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/06/30 16:54:33 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 0', 'Predictor 0: Few-Shot Set 0'].
2025/06/30 16:54:33 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 62.5, 70.83, 70.83, 66.67, 70.83, 70.83, 70.83, 66.67, 70.83, 70.83]
2025/06/30 16:54:33 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/06/30 16:54:33 INFO dspy.teleprompt.mipro_optimizer_v2: Returning best identified program with score 70.83!





In [24]:
from tqdm import tqdm

print("\n📊 Evaluating optimized LLM model on test_r3...")

def evaluate_llm_on_dataset(dataset_section, llm_model):
    """Evaluate LLM model on ANLI dataset section"""
    results = []
    label_names = ["entailment", "neutral", "contradiction"]
    
    for example in tqdm(dataset_section):
        premise = example['premise']
        hypothesis = example['hypothesis']
        
        # Get LLM prediction
        llm_result = llm_model(premise=premise, hypothesis=hypothesis)
        llm_prediction = llm_result.label.lower().strip()
        
        # Ensure prediction is valid
        if llm_prediction not in label_names:
            # Handle edge cases where model outputs slight variations
            if 'entail' in llm_prediction:
                llm_prediction = 'entailment'
            elif 'neutral' in llm_prediction:
                llm_prediction = 'neutral'
            elif 'contra' in llm_prediction:
                llm_prediction = 'contradiction'
            else:
                llm_prediction = 'neutral'  # Default fallback
        
        results.append({
            'premise': premise,
            'hypothesis': hypothesis,
            'llm_pred': llm_prediction,
            'gold_label': label_names[example['label']],
            'reason': example['reason']
        })
    
    return results

# Run LLM evaluation on test_r3
llm_test_r3_results = evaluate_llm_on_dataset(dataset['test_r3'], optimized_dspy_module)

# 2. Compute LLM metrics
def compute_llm_metrics(results):
    """Compute metrics for LLM predictions"""
    pred_labels = [r['llm_pred'] for r in results]
    gold_labels = [r['gold_label'] for r in results]
    
    # Convert to integers for metrics
    label_to_int = {"entailment": 0, "neutral": 1, "contradiction": 2}
    pred_ints = [label_to_int[label] for label in pred_labels]
    gold_ints = [label_to_int[label] for label in gold_labels]
    
    # Compute metrics
    accuracy = eval.load("accuracy")
    precision = eval.load("precision")
    recall = eval.load("recall")
    f1 = eval.load("f1")
    
    metrics = {
        'accuracy': accuracy.compute(predictions=pred_ints, references=gold_ints)['accuracy'],
        'precision': precision.compute(predictions=pred_ints, references=gold_ints, average='macro')['precision'],
        'recall': recall.compute(predictions=pred_ints, references=gold_ints, average='macro')['recall'],
        'f1': f1.compute(predictions=pred_ints, references=gold_ints, average='macro')['f1']
    }
    return metrics

llm_metrics = compute_llm_metrics(llm_test_r3_results)

print(f"✓ LLM Evaluation completed on {len(llm_test_r3_results)} examples")
print(f"  Accuracy: {llm_metrics['accuracy']:.4f}")
print(f"  F1: {llm_metrics['f1']:.4f}")
print(f"  Precision: {llm_metrics['precision']:.4f}")
print(f"  Recall: {llm_metrics['recall']:.4f}")



📊 Evaluating optimized LLM model on test_r3...


  2%|▎         | 30/1200 [02:01<1:19:10,  4.06s/it]


KeyboardInterrupt: 