# BERT (Updated 1 Feb 2025, Available CUDA)

We shall implement BERT.  For this tutorial, you may want to first look at my Transformers tutorial to get a basic understanding of Transformers. 

For BERT, the main difference is on how we process the datasets, i.e., masking.   Aside from that, the backbone model is still the Transformers.

In [1]:
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

In [2]:
# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.cuda.get_device_name(0)

cuda


## 1. Data

For simplicity, we shall use very simple data like this.

In [3]:
from datasets import load_dataset

# Load BookCorpus dataset
# The first 1% of `train` split.
dataset = load_dataset('bookcorpus', split='train[:1%]')
dataset = dataset.select(range(100000))
dataset

Dataset({
    features: ['text'],
    num_rows: 100000
})

In [4]:
sentences = dataset['text']
text = [x.lower() for x in sentences] #lower case
text = [re.sub("[.,!?\\-]", '', x) for x in text] #clean all symbols
# text

In [5]:
for sentence in text:
    print(sentence, "_____")
    words = sentence.split()
    print(words)
    break

usually  he would be tearing around the living room  playing with his toys  _____
['usually', 'he', 'would', 'be', 'tearing', 'around', 'the', 'living', 'room', 'playing', 'with', 'his', 'toys']


### Making vocabs

Before making the vocabs, let's remove all question marks and perios, etc, then turn everything to lowercase, and then simply split the text. 

In [6]:
from tqdm.auto import tqdm

# Combine everything into one to make vocab
word_list = list(set(" ".join(text).split()))
word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3, '[UNK]': 4}  # special tokens

# Create the word2id in a single pass
for i, w in tqdm(enumerate(word_list), desc="Creating word2id"):
    word2id[w] = i + 5  # because 0-3 are already occupied

# Precompute the id2word mapping (this can be done once after word2id is fully populated)
id2word = {v: k for k, v in word2id.items()}
vocab_size = len(word2id)
vocab_size

Creating word2id: 0it [00:00, ?it/s]

23069

In [7]:
vocab_size = len(word2id)

# List of all tokens for the whole text
token_list = []

# Process sentences more efficiently
for sentence in tqdm(text, desc="Processing sentences"):
    token_list.append([word2id[word] for word in sentence.split()])

# Now token_list contains the tokenized sentences

Processing sentences:   0%|          | 0/100000 [00:00<?, ?it/s]

In [8]:
#take a look at sentences
sentences[:2]

['usually , he would be tearing around the living room , playing with his toys .',
 'but just one look at a minion sent him practically catatonic .']

In [9]:
#take a look at token_list
token_list[:2]

[[14203,
  564,
  816,
  11098,
  8200,
  16073,
  4389,
  5175,
  13833,
  8905,
  1817,
  13407,
  13834],
 [18001, 15059, 21255, 13626, 12734, 15838, 21774, 3498, 8215, 1829, 17310]]

In [10]:
#testing one sentence
for tokens in token_list[0]:
    print(id2word[tokens])

usually
he
would
be
tearing
around
the
living
room
playing
with
his
toys


## 2. Data loader

We gonna make dataloader.  Inside here, we need to make two types of embeddings: **token embedding** and **segment embedding**

1. **Token embedding** - Given “The cat is walking. The dog is barking”, we add [CLS] and [SEP] >> “[CLS] the cat is walking [SEP] the dog is barking”. 

2. **Segment embedding**
A segment embedding separates two sentences, i.e., [0 0 0 0 1 1 1 1 ]

3. **Masking**
As mentioned in the original paper, BERT randomly assigns masks to 15% of the sequence. In this 15%, 80% is replaced with masks, while 10% is replaced with random tokens, and the rest 10% is left as is.  Here we specified `max_pred` 

4. **Padding**
Once we mask, we will add padding. For simplicity, here we padded until some specified `max_len`. 

Note:  `positive` and `negative` are just simply counts to keep track of the batch size.  `positive` refers to two sentences that are really next to one another.

In [11]:
batch_size = 4
max_mask   = 5  # max masked tokens when 15% exceed, it will only be max_pred
max_len    = 1000 # maximum of length to be padded; 

