# NLI base results: Llama-3.1-8B-Instruct (meta-llama/Llama-3.1-8B-Instruct) via Ollama on M4

Loads [yilmazzey/sdp2-nli](https://huggingface.co/datasets/yilmazzey/sdp2-nli) (snli_tr_1_1, multinli_tr_1_1, trglue_mnli) and runs **test-only** zero-shot Turkish NLI evaluation with **Llama-3.1-8B-Instruct** via **Ollama** (no Hugging Face pipeline).

Uses Llama 3.1 Instruct chat format (system + user turn; Ollama applies the template). Model is instructed to answer with exactly one word: entailment, neutral, or contradiction. Outputs parsed to 0=entailment, 1=neutral, 2=contradiction. Runs on Apple Silicon (M4) with pure Ollama (CPU/Metal), no CUDA, no quantization.

**Splits:** snli → test; multinli → validation_matched/mismatched; trglue → test_matched/test_mismatched. **Metrics:** Accuracy, macro F1, per-class F1, confusion matrix (CSV + seaborn plot). Results saved to `./results/`. **Prerequisite:** `ollama pull llama3.1:8b` (or your Ollama model name for Llama-3.1-8B-Instruct).

In [1]:
# Install ollama Python client if needed; standard libs for datasets/metrics/plots
# !pip install -q ollama datasets scikit-learn tqdm matplotlib seaborn
# If ollama is already installed via brew/pip, skip or run: pip install -q ollama

In [2]:
import json
import random
import re
from collections import Counter
from pathlib import Path

import numpy as np
from datasets import load_dataset
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
from tqdm import tqdm
import ollama

try:
    import matplotlib.pyplot as plt
    import seaborn as sns
    HAS_PLOT = True
except ImportError:
    HAS_PLOT = False

LABEL_NAMES = ["entailment", "neutral", "contradiction"]

# Device: M4 / Apple Silicon — no CUDA; Ollama uses Metal/CPU
print("Running on Apple Silicon (M4) / CPU — Ollama handles Metal.")

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm


Running on Apple Silicon (M4) / CPU — Ollama handles Metal.


In [3]:
REPO_ID = "yilmazzey/sdp2-nli"
CONFIGS = ["snli_tr_1_1", "multinli_tr_1_1", "trglue_mnli"]
MODEL_ID = "llama3.1:8b"  # Ollama model name (e.g. after: ollama pull llama3.1:8b)
NUM_LABELS = 3
RESULTS_DIR = "results"
BATCH_SIZE = 6  # Safe on M4 36GB RAM; use 4–8
MAX_TOKENS = 10
TEMPERATURE = 0.0
TOP_P = 0.0
EVAL_SPLITS = {
    "snli_tr_1_1": ["test"],
    "multinli_tr_1_1": ["validation_matched", "validation_mismatched"],
    "trglue_mnli": ["test_matched", "test_mismatched"],
}

In [4]:
# Load all three dataset configs (same as Turkish-Gemma-9b-T1)
datasets = {}
for cfg in CONFIGS:
    print(f"Loading {REPO_ID} :: {cfg} ...")
    datasets[cfg] = load_dataset(REPO_ID, cfg)
    print("  splits:", list(datasets[cfg].keys()))

Loading yilmazzey/sdp2-nli :: snli_tr_1_1 ...
  splits: ['train', 'validation', 'test']
Loading yilmazzey/sdp2-nli :: multinli_tr_1_1 ...
  splits: ['train', 'validation_matched', 'validation_mismatched']
Loading yilmazzey/sdp2-nli :: trglue_mnli ...
  splits: ['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched']


In [5]:
# Llama 3.1 Instruct: system + user turn; Ollama applies <|start_header_id|>system/user/assistant<|end_header_id|> + <|eot_id|>
SYSTEM_PROMPT = """You are a natural language inference classifier. You must answer with exactly one word and nothing else: entailment, neutral, or contradiction. No explanation, no punctuation, no extra text. Only one of these three words."""


def nli_user_prompt(premise, hypothesis):
    return f"""Premise: {premise}
Hypothesis: {hypothesis}
Does the premise entail, is neutral to, or contradict the hypothesis? Answer with only one word: entailment, neutral, or contradiction."""


def ollama_chat_single(user_content: str) -> str:
    """Llama 3.1 chat: system + user turn. Returns assistant message content only."""
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_content},
    ]
    response = ollama.chat(
        model=MODEL_ID,
        messages=messages,
        options={
            "num_predict": MAX_TOKENS,
            "temperature": TEMPERATURE,
            "top_p": TOP_P,
        },
    )
    return (response.get("message") or {}).get("content", "") or ""


