In [1]:
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os   
from tqdm import tqdm 
from torch.utils.data import Dataset, DataLoader

In [2]:
import pandas as pd

splits = {'test': 'plain_text/test-00000-of-00001.parquet', 'validation': 'plain_text/validation-00000-of-00001.parquet', 'train': 'plain_text/train-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/stanfordnlp/snli/" + splits["test"])

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df_train = pd.read_parquet("hf://datasets/stanfordnlp/snli/" + splits["train"])

In [4]:
df_train = df_train[df_train['label'] != -1]

In [5]:
df_train = df_train.iloc[:100000]

In [6]:
def clean_text(text):
    text = str(text).lower()
    text = re.sub(r"[.,!?\\-]", '', text) # Remove punctuation
    return text

all_text = " ".join([clean_text(row['premise']) + " " + clean_text(row['hypothesis']) for _, row in df_train.iterrows()])
unique_words = set(all_text.split())

In [7]:
word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[UNK]': 3}
for i, w in enumerate(unique_words):
    word2id[w] = i + 4

VOCAB_SIZE = len(word2id)
print(f"Vocabulary Size: {VOCAB_SIZE}")

Vocabulary Size: 17863


In [8]:
class SNLIDataset(Dataset):
    def __init__(self, df, word2id, max_len):
        self.df = df
        self.word2id = word2id
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def tokenize(self, text):
        tokens = clean_text(text).split()
        ids = [self.word2id.get(w, self.word2id['[UNK]']) for w in tokens]
        # Add [CLS] and [SEP]
        ids = [self.word2id['[CLS]']] + ids + [self.word2id['[SEP]']]
        # Pad or Truncate
        if len(ids) < self.max_len:
            ids = ids + [self.word2id['[PAD]']] * (self.max_len - len(ids))
        else:
            ids = ids[:self.max_len]
        return torch.tensor(ids, dtype=torch.long)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        tokens_a = self.tokenize(row['premise'])
        tokens_b = self.tokenize(row['hypothesis'])
        label = torch.tensor(int(row['label']), dtype=torch.long)
        return tokens_a, tokens_b, label

In [9]:
dataset = SNLIDataset(df_train, word2id, 100)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [10]:
from bert import BERT

In [11]:
n_layers = 2    # number of Encoder of Encoder Layer
n_heads  = 2    # number of heads in Multi-Head Attention
d_model  = 256  # Embedding Size
d_ff = 256 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 16  # dimension of K(=Q), V
n_segments = 2

In [12]:
class SentenceBERT(nn.Module):
    def __init__(self, bert_model, embed_dim, num_classes=3):
        super(SentenceBERT, self).__init__()
        self.bert = bert_model
        
        # The Classifier Layer: Takes concatenated (u, v, |u-v|)
        # Input size is 3x the embedding dimension
        self.classifier = nn.Linear(embed_dim * 3, num_classes)
        self.device = bert_model.device

    def mean_pooling(self, token_embeddings, attention_mask):
        # token_embeddings shape: [batch_size, seq_len, embed_dim]
        # attention_mask shape: [batch_size, seq_len]
        
        # Mask out padding tokens (make them zero so they don't affect average)
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        
        # Sum of all valid token vectors
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        
        # Count of valid tokens (avoid division by zero)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        # Average
        return sum_embeddings / sum_mask

    def forward(self, input_ids_a, input_ids_b):
        # 1. Create dummy segment_ids (All zeros for single sentences)
        # Your BERT expects segment_ids, but SBERT treats each sentence independently.
        segment_ids_a = torch.zeros_like(input_ids_a).to(self.device)
        segment_ids_b = torch.zeros_like(input_ids_b).to(self.device)

        # 2. Pass through YOUR BERT (Shared Weights)
        # We use get_last_hidden_state, NOT the forward() used for pre-training
        out_a = self.bert.get_last_hidden_state(input_ids_a, segment_ids_a)
        out_b = self.bert.get_last_hidden_state(input_ids_b, segment_ids_b)

        # 3. Create Attention Masks (0 for PAD, 1 for Real)
        # Assuming 0 is your PAD token ID
        mask_a = (input_ids_a != 0) 
        mask_b = (input_ids_b != 0)

        # 4. Mean Pooling -> u and v
        u = self.mean_pooling(out_a, mask_a)
        v = self.mean_pooling(out_b, mask_b)

        # 5. Concatenate: (u, v, |u-v|)
        features = torch.cat([u, v, torch.abs(u - v)], dim=1)

        # 6. Classify
        logits = self.classifier(features)
        return logits

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [14]:
my_base_bert = BERT(n_layers, n_heads, d_model, d_ff, d_k, n_segments, VOCAB_SIZE, 100, device).to(device)
# my_base_bert.load_state_dict(torch.load('bert_model.pt'))

In [15]:
sbert_model = SentenceBERT(my_base_bert, embed_dim=d_model).to(device)

In [16]:
# 4. Test it with dummy data
# Create two fake sentences of token IDs
fake_a = torch.tensor([[1, 45, 23, 2, 0]]).to(device) # [CLS] ... [SEP] [PAD]
fake_b = torch.tensor([[1, 99, 12, 2, 0]]).to(device)

output = sbert_model(fake_a, fake_b)
print("Logits:", output) # Should be shape [1, 3] (Entailment, Neutral, Contradiction)

Logits: tensor([[-0.2589,  0.0661,  0.1251]], device='cuda:0',
       grad_fn=<AddmmBackward0>)


In [17]:
def train_model(model, train_loader, optimizer, criterion, epochs):
    model.train()  # Set model to training mode
    
    for epoch in range(epochs):
        loop = tqdm(train_loader, leave=True)
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (input_ids_a, input_ids_b, labels) in enumerate(loop):
            # Move data to GPU
            input_ids_a = input_ids_a.to(device)
            input_ids_b = input_ids_b.to(device)
            labels = labels.to(device)

            # ---------------------------------------
            # 1. Forward Pass
            # ---------------------------------------
            # The model takes two sentences (A & B) and outputs logits (3 classes)
            outputs = model(input_ids_a, input_ids_b)

            # ---------------------------------------
            # 2. Calculate Loss
            # ---------------------------------------
            # print(outputs)
            # print(labels)
            loss = criterion(outputs, labels)

            # ---------------------------------------
            # 3. Backward Pass (Optimization)
            # ---------------------------------------
            optimizer.zero_grad() # Clear old gradients
            loss.backward()       # Calculate new gradients
            optimizer.step()      # Update weights

            # ---------------------------------------
            # 4. Statistics
            # ---------------------------------------
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update progress bar
            loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
            loop.set_postfix(loss=loss.item(), acc=100*correct/total)

    print("Training Complete!")

In [18]:
optimizer = optim.Adam(sbert_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss() # For 3 classes (Entailment, Neutral, Contradiction)

In [19]:
train_model(sbert_model, loader, optimizer, criterion, 20)

Epoch [1/20]: 100%|██████████| 3125/3125 [01:28<00:00, 35.31it/s, acc=58.7, loss=0.96] 
Epoch [2/20]: 100%|██████████| 3125/3125 [01:35<00:00, 32.57it/s, acc=66.4, loss=0.538]
Epoch [3/20]: 100%|██████████| 3125/3125 [02:01<00:00, 25.65it/s, acc=69.4, loss=0.77] 
Epoch [4/20]: 100%|██████████| 3125/3125 [02:03<00:00, 25.22it/s, acc=71.7, loss=0.856]
Epoch [5/20]: 100%|██████████| 3125/3125 [01:53<00:00, 27.59it/s, acc=73.4, loss=0.573]
Epoch [6/20]: 100%|██████████| 3125/3125 [02:05<00:00, 24.98it/s, acc=74.9, loss=0.481]
Epoch [7/20]: 100%|██████████| 3125/3125 [02:03<00:00, 25.39it/s, acc=76.3, loss=0.38] 
Epoch [8/20]: 100%|██████████| 3125/3125 [02:02<00:00, 25.51it/s, acc=77.7, loss=0.872]
Epoch [9/20]: 100%|██████████| 3125/3125 [02:00<00:00, 26.03it/s, acc=78.9, loss=0.519]
Epoch [10/20]: 100%|██████████| 3125/3125 [01:56<00:00, 26.78it/s, acc=80.1, loss=0.537]
Epoch [11/20]: 100%|██████████| 3125/3125 [02:02<00:00, 25.49it/s, acc=81.4, loss=0.704]
Epoch [12/20]: 100%|██████████

Training Complete!





In [20]:
torch.save(sbert_model.state_dict(), "sbert_model.pt")

In [17]:
sbert_model.load_state_dict(torch.load('sbert_model.pt'))

<All keys matched successfully>

In [25]:
df_test = pd.read_parquet("hf://datasets/stanfordnlp/snli/" + splits["test"])
df_test = df_test[df_test['label'].isin([0, 1, 2])].dropna().reset_index(drop=True)

In [26]:
test_dataset = SNLIDataset(df_test, word2id, 100)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [28]:
from sklearn.metrics import accuracy_score, classification_report
def evaluate_model(model, test_loader):
    model.eval()  # Set model to evaluation mode
    
    all_preds = []
    all_labels = []
    
    print("Evaluating...")
    with torch.no_grad():  # Disable gradient calculation for speed
        for batch_a, batch_b, labels in tqdm(test_loader):
            # Move to device
            batch_a, batch_b, labels = batch_a.to(device), batch_b.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(batch_a, batch_b)
            
            # Get predictions
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    # Calculate Metrics
    acc = accuracy_score(all_labels, all_preds)
    print(f"\nTest Accuracy: {acc:.4f}")
    
    # Detailed Report
    target_names = ['Entailment', 'Neutral', 'Contradiction']
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=target_names))
    
    return acc

