In [1]:
import math
import re
from random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [3]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

In [4]:
raw_text = [line for line in dataset['text'] if len(line) > 50] 
subset_text = raw_text[:100000] # subset of 100k

In [5]:
sentences = [s.lower().strip() for s in subset_text]
word_list = list(set(" ".join(sentences).split()))
word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}

In [6]:
for i, w in enumerate(word_list):
    word2id[w] = i + 4
id2word = {i: w for w, i in word2id.items()}
vocab_size = len(word2id)

In [7]:
token_list = []
for sentence in sentences:
    arr = [word2id[word] for word in sentence.split() if word in word2id]
    token_list.append(arr)

In [8]:
#hyperparameters
max_len = 128    # Maximum sequence length
batch_size = 32
max_mask = 5     # Max tokens to mask per sentence
n_layers = 6
n_heads = 8
d_model = 768
d_ff = 768 * 4
d_k = d_v = 64
n_segments = 2

In [None]:
#Data Loader Logic
def make_batch():
    batch = []
    positive = negative = 0

    while len(batch) < batch_size:
        idx_a = randrange(len(token_list))

        # --- Correct NSP pairing ---
        if random() < 0.5 and idx_a < len(token_list) - 1:
            idx_b = idx_a + 1      # True next sentence (positive)
            is_next = 1
        else:
            idx_b = randrange(len(token_list))  # Random sentence (negative)
            is_next = 0

        tokens_a, tokens_b = token_list[idx_a], token_list[idx_b]

        tokens_a = tokens_a[:max_len//2 - 2]
        tokens_b = tokens_b[:max_len//2 - 2]

        input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        n_pred = min(max_mask, max(1, int(round(len(input_ids) * 0.15))))
        cand_masked_pos = [i for i, token in enumerate(input_ids) if token not in range(4)]
        shuffle(cand_masked_pos)

        masked_tokens, masked_pos = [], []
        for pos in cand_masked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])

            if random() < 0.8:
                input_ids[pos] = word2id['[MASK]']
            elif random() < 0.5:
                index = randint(0, vocab_size - 1)
                input_ids[pos] = index

        if len(input_ids) > max_len:
            input_ids = input_ids[:max_len]
            segment_ids = segment_ids[:max_len]

            valid = [i for i, p in enumerate(masked_pos) if p < max_len]
            masked_pos = [masked_pos[i] for i in valid]
            masked_tokens = [masked_tokens[i] for i in valid]

        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        if max_mask > len(masked_tokens):
            n_pad = max_mask - len(masked_tokens)
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        if is_next == 1 and positive < batch_size // 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, 1])
            positive += 1
        elif is_next == 0 and negative < batch_size // 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, 0])
            negative += 1

    return batch

In [10]:
#BERT Architecture

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.seg_embed = nn.Embedding(n_segments, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0).expand_as(x)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    return seq_k.data.eq(0).unsqueeze(1).expand(batch_size, len_q, len_k)


In [11]:
class ScaledDotProductAttention(nn.Module):
    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
        scores.masked_fill_(attn_mask, -1e9)
        attn = nn.Softmax(dim=-1)(scores)
        return torch.matmul(attn, V), attn

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
        self.linear = nn.Linear(n_heads * d_v, d_model)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, Q, K, V, attn_mask):
        residual, batch_size = Q, Q.size(0)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
        context, _ = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
        output = self.linear(context)
        return self.layer_norm(output + residual)

In [13]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

In [14]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs

In [15]:
class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # Decoder weights tied to embeddings
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
        for layer in self.layers:
            output = layer(output, enc_self_attn_mask)
        
        h_pooled = self.activ(self.fc(output[:, 0]))
        logits_nsp = self.classifier(h_pooled)
        
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
        h_masked = torch.gather(output, 1, masked_pos)
        h_masked = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias
        
        return logits_lm, logits_nsp

In [16]:
#training loop
model = BERT().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [17]:
model.train()
for epoch in range(10):  
    total_loss = 0

    for _ in range(100):
        batch = make_batch()

        input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
            map(torch.LongTensor, zip(*batch))

        input_ids = input_ids.to(device)
        segment_ids = segment_ids.to(device)
        masked_tokens = masked_tokens.to(device)
        masked_pos = masked_pos.to(device)
        isNext = isNext.to(device)

        optimizer.zero_grad()
        logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)

        loss_lm = criterion(
            logits_lm.view(-1, vocab_size),
            masked_tokens.view(-1)
        )
        loss_nsp = criterion(logits_nsp, isNext)

        loss = loss_lm + loss_nsp
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/100:.4f}")

Epoch 1, Loss: 68.2752
Epoch 2, Loss: 23.4773
Epoch 3, Loss: 16.0118
Epoch 4, Loss: 13.3716
Epoch 5, Loss: 12.0592
Epoch 6, Loss: 11.1732
Epoch 7, Loss: 10.7772
Epoch 8, Loss: 10.4050
Epoch 9, Loss: 10.0351
Epoch 10, Loss: 9.8702


In [18]:
torch.save(model.state_dict(), 'bert_from_scratch.pth')
print("BERT Weights Saved")

BERT Weights Saved


## TASK 2

In [19]:
from torch.utils.data import DataLoader

