In [1]:
# Cell 1 — environment check and imports
# Run this to ensure required libs are available and to import everything we'll use.
import sys, math, os
import torch
print("Python:", sys.version.splitlines()[0])
print("Torch:", getattr(torch, "__version__", "n/a"), "CUDA:", torch.cuda.is_available())

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader


Python: 3.10.19 | packaged by conda-forge | (main, Oct 22 2025, 22:29:10) [GCC 14.3.0]
Torch: 2.9.0+cu128 CUDA: True


In [3]:
# Cell 2 — load SNLI and inspect raw examples
snli = load_dataset("snli")
print("Splits:", snli.keys())
print("Sizes: train", len(snli["train"]), "validation", len(snli["validation"]), "test", len(snli["test"]))

# show first 3 raw examples (these are plain python dicts)
for i in range(3):
    ex = snli["train"][i]
    print(f"\nExample {i}:")
    print(" Premise:", ex["premise"])
    print(" Hypothesis:", ex["hypothesis"])
    print(" Label:", ex["label"], " (0=entailment, 1=neutral, 2=contradiction; -1 may mean missing)")


Splits: dict_keys(['test', 'validation', 'train'])
Sizes: train 550152 validation 10000 test 10000

Example 0:
 Premise: A person on a horse jumps over a broken down airplane.
 Hypothesis: A person is training his horse for a competition.
 Label: 1  (0=entailment, 1=neutral, 2=contradiction; -1 may mean missing)

Example 1:
 Premise: A person on a horse jumps over a broken down airplane.
 Hypothesis: A person is at a diner, ordering an omelette.
 Label: 2  (0=entailment, 1=neutral, 2=contradiction; -1 may mean missing)

Example 2:
 Premise: A person on a horse jumps over a broken down airplane.
 Hypothesis: A person is outdoors, on a horse.
 Label: 0  (0=entailment, 1=neutral, 2=contradiction; -1 may mean missing)


In [4]:
# Cell 3 — filter invalid labels and look at label distribution
snli = snli.filter(lambda ex: ex["label"] is not None and ex["label"] >= 0)
from collections import Counter
def label_counts(split):
    return Counter([ex["label"] for ex in snli[split]])
print("Label counts (train):", label_counts("train"))
print("Label counts (validation):", label_counts("validation"))
print("Label counts (test):", label_counts("test"))

Label counts (train): Counter({0: 183416, 2: 183187, 1: 182764})
Label counts (validation): Counter({0: 3329, 2: 3278, 1: 3235})
Label counts (test): Counter({0: 3368, 2: 3237, 1: 3219})


In [5]:
# Cell 4 — tokenizer: what it *does* and when it runs
# We'll use a transformers tokenizer (subword/BERT-style). It converts text -> token ids and creates attention mask.
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
print("Tokenizer:", tokenizer.__class__.__name__)
print("Vocab size:", tokenizer.vocab_size)
print("Pad token id:", tokenizer.pad_token_id, "Pad token:", tokenizer.pad_token)

# 2️⃣  Decide maximum length
#    SNLI sentences are short, 64 tokens is plenty
max_len = 64

# 3️⃣  Define a function that tokenizes both premise & hypothesis separately
def tokenize_pair(batch):
    premise = tokenizer(batch["premise"], truncation=True, padding="max_length", max_length=max_len)
    hypo    = tokenizer(batch["hypothesis"], truncation=True, padding="max_length", max_length=max_len)
    return {
        "premise_input_ids": premise["input_ids"],
        "premise_attention_mask": premise["attention_mask"],
        "hypo_input_ids": hypo["input_ids"],
        "hypo_attention_mask": hypo["attention_mask"],
    }

# 4. apply map BUT keep the label column by removing everything EXCEPT 'label'
orig_cols = snli["train"].column_names
cols_to_remove = [c for c in orig_cols if c != "label"]  # remove all except label
snli_tok = snli.map(tokenize_pair, batched=True, remove_columns=cols_to_remove)

print(snli_tok)
print("Columns:", snli_tok["train"].column_names)

# NOTE: tokenizers map text -> ids (and produce masks). This step is CPU work (fast with 'use_fast').
# You can run this once for the whole dataset (pre-tokenize) or run it each batch (on-the-fly).


