# ANLI Baseline with LLM

You have to implement in this notebook a baseline for ANLI classification using an LLM.
This baseline must be implemented using DSPy.



In [None]:
import os
import dspy

# Set the API key directly in notebook
os.environ["XAI_API_KEY"] = ""

# Configure DSPy
lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
dspy.configure(lm=lm)


In [11]:
from typing import Literal

class EntailmentSignature(dspy.Signature):
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

class EntailmentClassifier(dspy.Module):
    def __init__(self):
        super().__init__()
        self.classify = dspy.Predict(EntailmentSignature)

    def forward(self, premise, hypothesis):
        return self.classify(premise=premise, hypothesis=hypothesis)


## Load ANLI dataset

In [12]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

In [13]:
dataset

DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 2923
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 4861
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 13375
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200


## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [14]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [15]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

In [16]:
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])

{'accuracy': 0.6666666666666666,
 'f1': 0.6666666666666666,
 'precision': 1.0,
 'recall': 0.5}

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

In [17]:
label_names = ["entailment", "neutral", "contradiction"]

# Shuffle and sample 30 examples from dev_r3
dev_r3_sample = dataset["dev_r3"].shuffle(seed=42).select(range(30))


train_examples = [
    dspy.Example(
        premise=ex["premise"],
        hypothesis=ex["hypothesis"],
        label=label_names[ex["label"]]
    ).with_inputs("premise", "hypothesis")  
    for ex in dev_r3_sample
]

In [18]:
from dspy import MIPROv2

def exact_match(pred, gold, trace=None):
    # Fixes argument confusion — extract the label field from Prediction and Example
    try:
        pred_label = getattr(pred, 'label', pred)
        gold_label = getattr(gold, 'label', gold)

        if isinstance(pred_label, dspy.Example):
            pred_label = getattr(pred_label, 'label', pred_label)
        if isinstance(gold_label, dspy.Example):
            gold_label = getattr(gold_label, 'label', gold_label)

        pred_str = str(pred_label).strip().lower()
        gold_str = str(gold_label).strip().lower()

        return pred_str == gold_str
    except Exception as e:
        return False


dspy_module = EntailmentClassifier()

optimizer = MIPROv2(metric=exact_match)

# Compile the module with optimization on the sample set
optimized_dspy_module = optimizer.compile(
    dspy_module,
    trainset=train_examples,
    requires_permission_to_run=False  # avoids prompt for Grok cost confirmation
)


2025/07/17 17:40:34 INFO dspy.teleprompt.mipro_optimizer_v2: 
RUNNING WITH THE FOLLOWING LIGHT AUTO RUN SETTINGS:
num_trials: 10
minibatch: False
num_fewshot_candidates: 6
num_instruct_candidates: 3
valset size: 24

2025/07/17 17:40:34 INFO dspy.teleprompt.mipro_optimizer_v2: 
==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==
2025/07/17 17:40:34 INFO dspy.teleprompt.mipro_optimizer_v2: These will be used as few-shot example candidates for our program and for creating instructions.

2025/07/17 17:40:34 INFO dspy.teleprompt.mipro_optimizer_v2: Bootstrapping N=6 sets of demonstrations...


Bootstrapping set 1/6
Bootstrapping set 2/6
Bootstrapping set 3/6


100%|██████████| 6/6 [00:00<00:00, 226.60it/s]


Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 6 attempts.
Bootstrapping set 4/6


100%|██████████| 6/6 [00:00<00:00, 452.70it/s]


Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 6 attempts.
Bootstrapping set 5/6


100%|██████████| 6/6 [00:00<00:00, 525.81it/s]


Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 6 attempts.
Bootstrapping set 6/6


 50%|█████     | 3/6 [00:00<00:00, 594.29it/s]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: 
==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: We will use the few-shot examples from the previous step, a generated dataset summary, a summary of the program code, and a randomly selected prompting tip to propose instructions.
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: 
Proposing N=3 instructions...