In [12]:
def make_batch():
    batch = []
    positive = negative = 0  #count of batch size;  we want to have half batch that are positive pairs (i.e., next sentence pairs)
    while positive != batch_size/2 or negative != batch_size/2:
        
        #randomly choose two sentence so we can put [SEP]
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        #retrieve the two sentences
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]

        #1. token embedding - append CLS and SEP
        input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]

        #2. segment embedding - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        #3. mask language modeling
        #masked 15%, but should be at least 1 but does not exceed max_mask
        n_pred =  min(max_mask, max(1, int(round(len(input_ids) * 0.15))))
        #get the pos that excludes CLS and SEP and shuffle them
        cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] and token != word2id['[SEP]']]
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        #simply loop and change the input_ids to [MASK]
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)  #remember the position
            masked_tokens.append(input_ids[pos]) #remember the tokens
            #80% replace with a [MASK], but 10% will replace with a random token
            if random() < 0.1:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                input_ids[pos] = word2id[id2word[index]] # replace
            elif random() < 0.9:  # 80%
                input_ids[pos] = word2id['[MASK]'] # make mask
            else:  #10% do nothing
                pass

        # pad the input_ids and segment ids until the max len
        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # pad the masked_tokens and masked_pos to make sure the lenth is max_mask
        if max_mask > n_pred:
            n_pad = max_mask - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        #check if first sentence is really comes before the second sentence
        #also make sure positive is exactly half the batch size
        if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
            
    return batch

In [13]:
batch = make_batch()

In [14]:
#len of batch
len(batch)

4

In [15]:
#we can deconstruct using map and zip
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))
input_ids.shape, segment_ids.shape, masked_tokens.shape, masked_pos.shape, isNext.shape

(torch.Size([4, 1000]),
 torch.Size([4, 1000]),
 torch.Size([4, 5]),
 torch.Size([4, 5]),
 torch.Size([4]))

## 3. Model

Recall that BERT only uses the encoder.

BERT has the following components:

- Embedding layers
- Attention Mask
- Encoder layer
- Multi-head attention
- Scaled dot product attention
- Position-wise feed-forward network
- BERT (assembling all the components)

## 3.1 Embedding

Here we simply generate the positional embedding, and sum the token embedding, positional embedding, and segment embedding together.

<img src = "figures/BERT_embed.png" width=500>

In [16]:
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)      # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        #x, seg: (bs, len)
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)  # (len,) -> (bs, len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

## 3.2 Attention mask

In [17]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

### Testing the attention mask

In [18]:
print(get_attn_pad_mask(input_ids, input_ids).shape)

torch.Size([4, 1000, 1000])


## 3.3 Encoder

The encoder has two main components: 

- Multi-head Attention
- Position-wise feed-forward network

First let's make the wrapper called `EncoderLayer`

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn       = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn

Let's define the scaled dot attention, to be used inside the multihead attention

In [20]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn 

Let's define the parameters first

In [21]:
n_layers = 6    # number of Encoder of Encoder Layer
n_heads  = 8    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

Here is the Multiheadattention.

In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]

Here is the PoswiseFeedForwardNet.

In [23]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(F.gelu(self.fc1(x)))

## 3.4 Putting them together

In [24]:
class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        
        # 1. predict next sentence
        # it will be decided by first token(CLS)
        h_pooled   = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_nsp = self.classifier(h_pooled) # [batch_size, 2]

        # 2. predict the masked token
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked  = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_nsp

## 4. Training

In [25]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
import time

num_epoch = 1000
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

best_loss = float('inf')

