# NLI classification: BERTurk All-NLI-TR (emrecan/bert-base-turkish-cased-allnli_tr)

Loads [yilmazzey/sdp2-nli](https://huggingface.co/datasets/yilmazzey/sdp2-nli) (snli_tr_1_1, multinli_tr_1_1, trglue_mnli) and runs classification. Model is pre-finetuned on All-NLI-TR; use as-is for zero-shot eval or further fine-tune.

In [None]:
REPO_ID = "yilmazzey/sdp2-nli"
CONFIGS = ["snli_tr_1_1", "multinli_tr_1_1", "trglue_mnli"]
MODEL_ID = "emrecan/bert-base-turkish-cased-allnli_tr"
NUM_LABELS = 3  # entailment, neutral, contradiction
RESULTS_DIR = "results"

In [None]:
from pathlib import Path

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

In [None]:
# Load all three dataset configs
datasets = {}
for cfg in CONFIGS:
    print(f"Loading {REPO_ID} :: {cfg} ...")
    datasets[cfg] = load_dataset(REPO_ID, cfg)
    print("  splits:", list(datasets[cfg].keys()))

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=NUM_LABELS)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()

In [None]:
def get_eval_splits(ds_dict):
    """Return split names that are validation or test (for evaluation)."""
    return [s for s in ds_dict.keys() if s != "train"]

def tokenize_fn(examples):
    return tokenizer(
        examples["premise"],
        examples["hypothesis"],
        truncation=True,
        max_length=256,
        padding="max_length",
    )

def evaluate_split(ds):
    ds = ds.map(
        lambda ex: tokenize_fn(ex),
        batched=True,
        remove_columns=[c for c in ds.column_names if c != "label"],
        desc="Tokenize",
    )
    ds.set_format("torch")
    loader = torch.utils.data.DataLoader(ds, batch_size=32)
    preds, labels = [], []
    with torch.no_grad():
        for batch in loader:
            out = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )
            preds.append(out.logits.argmax(-1).cpu())
            labels.append(batch["label"])
    preds = torch.cat(preds)
    labels = torch.cat(labels)
    acc = (preds == labels).float().mean().item()
    return acc, preds.numpy(), labels.numpy()

In [None]:
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)
all_metrics = {}

for config_name, ds_dict in datasets.items():
    all_metrics[config_name] = {}
    for split_name in get_eval_splits(ds_dict):
        acc, _, _ = evaluate_split(ds_dict[split_name])
        all_metrics[config_name][split_name] = {"accuracy": acc}
        print(f"{config_name} / {split_name}: accuracy = {acc:.4f}")

import json
with open(Path(RESULTS_DIR) / "metrics.json", "w") as f:
    json.dump(all_metrics, f, indent=2)
print(f"Saved to {RESULTS_DIR}/metrics.json")