In [29]:
# Run evaluation
evaluate_model(sbert_model, test_loader)

Evaluating...


100%|██████████| 9824/9824 [01:11<00:00, 137.06it/s]


Test Accuracy: 0.7051

Classification Report:
               precision    recall  f1-score   support

   Entailment       0.73      0.74      0.74      3368
      Neutral       0.64      0.68      0.66      3219
Contradiction       0.74      0.69      0.72      3237

     accuracy                           0.71      9824
    macro avg       0.71      0.70      0.71      9824
 weighted avg       0.71      0.71      0.71      9824






0.7051099348534202

In [18]:
def predict_inference(premise, hypothesis, model, dataset, device):
    model.eval() # Set model to evaluation mode
    
    # 1. Tokenize using your specific method
    # Your tokenizer returns shape [max_len], but model expects [batch_size, max_len]
    # We use .unsqueeze(0) to add a batch size of 1
    ids_a = dataset.tokenize(premise).unsqueeze(0).to(device)
    ids_b = dataset.tokenize(hypothesis).unsqueeze(0).to(device)
    
    # 2. Forward Pass
    with torch.no_grad():
        logits = model(ids_a, ids_b)
        probs = torch.softmax(logits, dim=1)
        prediction = torch.argmax(probs, dim=1).item()
        
    # 3. Decode Result
    label_map = {0: "Entailment", 1: "Neutral", 2: "Contradiction"}
    confidence = probs[0][prediction].item()
    
    print(f"\nPremise:    {premise}")
    print(f"Hypothesis: {hypothesis}")
    print(f"Prediction: {label_map[prediction]} ({confidence*100:.2f}%)")
    
    return prediction

In [21]:
predict_inference(
    premise="A man is playing a guitar on stage", 
    hypothesis="The man is performing music", 
    model=sbert_model, 
    dataset=dataset, # Uses your tokenizer logic
    device=device
)


Premise:    A man is playing a guitar on stage
Hypothesis: The man is performing music
Prediction: Entailment (38.64%)


0