start_time = time.time()
for epoch in range(num_epoch):
    optimizer.zero_grad()
    logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)    

    #1. mlm loss
    #logits_lm.transpose: (bs, vocab_size, max_mask) vs. masked_tokens: (bs, max_mask)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
    loss_lm = (loss_lm.float()).mean()
    #2. nsp loss
    #logits_nsp: (bs, 2) vs. isNext: (bs, )
    loss_nsp = criterion(logits_nsp, isNext) # for sentence classification
    
    #3. combine loss
    loss = loss_lm + loss_nsp
    if loss < best_loss:
        best_loss = loss
        torch.save(model.state_dict(), 'models/BERT-model.pt')

    if epoch % 100 == 0:
        print('Epoch:', '%02d' % (epoch), 'loss =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
print(f'Time: {epoch_mins}m {epoch_secs}s')

Epoch: 00 loss = 102.815659
Epoch: 100 loss = 2.937565
Epoch: 200 loss = 5.830076


In [None]:
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# batch = make_batch()
# input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

# # Move inputs to GPU
# input_ids = input_ids.to(device)
# segment_ids = segment_ids.to(device)
# masked_tokens = masked_tokens.to(device)
# masked_pos = masked_pos.to(device)
# isNext = isNext.to(device)

# # Wrap the epoch loop with tqdm
# for epoch in tqdm(range(num_epoch), desc="Training Epochs"):
#     torch.cuda.empty_cache()
#     optimizer.zero_grad()
#     logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)    
#     #logits_lm: (bs, max_mask, vocab_size) ==> (6, 5, 34)
#     #logits_nsp: (bs, yes/no) ==> (6, 2)

#     #1. mlm loss
#     #logits_lm.transpose: (bs, vocab_size, max_mask) vs. masked_tokens: (bs, max_mask)
#     loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
#     loss_lm = (loss_lm.float()).mean()
#     #2. nsp loss
#     #logits_nsp: (bs, 2) vs. isNext: (bs, )
#     loss_nsp = criterion(logits_nsp, isNext) # for sentence classification
    
#     #3. combine loss
#     loss = loss_lm + loss_nsp
#     if epoch % 100 == 0:
#         print('Epoch:', '%02d' % (epoch), 'loss =', '{:.6f}'.format(loss))
#     loss.backward()
#     optimizer.step()

## 5. Inference

Since our dataset is very small, it won't work very well, but just for the sake of demonstration.

In [None]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[2]))
print([id2word[w.item()] for w in input_ids[0] if id2word[w.item()] != '[PAD]'])
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_tokens = masked_tokens.to(device)
masked_pos = masked_pos.to(device)
isNext = isNext.to(device)

logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)
#logits_lm:  (1, max_mask, vocab_size) ==> (1, 5, 34)
#logits_nsp: (1, yes/no) ==> (1, 2)

#predict masked tokens
#max the probability along the vocab dim (2), [1] is the indices of the max, and [0] is the first value
logits_lm = logits_lm.data.cpu().max(2)[1][0].data.numpy() 
#note that zero is padding we add to the masked_tokens
print('masked tokens (words) : ',[id2word[pos.item()] for pos in masked_tokens[0]])
print('masked tokens list : ',[pos.item() for pos in masked_tokens[0]])
print('masked tokens (words) : ',[id2word[pos.item()] for pos in logits_lm])
print('predict masked tokens list : ', [pos for pos in logits_lm])

#predict nsp
logits_nsp = logits_nsp.cpu().data.max(1)[1][0].data.numpy()
print(logits_nsp)
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_nsp else False)

Trying a bigger dataset should be able to see the difference.

In [None]:
model

## Sentence BERT (Task-2)

In [None]:
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')
mnli['train'].features, snli['train'].features

In [None]:
# List of datasets to remove 'idx' column from
mnli.column_names.keys()

In [None]:
# remove 'idx' column from each dataset
for column_names in mnli.column_names.keys():
    mnli[column_names] = mnli[column_names].remove_columns('idx')

In [None]:
# list of datasets to ensure that 'idx' column is removed
mnli.column_names.keys()

In [None]:
# list all label that have in the dataset
np.unique(mnli['train']['label'])

In [None]:
# create dataset dictionary with sample data (since my computer cannot run all dataset)
from datasets import DatasetDict

raw_dataset = DatasetDict({
    'train': mnli['train'].shuffle(seed=55).select(list(range(1000))),
    'test': mnli['test_mismatched'].shuffle(seed=55).select(list(range(100))),
    'validation': mnli['validation_mismatched'].shuffle(seed=55).select(list(range(100)))
})

raw_dataset

## Preprocessing

In [None]:
import torch
from random import seed, shuffle, random, randint