2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Proposed Instructions for Predictor 0:

2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: 0: Given the fields `premise`, `hypothesis`, produce the fields `label`.

2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: 1: You are an expert in natural language inference (NLI). Given a premise and a hypothesis, your task is to classify the relationship between them by producing a single label from the following options: 'entailme

Bootstrapped 3 full traces after 3 examples for up to 1 rounds, amounting to 3 attempts.
Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:00<00:00, 565.23it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Default program score: 70.83






2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 2 / 10 =====


Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:00<00:00, 509.39it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 1', 'Predictor 0: Few-Shot Set 3'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 70.83


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 3 / 10 =====



Average Metric: 18.00 / 24 (75.0%): 100%|██████████| 24/24 [00:00<00:00, 539.69it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 18 / 24 (75.0%)
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: [92mBest full score so far![0m Score: 75.0





2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 75.0 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 0'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 4 / 10 =====


Average Metric: 16.00 / 24 (66.7%): 100%|██████████| 24/24 [00:00<00:00, 488.15it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 16 / 24 (66.7%)
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 66.67 with parameters ['Predictor 0: Instruction 1', 'Predictor 0: Few-Shot Set 5'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0, 66.67]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 5 / 10 =====



Average Metric: 16.00 / 24 (66.7%): 100%|██████████| 24/24 [00:00<00:00, 281.06it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 16 / 24 (66.7%)





2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 66.67 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 2'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0, 66.67, 66.67]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 6 / 10 =====


Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:00<00:00, 455.97it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)





2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 0', 'Predictor 0: Few-Shot Set 5'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0, 66.67, 66.67, 70.83]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 7 / 10 =====


Average Metric: 18.00 / 24 (75.0%): 100%|██████████| 24/24 [00:00<00:00, 2516.77it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 18 / 24 (75.0%)
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 75.0 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 0'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0, 66.67, 66.67, 70.83, 75.0]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 8 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:00<00:00, 505.41it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 5'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0, 66.67, 66.67, 70.83, 75.0, 70.83]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 9 / 10 =====



Average Metric: 15.00 / 24 (62.5%): 100%|██████████| 24/24 [00:00<00:00, 275.82it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 15 / 24 (62.5%)
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 62.5 with parameters ['Predictor 0: Instruction 1', 'Predictor 0: Few-Shot Set 4'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0, 66.67, 66.67, 70.83, 75.0, 70.83, 62.5]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 10 / 10 =====



Average Metric: 17.00 / 24 (70.8%): 100%|██████████| 24/24 [00:00<00:00, 1976.43it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 17 / 24 (70.8%)





2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 70.83 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 5'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0, 66.67, 66.67, 70.83, 75.0, 70.83, 62.5, 70.83]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: ===== Trial 11 / 10 =====


Average Metric: 16.00 / 24 (66.7%): 100%|██████████| 24/24 [00:00<00:00, 484.94it/s]

2025/07/17 17:40:35 INFO dspy.evaluate.evaluate: Average Metric: 16 / 24 (66.7%)





2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Score: 66.67 with parameters ['Predictor 0: Instruction 2', 'Predictor 0: Few-Shot Set 3'].
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Scores so far: [70.83, 70.83, 75.0, 66.67, 66.67, 70.83, 75.0, 70.83, 62.5, 70.83, 66.67]
2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Best score so far: 75.0


2025/07/17 17:40:35 INFO dspy.teleprompt.mipro_optimizer_v2: Returning best identified program with score 75.0!


In [19]:
from tqdm import tqdm

test_r3 = dataset["test_r3"].filter(lambda x: x["reason"] is not None and x["reason"] != "")

dspy_results = []

for i, example in enumerate(tqdm(test_r3)):
    output = optimized_dspy_module(premise=example["premise"], hypothesis=example["hypothesis"])
   
    dspy_results.append({
        "premise": example["premise"],
        "hypothesis": example["hypothesis"],
        "pred_label": output.label,
        "gold_label": label_names[example["label"]],
    })
    if i % 50 == 0:
        print(f"Processed {i}/{len(test_r3)} examples...")


  0%|          | 0/1200 [00:00<?, ?it/s]

Processed 0/1200 examples...
Processed 50/1200 examples...
Processed 100/1200 examples...


 10%|█         | 120/1200 [00:00<00:00, 1193.37it/s]

Processed 150/1200 examples...
Processed 200/1200 examples...


 22%|██▏       | 265/1200 [00:00<00:00, 1341.06it/s]

Processed 250/1200 examples...
Processed 300/1200 examples...
Processed 350/1200 examples...


 33%|███▎      | 400/1200 [00:00<00:00, 1152.89it/s]

Processed 400/1200 examples...
Processed 450/1200 examples...
Processed 500/1200 examples...


 46%|████▋     | 557/1200 [00:00<00:00, 1300.39it/s]

Processed 550/1200 examples...


 60%|██████    | 720/1200 [00:00<00:00, 1410.84it/s]

Processed 600/1200 examples...
Processed 650/1200 examples...
Processed 700/1200 examples...
Processed 750/1200 examples...
Processed 800/1200 examples...
Processed 850/1200 examples...


 73%|███████▎  | 878/1200 [00:00<00:00, 1463.46it/s]

Processed 900/1200 examples...


 86%|████████▋ | 1037/1200 [00:00<00:00, 1501.33it/s]

Processed 950/1200 examples...
Processed 1000/1200 examples...
Processed 1050/1200 examples...
Processed 1100/1200 examples...
Processed 1150/1200 examples...


100%|██████████| 1200/1200 [00:00<00:00, 1415.80it/s]


In [20]:
label2id = {label: i for i, label in enumerate(label_names)}

pred_labels = [label2id[r['pred_label']] for r in dspy_results]
gold_labels = [label2id[r['gold_label']] for r in dspy_results]

dspy_metrics = {
    "accuracy": accuracy.compute(predictions=pred_labels, references=gold_labels)["accuracy"],
    "precision": precision.compute(predictions=pred_labels, references=gold_labels, average="macro")["precision"],
    "recall": recall.compute(predictions=pred_labels, references=gold_labels, average="macro")["recall"],
    "f1": f1.compute(predictions=pred_labels, references=gold_labels, average="macro")["f1"],
}
print({k: round(v, 4) for k, v in dspy_metrics.items()})

{'accuracy': 0.7033, 'precision': 0.7461, 'recall': 0.703, 'f1': 0.7086}


In [21]:
import pickle
import pandas as pd

with open("baseline_preds.pkl", "rb") as f:
    baseline_results = pickle.load(f)

def compute_agreement(baseline_results, dspy_results):
    assert len(baseline_results) == len(dspy_results)

    correct_both = 0
    correct_baseline = 0
    correct_dspy = 0
    incorrect_both = 0

    for b, d in zip(baseline_results, dspy_results):
        gold = b["gold_label"]
        pred1 = b["pred_label"]
        pred2 = d["pred_label"]

        is1 = pred1 == gold
        is2 = pred2 == gold

        if is1 and is2:
            correct_both += 1
        elif is1 and not is2:
            correct_baseline += 1
        elif not is1 and is2:
            correct_dspy += 1
        else:
            incorrect_both += 1

    return {
        "Correct (both correct)": correct_both,
        "Correct1 (baseline only)": correct_baseline,
        "Correct2 (DSPy only)": correct_dspy,
        "Incorrect (both wrong)": incorrect_both,
    }

pd.DataFrame([compute_agreement(baseline_results, dspy_results)])

Unnamed: 0,Correct (both correct),Correct1 (baseline only),Correct2 (DSPy only),Incorrect (both wrong)
0,454,140,390,216