LABEL_WORD_TO_ID = {
    "entailment": 0,
    "neutral": 1,
    "contradiction": 2,
    "içerme": 0,
    "tarafsız": 1,
    "nötr": 1,
    "çelişki": 2,
}


def parse_generated_label(raw_text: str) -> int:
    """Extract first word from model output; strip punctuation; lowercase; map EN+TR; default 1 (neutral)."""
    if not raw_text or not isinstance(raw_text, str):
        return 1
    text = raw_text.strip()
    if not text:
        return 1
    # First token/word (split by whitespace)
    parts = text.split()
    first = parts[0] if parts else ""
    # Strip punctuation and lowercase (handles quoted output e.g. 'entailment')
    first = re.sub(r"[.,;:!?\"'()\[\]]+", "", first).strip().lower()
    if not first:
        return 1
    return LABEL_WORD_TO_ID.get(first, 1)

In [6]:
def run_prompted_inference(ds):
    premises = ds["premise"]
    hypotheses = ds["hypothesis"]
    labels = ds["label"]
    n = len(labels)
    y_pred = []
    debug_indices = set(list(range(min(5, n))) + list(range(0, n, 100)))

    for start in tqdm(range(0, n, BATCH_SIZE), desc="Inference"):
        end = min(start + BATCH_SIZE, n)
        for i in range(start, end):
            user_content = nli_user_prompt(premises[i], hypotheses[i])
            raw = ollama_chat_single(user_content)
            label_id = parse_generated_label(raw)
            y_pred.append(label_id)
            if i in debug_indices:
                print(f"[sample {i}] raw: {repr(raw)} -> parsed: {label_id} ({LABEL_NAMES[label_id]})")

    y_true = np.array(labels, dtype=np.int64)
    y_pred = np.array(y_pred, dtype=np.int64)
    print("True label dist:", dict(Counter(y_true)))
    print("Pred label dist:", dict(Counter(y_pred)))
    return y_true, y_pred

In [7]:
def compute_metrics(y_true, y_pred):
    acc = float(accuracy_score(y_true, y_pred))
    f1_macro = float(f1_score(y_true, y_pred, average="macro", zero_division=0))
    f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
    f1_per_class = {LABEL_NAMES[i]: float(f1_per_class[i]) for i in range(NUM_LABELS)}
    cm = confusion_matrix(y_true, y_pred)
    out = {"accuracy": acc, "f1_macro": f1_macro, "f1_per_class": f1_per_class}
    return out, cm


def save_confusion_plot(cm, path):
    if not HAS_PLOT:
        return
    fig, ax = plt.subplots(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", xticklabels=LABEL_NAMES, yticklabels=LABEL_NAMES, ax=ax)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    plt.tight_layout()
    plt.savefig(path)
    plt.close()

In [None]:
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)
all_metrics = {}

for config_name in CONFIGS:
    ds_dict = datasets[config_name]
    split_names = EVAL_SPLITS[config_name]
    all_metrics[config_name] = {}

    for split_name in split_names:
        if split_name not in ds_dict:
            print(f"  Skip {config_name}/{split_name} (missing)")
            continue
        ds = ds_dict[split_name]
        print(f"Evaluating {config_name} / {split_name} ...")
        y_true, y_pred = run_prompted_inference(ds)
        metrics, cm = compute_metrics(y_true, y_pred)
        all_metrics[config_name][split_name] = metrics

        cm_path = Path(RESULTS_DIR) / f"confusion_{config_name}_{split_name}.csv"
        np.savetxt(cm_path, cm, fmt="%d", delimiter=",")
        save_confusion_plot(cm, Path(RESULTS_DIR) / f"confusion_{config_name}_{split_name}.png")

        print(f"  accuracy={metrics['accuracy']:.4f}, f1_macro={metrics['f1_macro']:.4f}")

with open(Path(RESULTS_DIR) / "metrics.json", "w") as f:
    json.dump(all_metrics, f, indent=2)
print(f"Saved {RESULTS_DIR}/metrics.json")
sys.stdout = _nli_orig_stdout
_nli_log_file.close()
print(f"Log closed: {_nli_log_path}")

Evaluating snli_tr_1_1 / test ...


Inference:   0%|          | 0/1638 [00:00<?, ?it/s]