def preprocess_function(examples):
    # Constants
    MAX_SEQ_LENGTH = 200
    MAX_MASK = 5
    seed(55)

    # Output containers
    premise_ids_list = []
    hypothesis_ids_list = []
    masked_premise_tokens_list = []
    masked_premise_pos_list = []
    masked_hypothesis_tokens_list = []
    masked_hypothesis_pos_list = []
    segments_list = []
    attention_premise_list = []
    attention_hypothesis_list = []
    labels = examples['label']

    def apply_masking(tokens):
        # Determine how many tokens to mask (15% of tokens, at least 1, capped at MAX_MASK)
        num_to_mask = min(MAX_MASK, max(1, int(round(len(tokens) * 0.15))))
        # Exclude special tokens ([CLS] and [SEP])
        valid_positions = [i for i, token in enumerate(tokens)
                           if token not in (word2id['[CLS]'], word2id['[SEP]'])]
        shuffle(valid_positions)

        masked_tokens, masked_positions = [], []
        for pos in valid_positions[:num_to_mask]:
            masked_positions.append(pos)
            masked_tokens.append(tokens[pos])
            r = random()
            if r < 0.1:
                # Replace with a random token
                tokens[pos] = word2id[id2word[randint(0, vocab_size - 1)]]
            elif r < 0.8:
                # Replace with the [MASK] token
                tokens[pos] = word2id['[MASK]']
            # Otherwise, leave the token unchanged

        # Pad or trim lists to ensure fixed length
        masked_tokens = (masked_tokens[:MAX_MASK] + [0] * MAX_MASK)[:MAX_MASK]
        masked_positions = (masked_positions[:MAX_MASK] + [0] * MAX_MASK)[:MAX_MASK]
        return masked_tokens, masked_positions

    for premise, hypothesis in zip(examples['premise'], examples['hypothesis']):
        # Convert words to token indices (handle OOV words)
        premise_tokens = [word2id[word] if word in word_list else len(word_list)
                          for word in premise.split()]
        hypothesis_tokens = [word2id[word] if word in word_list else len(word_list)
                             for word in hypothesis.split()]

        # Add special tokens ([CLS] at start and [SEP] at end)
        premise_tokens = [word2id['[CLS]']] + premise_tokens + [word2id['[SEP]']]
        hypothesis_tokens = [word2id['[CLS]']] + hypothesis_tokens + [word2id['[SEP]']]

        # Create segment ids (all zeros) and apply masking
        segments = [0] * MAX_SEQ_LENGTH
        masked_premise, masked_premise_pos = apply_masking(premise_tokens)
        masked_hypothesis, masked_hypothesis_pos = apply_masking(hypothesis_tokens)

        # Pad or truncate to fixed sequence length
        premise_tokens = (premise_tokens[:MAX_SEQ_LENGTH] + [0] * MAX_SEQ_LENGTH)[:MAX_SEQ_LENGTH]
        hypothesis_tokens = (hypothesis_tokens[:MAX_SEQ_LENGTH] + [0] * MAX_SEQ_LENGTH)[:MAX_SEQ_LENGTH]

        # Generate attention masks (1 for actual tokens, 0 for padding)
        attention_premise = [1 if token != 0 else 0 for token in premise_tokens]
        attention_hypothesis = [1 if token != 0 else 0 for token in hypothesis_tokens]

        # Append results
        premise_ids_list.append(premise_tokens)
        hypothesis_ids_list.append(hypothesis_tokens)
        segments_list.append(segments)
        masked_premise_tokens_list.append(masked_premise)
        masked_premise_pos_list.append(masked_premise_pos)
        masked_hypothesis_tokens_list.append(masked_hypothesis)
        masked_hypothesis_pos_list.append(masked_hypothesis_pos)
        attention_premise_list.append(attention_premise)
        attention_hypothesis_list.append(attention_hypothesis)

    return {
        "premise_input_ids": premise_ids_list,
        "premise_pos_mask": masked_premise_pos_list,
        "hypothesis_input_ids": hypothesis_ids_list,
        "hypothesis_pos_mask": masked_hypothesis_pos_list,
        "segment_ids": segments_list,
        "attention_premise": attention_premise_list,
        "attention_hypothesis": attention_hypothesis_list,
        "labels": labels,
    }

