In [1]:
# Cell 1 — environment check and imports
# Run this to ensure required libs are available and to import everything we'll use.
import sys, math, os
import torch
print("Python:", sys.version.splitlines()[0])
print("Torch:", getattr(torch, "__version__", "n/a"), "CUDA:", torch.cuda.is_available())

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader


Python: 3.10.19 | packaged by conda-forge | (main, Oct 22 2025, 22:29:10) [GCC 14.3.0]
Torch: 2.9.0+cu128 CUDA: True


In [2]:
# Cell 2 — load SNLI and inspect raw examples
snli = load_dataset("snli")
print("Splits:", snli.keys())
print("Sizes: train", len(snli["train"]), "validation", len(snli["validation"]), "test", len(snli["test"]))

# show first 3 raw examples (these are plain python dicts)
for i in range(3):
    ex = snli["train"][i]
    print(f"\nExample {i}:")
    print(" Premise:", ex["premise"])
    print(" Hypothesis:", ex["hypothesis"])
    print(" Label:", ex["label"], " (0=entailment, 1=neutral, 2=contradiction; -1 may mean missing)")


Splits: dict_keys(['test', 'validation', 'train'])
Sizes: train 550152 validation 10000 test 10000

Example 0:
 Premise: A person on a horse jumps over a broken down airplane.
 Hypothesis: A person is training his horse for a competition.
 Label: 1  (0=entailment, 1=neutral, 2=contradiction; -1 may mean missing)

Example 1:
 Premise: A person on a horse jumps over a broken down airplane.
 Hypothesis: A person is at a diner, ordering an omelette.
 Label: 2  (0=entailment, 1=neutral, 2=contradiction; -1 may mean missing)

Example 2:
 Premise: A person on a horse jumps over a broken down airplane.
 Hypothesis: A person is outdoors, on a horse.
 Label: 0  (0=entailment, 1=neutral, 2=contradiction; -1 may mean missing)


In [3]:
# Cell 3 — filter invalid labels and look at label distribution
snli = snli.filter(lambda ex: ex["label"] is not None and ex["label"] >= 0)
from collections import Counter
def label_counts(split):
    return Counter([ex["label"] for ex in snli[split]])
print("Label counts (train):", label_counts("train"))
print("Label counts (validation):", label_counts("validation"))
print("Label counts (test):", label_counts("test"))

Label counts (train): Counter({0: 183416, 2: 183187, 1: 182764})
Label counts (validation): Counter({0: 3329, 2: 3278, 1: 3235})
Label counts (test): Counter({0: 3368, 2: 3237, 1: 3219})


In [4]:
# Cell 4 — tokenizer: what it *does* and when it runs
# We'll use a transformers tokenizer (subword/BERT-style). It converts text -> token ids and creates attention mask.
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
print("Tokenizer:", tokenizer.__class__.__name__)
print("Vocab size:", tokenizer.vocab_size)
print("Pad token id:", tokenizer.pad_token_id, "Pad token:", tokenizer.pad_token)

# Example: tokenization pipeline on ONE pair (instant demonstration)
sample = snli["train"][0]
enc = tokenizer(sample["premise"], sample["hypothesis"], truncation=True, padding="max_length", max_length=32)
print("\nEncoded keys:", list(enc.keys()))
print("input_ids (length):", len(enc["input_ids"]))
print("input_ids (first 20):", enc["input_ids"][:20])
print("attention_mask (first 20):", enc["attention_mask"][:20])
print("token strings (first 20):", tokenizer.convert_ids_to_tokens(enc["input_ids"])[:20])

# NOTE: tokenizers map text -> ids (and produce masks). This step is CPU work (fast with 'use_fast').
# You can run this once for the whole dataset (pre-tokenize) or run it each batch (on-the-fly).


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Tokenizer: BertTokenizerFast
Vocab size: 30522
Pad token id: 0 Pad token: [PAD]

Encoded keys: ['input_ids', 'token_type_ids', 'attention_mask']
input_ids (length): 32
input_ids (first 20): [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586]
attention_mask (first 20): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
token strings (first 20): ['[CLS]', 'a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane', '.', '[SEP]', 'a', 'person', 'is', 'training', 'his', 'horse']