class SentenceBERT(nn.Module):
    def __init__(self, bert_model, hidden_size, num_classes=3):
        super(SentenceBERT, self).__init__()
        self.bert = bert_model
        self.classifier = nn.Linear(hidden_size * 3, num_classes)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def forward(self, input_ids_a, seg_ids_a, input_ids_b, seg_ids_b):
        # Siamese structure: SAME BERT for both sentences
        u_output = self.get_bert_embeddings(input_ids_a, seg_ids_a)
        v_output = self.get_bert_embeddings(input_ids_b, seg_ids_b)

        u = self.mean_pooling(u_output, (input_ids_a != 0))
        v = self.mean_pooling(v_output, (input_ids_b != 0))

        uv_abs = torch.abs(u - v)
        x = torch.cat([u, v, uv_abs], dim=-1)

        return self.classifier(x)

    def get_bert_embeddings(self, input_ids, segment_ids):
        x = self.bert.embedding(input_ids, segment_ids)
        attn_mask = get_attn_pad_mask(input_ids, input_ids)
        for layer in self.bert.layers:
            x = layer(x, attn_mask)
        return x

In [20]:
snli = load_dataset("snli", split="train")
snli = snli.filter(lambda x: x['label'] != -1).select(range(20000))

In [21]:
#Custom Tokenizer
import re

def tokenize_snli(example):
    def encode(text):
        text = re.sub(r'[^\w\s]', '', text.lower())
        tokens = text.split()
        return [word2id.get(w, word2id['[MASK]']) for w in tokens][:max_len]

    def pad(ids):
        return ids + [0] * (max_len - len(ids))

    ids_a = encode(example['premise'])
    ids_b = encode(example['hypothesis'])

    return {
        'input_ids_a': pad(ids_a),
        'input_ids_b': pad(ids_b),
        'label': example['label']
    }

In [22]:
tokenized_snli = snli.map(tokenize_snli)
tokenized_snli.set_format(type='torch', columns=['input_ids_a', 'input_ids_b', 'label'])
loader = DataLoader(tokenized_snli, batch_size=16, shuffle=True)

Map: 100%|██████████| 20000/20000 [00:00<00:00, 30478.70 examples/s]


In [23]:
#Loading BERT weights
base_bert = BERT().to(device)
base_bert.load_state_dict(torch.load('bert_from_scratch.pth', map_location=device))

<All keys matched successfully>

In [24]:
sbert_model = SentenceBERT(base_bert, d_model).to(device)
optimizer = optim.Adam(sbert_model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

In [25]:
#Training Loop (SoftmaxLoss objective)
sbert_model.train()
for epoch in range(10):
    total_loss = 0

    for batch in loader:
        ids_a = batch['input_ids_a'].to(device)
        ids_b = batch['input_ids_b'].to(device)
        labels = batch['label'].to(device)

        seg_a = torch.zeros_like(ids_a).to(device)
        seg_b = torch.zeros_like(ids_b).to(device)

        optimizer.zero_grad()
        logits = sbert_model(ids_a, seg_a, ids_b, seg_b)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"SBERT Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")

SBERT Epoch 1, Loss: 1.2335
SBERT Epoch 2, Loss: 1.1053
SBERT Epoch 3, Loss: 1.1028
SBERT Epoch 4, Loss: 1.0648
SBERT Epoch 5, Loss: 1.0275
SBERT Epoch 6, Loss: 1.0029
SBERT Epoch 7, Loss: 0.9810
SBERT Epoch 8, Loss: 0.9591
SBERT Epoch 9, Loss: 0.9412
SBERT Epoch 10, Loss: 0.9148


In [26]:
#Save Sentence-BERT model
torch.save(sbert_model.state_dict(), 'sbert_model.pth')
print("Sentence-BERT Saved")

Sentence-BERT Saved


## TASK 3

In [27]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
import torch

try:
    print(f"Using device: {device}")
    print(f"Vocab size: {len(word2id)}")
except NameError:
    print("error")

def tokenize_snli_eval(example):
    tokens_a = example['premise'].lower().split()
    tokens_b = example['hypothesis'].lower().split()
    
    ids_a = [word2id.get(w, word2id['[MASK]']) for w in tokens_a][:max_len]
    ids_b = [word2id.get(w, word2id['[MASK]']) for w in tokens_b][:max_len]
    
    # Padding
    ids_a += [0] * (max_len - len(ids_a))
    ids_b += [0] * (max_len - len(ids_b))
    
    return {
        'input_ids_a': torch.tensor(ids_a),
        'input_ids_b': torch.tensor(ids_b),
        'label': example['label']
    }

test_dataset = load_dataset("snli", split="test")
test_dataset = test_dataset.filter(lambda x: x['label'] != -1).select(range(1000)) # Using 1000 for speed

tokenized_test = test_dataset.map(tokenize_snli_eval)
tokenized_test.set_format(type='torch', columns=['input_ids_a', 'input_ids_b', 'label'])
test_loader = DataLoader(tokenized_test, batch_size=32)

#evaluation
sbert_model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        ids_a = batch['input_ids_a'].to(device)
        ids_b = batch['input_ids_b'].to(device)
        labels = batch['label'].to(device)
        
        seg_a = torch.zeros_like(ids_a).to(device)
        seg_b = torch.zeros_like(ids_b).to(device)
        
        logits = sbert_model(ids_a, seg_a, ids_b, seg_b)
        preds = torch.argmax(logits, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

target_names = ['entailment', 'neutral', 'contradiction']
print("\n" + "="*40)
print("PERFORMANCE METRICS")
print("="*40)
print(classification_report(all_labels, all_preds, target_names=target_names))

Using device: mps
Vocab size: 66143


Map: 100%|██████████| 1000/1000 [00:00<00:00, 15761.01 examples/s]



PERFORMANCE METRICS
               precision    recall  f1-score   support

   entailment       0.48      0.65      0.55       344
      neutral       0.42      0.49      0.45       327
contradiction       0.48      0.22      0.31       329

     accuracy                           0.46      1000
    macro avg       0.46      0.46      0.44      1000
 weighted avg       0.46      0.46      0.44      1000