[sample 0] raw: 'neutral' -> parsed: 1 (neutral)
[sample 1] raw: 'entailment' -> parsed: 0 (entailment)
[sample 2] raw: 'entailment' -> parsed: 0 (entailment)
[sample 3] raw: 'entailment' -> parsed: 0 (entailment)
[sample 4] raw: 'entailment' -> parsed: 0 (entailment)


Inference:   1%|          | 16/1638 [00:19<33:34,  1.24s/it]

[sample 100] raw: 'neutral' -> parsed: 1 (neutral)


Inference:   2%|▏         | 33/1638 [00:40<32:57,  1.23s/it]

[sample 200] raw: 'neutral' -> parsed: 1 (neutral)


Inference:   3%|▎         | 50/1638 [01:02<33:50,  1.28s/it]

[sample 300] raw: 'entailment' -> parsed: 0 (entailment)


Inference:   4%|▍         | 66/1638 [01:22<33:54,  1.29s/it]

[sample 400] raw: 'neutral' -> parsed: 1 (neutral)


Inference:   5%|▌         | 83/1638 [01:44<33:02,  1.28s/it]

[sample 500] raw: 'entailment' -> parsed: 0 (entailment)


Inference:   6%|▌         | 100/1638 [02:07<34:39,  1.35s/it]

[sample 600] raw: 'entailment' -> parsed: 0 (entailment)


Inference:   7%|▋         | 116/1638 [02:29<36:41,  1.45s/it]

[sample 700] raw: 'neutral' -> parsed: 1 (neutral)


Inference:   8%|▊         | 133/1638 [02:56<42:50,  1.71s/it]

[sample 800] raw: 'entailment' -> parsed: 0 (entailment)


Inference:   9%|▉         | 150/1638 [03:29<46:21,  1.87s/it]

[sample 900] raw: 'neutral' -> parsed: 1 (neutral)


Inference:  10%|█         | 166/1638 [03:56<39:26,  1.61s/it]

[sample 1000] raw: 'neutral' -> parsed: 1 (neutral)


Inference:  11%|█         | 183/1638 [04:22<38:09,  1.57s/it]

[sample 1100] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  12%|█▏        | 200/1638 [04:49<37:17,  1.56s/it]

[sample 1200] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  13%|█▎        | 216/1638 [05:14<36:19,  1.53s/it]

[sample 1300] raw: 'neutral' -> parsed: 1 (neutral)


Inference:  14%|█▍        | 233/1638 [05:40<35:42,  1.53s/it]

[sample 1400] raw: 'neutral' -> parsed: 1 (neutral)


Inference:  15%|█▌        | 250/1638 [06:05<35:33,  1.54s/it]

[sample 1500] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  16%|█▌        | 266/1638 [06:29<33:53,  1.48s/it]

[sample 1600] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  17%|█▋        | 283/1638 [06:54<33:04,  1.46s/it]

[sample 1700] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  18%|█▊        | 300/1638 [07:19<33:25,  1.50s/it]

[sample 1800] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  19%|█▉        | 316/1638 [07:43<32:01,  1.45s/it]

[sample 1900] raw: 'neutral' -> parsed: 1 (neutral)


Inference:  20%|██        | 333/1638 [08:08<32:01,  1.47s/it]

[sample 2000] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  21%|██▏       | 350/1638 [08:32<31:00,  1.44s/it]

[sample 2100] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  22%|██▏       | 366/1638 [08:57<35:29,  1.67s/it]

[sample 2200] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  23%|██▎       | 383/1638 [09:24<32:58,  1.58s/it]

[sample 2300] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  24%|██▍       | 400/1638 [09:49<32:14,  1.56s/it]

[sample 2400] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  25%|██▌       | 416/1638 [10:14<30:22,  1.49s/it]

[sample 2500] raw: 'entailment' -> parsed: 0 (entailment)


Inference:  26%|██▋       | 433/1638 [10:40<31:19,  1.56s/it]

[sample 2600] raw: 'neutral' -> parsed: 1 (neutral)


Inference:  27%|██▋       | 450/1638 [11:08<31:19,  1.58s/it]

[sample 2700] raw: 'neutral' -> parsed: 1 (neutral)


Inference:  28%|██▊       | 463/1638 [11:27<29:35,  1.51s/it]

In [None]:
# Summary: per config/split
for config_name, splits in all_metrics.items():
    for split_name, m in splits.items():
        print(f"{config_name} / {split_name}: acc={m['accuracy']:.4f}, F1_macro={m['f1_macro']:.4f}, F1_per_class={m['f1_per_class']}")