# Apply preprocessing to the dataset
tokenized_datasets = raw_dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['premise', 'hypothesis', 'label'])
tokenized_datasets.set_format("torch")

In [None]:
tokenized_datasets

In [None]:
from torch.utils.data import DataLoader

# create the dataloader
batch_size = 10
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)

In [None]:
# print the shape of each key 
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)
    print(batch['premise_pos_mask'].shape)
    print(batch['hypothesis_input_ids'].shape)
    print(batch['hypothesis_pos_mask'].shape)
    print(batch['segment_ids'].shape)
    print(batch['attention_premise'].shape)
    print(batch['attention_hypothesis'].shape)
    print(batch['labels'].shape)
    break

## Model

In [None]:
# load model from BERT
model = BERT()
model.load_state_dict(torch.load('models/BERT-model.pt'))

In [None]:
def mean_pool(token_embeds, attention_mask):
    # Ensure attention_mask matches token_embeds length
    if attention_mask.shape[1] != token_embeds.shape[1]:
        raise ValueError(f"Mismatch: token_embeds has {token_embeds.shape[1]}, but attention_mask has {attention_mask.shape[1]}")

    # Expand mask and perform mean pooling
    in_mask = attention_mask.unsqueeze(-1).float()  # [batch_size, seq_len, 1]
    pool = torch.sum(token_embeds * in_mask, dim=1) / torch.clamp(in_mask.sum(dim=1), min=1e-9)

    return pool

In [None]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = np.dot(u, v)
    norm_u = np.linalg.norm(u)
    norm_v = np.linalg.norm(v)
    similarity = dot_product / (norm_u * norm_v)
    return similarity

In [None]:
# classifier_head has shape (vocab_size*3,3)
classifier_head = torch.nn.Linear(23069*3, 3).to(device)