Tokenizer: BertTokenizerFast
Vocab size: 30522
Pad token id: 0 Pad token: [PAD]
DatasetDict({
    test: Dataset({
        features: ['label', 'premise_input_ids', 'premise_attention_mask', 'hypo_input_ids', 'hypo_attention_mask'],
        num_rows: 9824
    })
    validation: Dataset({
        features: ['label', 'premise_input_ids', 'premise_attention_mask', 'hypo_input_ids', 'hypo_attention_mask'],
        num_rows: 9842
    })
    train: Dataset({
        features: ['label', 'premise_input_ids', 'premise_attention_mask', 'hypo_input_ids', 'hypo_attention_mask'],
        num_rows: 549367
    })
})
Columns: ['label', 'premise_input_ids', 'premise_attention_mask', 'hypo_input_ids', 'hypo_attention_mask']


In [6]:
# Cell 5 — set dataset to return PyTorch tensors and create a DataLoader (fixed-length prepadding)
snli_tok.set_format(type="torch", columns=[
    "premise_input_ids", "premise_attention_mask",
    "hypo_input_ids", "hypo_attention_mask", "label"
])
from torch.utils.data import DataLoader
train_dl = DataLoader(snli_tok["train"], batch_size=32, shuffle=True)

# inspect one batch
batch = next(iter(train_dl))
print("Batch keys:", batch.keys())
print("premise_input_ids shape:", batch["premise_input_ids"].shape)  # (B, L)
print("premise_attention_mask shape:", batch["premise_attention_mask"].shape)
print("label shape:", batch["label"].shape)


Batch keys: dict_keys(['label', 'premise_input_ids', 'premise_attention_mask', 'hypo_input_ids', 'hypo_attention_mask'])
premise_input_ids shape: torch.Size([32, 64])
premise_attention_mask shape: torch.Size([32, 64])
label shape: torch.Size([32])


In [7]:
sample = snli_tok["train"][0]
pad_id = tokenizer.pad_token_id

# convert IDs back to tokens (stop at first PAD)
prem_ids = sample["premise_input_ids"].tolist()
hypo_ids = sample["hypo_input_ids"].tolist()

prem_tokens = tokenizer.convert_ids_to_tokens(
    prem_ids[:prem_ids.index(pad_id)] if pad_id in prem_ids else prem_ids
)
hypo_tokens = tokenizer.convert_ids_to_tokens(
    hypo_ids[:hypo_ids.index(pad_id)] if pad_id in hypo_ids else hypo_ids
)

print("Premise tokens:", prem_tokens)
print("Hypothesis tokens:", hypo_tokens)
print("Label:", sample["label"])


Premise tokens: ['[CLS]', 'a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane', '.', '[SEP]']
Hypothesis tokens: ['[CLS]', 'a', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition', '.', '[SEP]']
Label: tensor(1)


In [8]:
import torch
import torch.nn as nn
import math

# --- Positional Encoding (sinusoidal) ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

# --- MultiHead Attention implemented manually ---
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super().__init__()
        assert embedding_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads
        self.qkv_layer = nn.Linear(embedding_dim, 3 * embedding_dim)
        self.output_layer = nn.Linear(embedding_dim, embedding_dim)
        self.attention_weights = None

    def forward(self, X, mask=None):
        B, L, D = X.shape
        qkv = self.qkv_layer(X)
        Q, K, V = qkv.chunk(3, dim=-1)
        Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = torch.softmax(scores, dim=-1)
        self.attention_weights = attn.detach()
        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(B, L, D)
        return self.output_layer(out)

# --- NLI Model ---
class NLIModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads, pad_idx):
        super().__init__()
        self.pad_idx = pad_idx
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.pos_enc = PositionalEncoding(embedding_dim)
        self.encoder = MultiHeadAttention(embedding_dim, num_heads)
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim * 2, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 3)
        )

    def encode(self, input_ids, mask):
        emb = self.embedding(input_ids)
        emb = self.pos_enc(emb)
        out = self.encoder(emb, mask.unsqueeze(1).unsqueeze(1))
        pooled = (out * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True)
        return pooled

    def forward(self, premise_ids, premise_mask, hypo_ids, hypo_mask):
        prem_vec = self.encode(premise_ids, premise_mask)
        hypo_vec = self.encode(hypo_ids, hypo_mask)
        combined = torch.cat([prem_vec, hypo_vec], dim=1)
        logits = self.classifier(combined)
        return logits