optimizer = torch.optim.Adam(s_model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

In [None]:
from transformers import get_linear_schedule_with_warmup

# and setup a warmup for the first ~10% steps
total_steps = int(len(raw_dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler.step()

scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler_classifier.step()

In [None]:
from tqdm.auto import tqdm
import time
import torch

num_epoch = 1 
# 1 epoch should be enough, increase if wanted
start_time = time.time()

best_loss = float('inf')  # Global best loss

for epoch in range(num_epoch):
    s_model.train()  
    classifier_head.train()
    best_loss_epoch = float('inf')  # Reset for this epoch

    for step, batch in enumerate(tqdm(train_dataloader, leave=True)):
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        # Move to device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        segment_ids = batch['segment_ids'].to(device)
        attention_a = batch['attention_premise'].to(device)  
        attention_b = batch['attention_hypothesis'].to(device)  
        label = batch['labels'].to(device)
        
        # Extract token embeddings from BERT
        u, _ = s_model(inputs_ids_a, segment_ids, attention_a)  
        v, _ = s_model(inputs_ids_b, segment_ids, attention_b)  

        # Mean Pooling
        u_mean_pool = mean_pool(u, attention_a)  
        v_mean_pool = mean_pool(v, attention_b)  
        
        # Compute similarity vector
        uv = torch.sub(u_mean_pool, v_mean_pool)  
        uv_abs = torch.abs(uv)
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) 
        
        # Classify
        x = classifier_head(x)  
        loss = criterion(x, label)

        # Update best loss
        if loss.item() < best_loss_epoch:
            best_loss_epoch = loss.item()  # Save best loss for this epoch
            best_model = s_model.state_dict()

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

    # Update learning rate scheduler per epoch
    scheduler.step()
    scheduler_classifier.step()

    # Save model per epoch if best
    if best_loss_epoch < best_loss:
        best_loss = best_loss_epoch
        torch.save(best_model, 'models/SBERT-model.pt')

    print(f'Epoch: {epoch + 1} | Loss = {best_loss_epoch:.6f}')

end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
print(f'Time: {epoch_mins}m {epoch_secs}s')

## Model Evaluation (Task-3)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Global constants
MAX_SEQ_LENGTH = 200
MAX_MASK = 5

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'______\n{total_params:>6}')

In [None]:
def calculate_loss_s_model(model, classifier, criterion, eval_dataloader):
    model.eval()
    classifier.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in eval_dataloader:
            # Move batch data to device
            premise_ids = batch['premise_input_ids'].to(device)
            hypothesis_ids = batch['hypothesis_input_ids'].to(device)
            segments = batch['segment_ids'].to(device)
            attention_premise = batch['attention_premise'].to(device)
            attention_hypothesis = batch['attention_hypothesis'].to(device)
            labels = batch['labels'].to(device)

            # Extract embeddings from the model
            u, _ = model(premise_ids, segments, attention_premise)
            v, _ = model(hypothesis_ids, segments, attention_hypothesis)

            # Mean pooling of token embeddings
            u_mean = mean_pool(u, attention_premise)
            v_mean = mean_pool(v, attention_hypothesis)

            # Form feature vector: concatenate u, v, and |u-v|
            diff = torch.sub(u_mean, v_mean)
            features = torch.cat([u_mean, v_mean, torch.abs(diff)], dim=-1)

            # Classifier prediction and loss computation
            logits = classifier(features)
            loss = criterion(logits, labels)
            total_loss += loss.item()

    average_loss = total_loss / len(eval_dataloader)
    return average_loss

In [None]:
def calculate_cosine_sim_s_model(model, classifier, eval_dataloader):
    model.eval()
    classifier.eval()
    total_similarity = 0
    batch_count = 0

    with torch.no_grad():
        for batch in eval_dataloader:
            # Move batch data to device
            premise_ids = batch['premise_input_ids'].to(device)
            hypothesis_ids = batch['hypothesis_input_ids'].to(device)
            segments = batch['segment_ids'].to(device)
            attention_premise = batch['attention_premise'].to(device)
            attention_hypothesis = batch['attention_hypothesis'].to(device)

            # Extract embeddings
            u, _ = model(premise_ids, segments, attention_premise)
            v, _ = model(hypothesis_ids, segments, attention_hypothesis)

            # Mean pooling
            u_mean = mean_pool(u, attention_premise)
            v_mean = mean_pool(v, attention_hypothesis)

            # Compute cosine similarity for the batch
            sim_scores = F.cosine_similarity(u_mean, v_mean, dim=-1)
            total_similarity += sim_scores.mean().item()
            batch_count += 1

    return total_similarity / batch_count

In [None]:
def tokenize_sentence_s_model(sentence_a, sentence_b):
    # For consistent masking behavior
    seed(55)
    
    # Tokenize each sentence (convert words to indices)
    premise_tokens = [word2id[word] if word in word_list else len(word_list) for word in sentence_a.split()]
    hypothesis_tokens = [word2id[word] if word in word_list else len(word_list) for word in sentence_b.split()]

    # Add special tokens: [CLS] at the beginning and [SEP] at the end
    premise_ids = [word2id['[CLS]']] + premise_tokens + [word2id['[SEP]']]
    hypothesis_ids = [word2id['[CLS]']] + hypothesis_tokens + [word2id['[SEP]']]

    # Create segment IDs: 0 for premise and 1 for hypothesis
    segment_ids = [0] * len(premise_ids) + [1] * len(hypothesis_ids)
    segment_ids = segment_ids[:MAX_SEQ_LENGTH] + [0] * (MAX_SEQ_LENGTH - len(segment_ids))
    
    def apply_masking(input_ids):
        """
        Masks approximately 15% of tokens (excluding special tokens).
        """
        num_to_mask = min(MAX_MASK, max(1, int(round(len(input_ids) * 0.15))))
        candidate_positions = [i for i, token in enumerate(input_ids)
                               if token not in [word2id['[CLS]'], word2id['[SEP]']]]
        shuffle(candidate_positions)
        masked_tokens = []
        masked_positions = []
        for pos in candidate_positions[:num_to_mask]:
            masked_positions.append(pos)
            masked_tokens.append(input_ids[pos])
            rand_val = random()
            if rand_val < 0.8:
                input_ids[pos] = word2id['[MASK]']  # 80% [MASK]
            elif rand_val < 0.9:
                input_ids[pos] = word2id[id2word[randint(0, vocab_size - 1)]]  # 10% random token
            # else: keep the original token
        # Pad masked tokens and positions if needed
        masked_tokens += [0] * (MAX_MASK - len(masked_tokens))
        masked_positions += [0] * (MAX_MASK - len(masked_positions))
        return masked_tokens[:MAX_MASK], masked_positions[:MAX_MASK]

    # Apply masking to copies so original IDs remain intact if needed elsewhere
    masked_tokens_premise, masked_pos_premise = apply_masking(premise_ids.copy())
    masked_tokens_hypothesis, masked_pos_hypothesis = apply_masking(hypothesis_ids.copy())

    # Pad or truncate input IDs to MAX_SEQ_LENGTH
    premise_ids = (premise_ids[:MAX_SEQ_LENGTH] + [0] * MAX_SEQ_LENGTH)[:MAX_SEQ_LENGTH]
    hypothesis_ids = (hypothesis_ids[:MAX_SEQ_LENGTH] + [0] * MAX_SEQ_LENGTH)[:MAX_SEQ_LENGTH]

    # Create attention masks (1 for real tokens, 0 for padding)
    attention_premise = [1 if token != 0 else 0 for token in premise_ids]
    attention_hypothesis = [1 if token != 0 else 0 for token in hypothesis_ids]

    return {
        "premise_input_ids": [premise_ids],
        "premise_pos_mask": [masked_pos_premise],
        "hypothesis_input_ids": [hypothesis_ids],
        "hypothesis_pos_mask": [masked_pos_hypothesis],
        "segment_ids": [segment_ids],
        "attention_premise": [attention_premise],
        "attention_hypothesis": [attention_hypothesis],
    }

In [None]:
def mean_pool(token_embeds, attention_mask):
    if attention_mask.shape[1] != token_embeds.shape[1]:
        attention_mask = attention_mask[:, :token_embeds.shape[1]]
    mask = attention_mask.unsqueeze(-1).float()
    pooled = torch.sum(token_embeds * mask, dim=1) / torch.clamp(mask.sum(dim=1), min=1e-9)
    return pooled


In [None]:
def calculate_similarity_s_model(model, sentence_a, sentence_b, device):
    inputs = tokenize_sentence_s_model(sentence_a, sentence_b)
    
    # Convert inputs to tensors and send to device
    premise_ids = torch.tensor(inputs['premise_input_ids']).to(device)
    pos_mask_premise = torch.tensor(inputs['premise_pos_mask']).to(device)
    attention_premise = torch.tensor(inputs['attention_premise']).to(device)
    hypothesis_ids = torch.tensor(inputs['hypothesis_input_ids']).to(device)
    pos_mask_hypothesis = torch.tensor(inputs['hypothesis_pos_mask']).to(device)
    attention_hypothesis = torch.tensor(inputs['attention_hypothesis']).to(device)
    segments = torch.tensor(inputs['segment_ids']).to(device)

    # Remove extra batch dimension if necessary
    premise_ids = premise_ids.squeeze(0) if premise_ids.dim() > 2 else premise_ids
    hypothesis_ids = hypothesis_ids.squeeze(0) if hypothesis_ids.dim() > 2 else hypothesis_ids
    segments = segments.squeeze(0) if segments.dim() > 2 else segments

    model.eval()
    with torch.no_grad():
        u, _ = model(premise_ids, segments, pos_mask_premise)
        v, _ = model(hypothesis_ids, segments, pos_mask_hypothesis)

    # Mean pooling and convert to numpy arrays for similarity computation
    u_mean = mean_pool(u, attention_premise).detach().cpu().numpy().squeeze(0)
    v_mean = mean_pool(v, attention_hypothesis).detach().cpu().numpy().squeeze(0)
    similarity = cosine_similarity(u_mean.reshape(1, -1), v_mean.reshape(1, -1))[0, 0]
    return similarity

In [None]:
def evaluate_nli_model(model, classifier, eval_dataloader, device):
    model.eval()
    classifier.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            premise_ids = batch['premise_input_ids'].to(device)
            hypothesis_ids = batch['hypothesis_input_ids'].to(device)
            pos_mask_premise = batch['premise_pos_mask'].to(device)
            pos_mask_hypothesis = batch['hypothesis_pos_mask'].to(device)
            segments = batch['segment_ids'].to(device)
            attention_premise = batch['attention_premise'].to(device)
            attention_hypothesis = batch['attention_hypothesis'].to(device)
            labels = batch['labels'].to(device)

            u, _ = model(premise_ids, segments, pos_mask_premise)
            v, _ = model(hypothesis_ids, segments, pos_mask_hypothesis)

            u_mean = mean_pool(u, attention_premise)
            v_mean = mean_pool(v, attention_hypothesis)
            diff_abs = torch.abs(u_mean - v_mean)
            features = torch.cat([u_mean, v_mean, diff_abs], dim=-1)
            logits = classifier(features)

            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    
    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1-score: {f1:.4f}')
    
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}


In [None]:
def predict_nli_and_similarity(model, classifier_head, sentence_a, sentence_b, device):
    inputs = tokenize_sentence_s_model(sentence_a, sentence_b)
    
    # Convert inputs to tensors and move to device
    premise_ids = torch.tensor(inputs['premise_input_ids']).to(device)
    pos_mask_premise = torch.tensor(inputs['premise_pos_mask']).to(device)
    attention_premise = torch.tensor(inputs['attention_premise']).to(device)
    hypothesis_ids = torch.tensor(inputs['hypothesis_input_ids']).to(device)
    pos_mask_hypothesis = torch.tensor(inputs['hypothesis_pos_mask']).to(device)
    attention_hypothesis = torch.tensor(inputs['attention_hypothesis']).to(device)
    segments = torch.tensor(inputs['segment_ids']).to(device)

    with torch.no_grad():
        u, _ = model(premise_ids, segments, pos_mask_premise)
        v, _ = model(hypothesis_ids, segments, pos_mask_hypothesis)

    u_mean = mean_pool(u, attention_premise)
    v_mean = mean_pool(v, attention_hypothesis)

    # Compute cosine similarity
    u_np = u_mean.cpu().numpy().reshape(1, -1)
    v_np = v_mean.cpu().numpy().reshape(1, -1)
    similarity_score = cosine_similarity(u_np, v_np)[0, 0]

    # NLI classification
    diff_abs = torch.abs(u_mean - v_mean)
    features = torch.cat([u_mean, v_mean, diff_abs], dim=-1)
    with torch.no_grad():
        logits = classifier_head(features)
        probabilities = F.softmax(logits, dim=-1)
    labels = ["contradiction", "neutral", "entailment"]
    nli_result = labels[torch.argmax(probabilities).item()]

    return similarity_score, nli_result

In [None]:
# Instantiate and load the BERT model (ensure BERT is defined elsewhere)
model2 = BERT()
model2.load_state_dict(torch.load('models/SBERT-model.pt'))

In [None]:
# Example evaluations and predictions:
calculate_cosine_sim_s_model(our_model, classifier_head, eval_dataloader)
calculate_loss_s_model(our_model, classifier_head, criterion, eval_dataloader)

sentence_a = "A woman is jogging along the beach at sunrise"
sentence_b = "The woman is running by the ocean early in the morning"
similarity = calculate_similarity_model1(model2, sentence_a, sentence_b, device)
print(f"Cosine Similarity: {similarity:.4f}")

eval_metrics = evaluate_nli_model(model2, classifier_head, eval_dataloader, device)
print(eval_metrics)

sentence_a = "A woman is jogging along the beach at sunrise"
sentence_b = "The woman is running by the ocean early in the morning"
similarity, nli_result = predict_nli_and_similarity(model2, classifier_head, sentence_a, sentence_b, device)
print(f"Cosine Similarity: {similarity:.4f}")
print(f"NLI Prediction: {nli_result}")

sentence_a = "A group of teenagers are hanging out at a mall"
sentence_b = "Several young people are spending time together in a shopping center"
similarity, nli_result = predict_nli_and_similarity(model2, classifier_head, sentence_a, sentence_b, device)
print(f"Cosine Similarity: {similarity:.4f}")
print(f"NLI Prediction: {nli_result}")