# Step 1 Load Dataset

In [2]:
import json
with open('/kaggle/input/ms-micro-dataset/train_ms_micro.json', 'r') as f:
    train_data = json.load(f)
train_data[0]

{'context': "Aishwarya's father Krishnaraj Rai's funeral held, celebs pay last respects. Actress Aishwarya Rai Bachchan 's father Krishnaraj Rai passed away recently at a suburban hospital in Mumbai. Rai was hospitalised a few weeks ago and later shifted to the ICU. The funeral took place at Vile Parle Seva Sansthan Shamshan Bhoomi where Aishwarya's father-in-law Amitabh Bachchan also offered prayers.",
 'question': 'aishwarya rai father',
 'answer': 'Krishnaraj Rai'}

In [3]:
len(train_data)

3138

In [4]:
import json
with open('/kaggle/input/ms-micro-dataset/valid_ms_micro.json', 'r') as f:
    valid_data = json.load(f)
valid_data[0]

{'context': 'Card definition, a usually rectangular piece of stiff paper, thin pasteboard, or plastic for various uses, as to write information on or printed as a means of identifying the holder: a 3″ × 5″ file card; a membership card. See more.',
 'question': 'cards meaning',
 'answer': 'Card means usually rectangular piece of stiff paper, thin pasteboard, or plastic for various uses, as to write information on or printed as a means of identifying the holder.'}

In [5]:
len(valid_data)

1097

# Step 2: Load tokenizer

In [6]:
import numpy as np
import random
import torch
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)

In [7]:
from transformers import BartTokenizer, BartForConditionalGeneration

2025-06-01 08:37:00.708677: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748767021.135421      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748767021.249001      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [8]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
special_tokens_dict = {
    "additional_special_tokens": ["<question>", "<context>", "<answer>"]
}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print(f"Added {num_added_toks} special tokens.")

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Added 3 special tokens.


In [9]:
from torch.utils.data import Dataset, DataLoader

class QnADataset(Dataset):
    def __init__(self, data, tokenizer, max_len=64):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        context = item.get('context', '')  # Optional context
        answer = item['answer']

        input_text = f"<query> {question} <context> {context} ".strip()

        # Encode the inputs and targets
        source = self.tokenizer(
            input_text, max_length=self.max_len, padding="max_length",
            truncation=True, return_tensors="pt"
        )
        target = self.tokenizer(
            f"<answer> {answer}", max_length=32, padding="max_length",
            truncation=True, return_tensors="pt"
        )

        # Shift decoder inputs and labels
        input_ids = source["input_ids"].squeeze()
        attention_mask = source["attention_mask"].squeeze()

        decoder_input_ids = target["input_ids"].squeeze()[:-1]
        labels = target["input_ids"].squeeze()[1:]

        # Ensure consistent lengths (max_len - 1)
        pad_token_id = self.tokenizer.pad_token_id
        pad_len = self.max_len - 1
        decoder_input_ids = torch.cat([
            decoder_input_ids,
            torch.full((pad_len - decoder_input_ids.size(0),), pad_token_id)
        ])
        labels = torch.cat([
            labels,
            torch.full((pad_len - labels.size(0),), pad_token_id)
        ])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "labels": labels,
        }

In [10]:
train_dataset = QnADataset(train_data, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

valid_dataset = QnADataset(valid_data, tokenizer)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

In [11]:
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.resize_token_embeddings(len(tokenizer))
model.to(device)
print("hello")

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


hello


In [13]:
from torch.optim import AdamW
# === Step 4: Optimizer ===
optimizer = AdamW(model.parameters(), lr=5e-5)

In [14]:
from tqdm.auto import tqdm
model.train()
encoder = model.model.encoder
decoder = model.model.decoder
lm_head = model.lm_head
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [15]:
import torch.nn as nn
class EntailmentMemory(nn.Module):
    def __init__(self, num_slots: int, embedding_dim: int):
        super().__init__()
        self.memory = nn.Parameter(torch.randn(num_slots, embedding_dim))
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x: torch.Tensor):
        projected = self.proj(x)  # [batch_size, embedding_dim]
        scores = torch.matmul(projected, self.memory.T)  # [batch_size, num_slots]
        weights = self.softmax(scores)
        output = torch.matmul(weights, self.memory)  # [batch_size, embedding_dim]
        return output, weights

memory = EntailmentMemory(num_slots=10, embedding_dim=model.config.d_model).to(device)

In [16]:
train_losses = []
val_losses = []
best_val_loss = float("inf")  # 🧠 Start with infinity

epoch = 5

for epoch in range(epoch):
    print(f"\n🟢 Epoch {epoch + 1}\n" + "-"*30)

    # ========== Training ==========
    model.train()
    running_train_loss = 0.0

    for batch in tqdm(train_dataloader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        decoder_input_ids = batch["decoder_input_ids"].to(device)
        labels = batch["labels"].to(device)

        encoder_outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        decoder_embeddings = model.model.decoder.embed_tokens(decoder_input_ids)
        # hz = memory(encoder_outputs.last_hidden_state[:, 0, :])
        # decoder_embeddings[:, 0, :] += hz[0]
       
        decoder_outputs = decoder(
            # input_ids=decoder_input_ids,
            inputs_embeds=decoder_embeddings,
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            encoder_attention_mask=attention_mask,
            return_dict=True
        )

        lm_logits = lm_head(decoder_outputs.last_hidden_state)

        loss = loss_fct(
            lm_logits.view(-1, lm_logits.size(-1)),
            labels.view(-1)
        )

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_train_loss += loss.item()

    avg_train_loss = running_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    print(f"✅ Avg Train Loss: {avg_train_loss:.4f}")

    # ========== Validation ==========
    model.eval()
    running_val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(valid_dataloader, desc="Validating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            decoder_input_ids = batch["decoder_input_ids"].to(device)
            labels = batch["labels"].to(device)

            encoder_outputs = encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            decoder_embeddings = model.model.decoder.embed_tokens(decoder_input_ids)
            # hz = memory(encoder_outputs.last_hidden_state[:, 0, :])
            # decoder_embeddings[:, 0, :] += hz[0]

            decoder_outputs = decoder(
                # input_ids=decoder_input_ids,
                inputs_embeds=decoder_embeddings,
                encoder_hidden_states=encoder_outputs.last_hidden_state,
                encoder_attention_mask=attention_mask,
                return_dict=True
            )

            lm_logits = lm_head(decoder_outputs.last_hidden_state)

            val_loss = loss_fct(
                lm_logits.view(-1, lm_logits.size(-1)),
                labels.view(-1)
            )

            running_val_loss += val_loss.item()

    avg_val_loss = running_val_loss / len(valid_dataloader)
    val_losses.append(avg_val_loss)
    print(f"🔵 Avg Validation Loss: {avg_val_loss:.4f}")

    # === Early Stopping Condition ===
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        print("✅ Validation loss improved — continue training.")
    else:
        print("🛑 Validation loss increased — stopping early.")
        break


🟢 Epoch 1
------------------------------


Training:   0%|          | 0/197 [00:00<?, ?it/s]

✅ Avg Train Loss: 2.0581


Validating:   0%|          | 0/69 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 1.3402
✅ Validation loss improved — continue training.

🟢 Epoch 2
------------------------------


Training:   0%|          | 0/197 [00:00<?, ?it/s]

✅ Avg Train Loss: 1.1712


Validating:   0%|          | 0/69 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 1.0307
✅ Validation loss improved — continue training.

🟢 Epoch 3
------------------------------


Training:   0%|          | 0/197 [00:00<?, ?it/s]

✅ Avg Train Loss: 0.8303


Validating:   0%|          | 0/69 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 0.9932
✅ Validation loss improved — continue training.

🟢 Epoch 4
------------------------------


Training:   0%|          | 0/197 [00:00<?, ?it/s]

✅ Avg Train Loss: 0.7221


Validating:   0%|          | 0/69 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 1.0218
🛑 Validation loss increased — stopping early.


## Step : Inference

In [17]:
import torch
import torch.nn.functional as F

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    """ Filter logits using top-k and/or top-p (nucleus) filtering. """
    logits = logits.clone()  # Avoid in-place modifications on input tensor

    # Top-k filtering
    if top_k > 0:
        top_k = min(top_k, logits.size(-1))  # Safety check
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    # Top-p (nucleus) filtering
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probs above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift to ensure at least one token is kept
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = False

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    return logits

def bart_generate(
    model, tokenizer, input_text, device,
    max_length=50, top_k=50, top_p=0.9, temperature=1.0
):
    model.eval()
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

    # Get encoder outputs
    encoder_outputs = model.model.encoder(input_ids=input_ids, return_dict=True)

    # Initialize decoder input with BOS token
    decoder_input_ids = torch.tensor([[tokenizer.bos_token_id]], device=device)

    # Optional: Add custom memory injection (mocked here for placeholder)
    decoder_embeddings = model.model.decoder.embed_tokens(decoder_input_ids)
    # hz = memory(encoder_outputs.last_hidden_state[:, 0, :])  # If needed, define memory()
    # decoder_embeddings[:, 0, :] += hz[0]  # Modify only if memory injection is required


    generated_tokens = []

    for step in range(max_length):
        decoder_outputs = model.model.decoder(
            inputs_embeds=decoder_embeddings,
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            encoder_attention_mask=(input_ids != tokenizer.pad_token_id),
            return_dict=True,
        )

        lm_logits = model.lm_head(decoder_outputs.last_hidden_state)
        next_token_logits = lm_logits[:, -1, :] / temperature

        filtered_logits = top_k_top_p_filtering(next_token_logits.squeeze(0), top_k=top_k, top_p=top_p)

        probs = F.softmax(filtered_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        decoder_input_ids = torch.cat([decoder_input_ids, next_token.unsqueeze(0)], dim=-1)
        decoder_embeddings = model.model.decoder.embed_tokens(decoder_input_ids)

        generated_tokens.append(next_token.item())

        if next_token.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(generated_tokens, skip_special_tokens=True)


In [18]:
import random

# Get a random integer between 1 and 100 (inclusive)
rand_num = random.randint(1, 1000)

print(rand_num)

context = valid_data[rand_num]['context']
question =  valid_data[rand_num]['question']
answer =  valid_data[rand_num]['answer']

input_text = f"<question>{question} <context> {context} "

predicted_answer = bart_generate(
    model, tokenizer, input_text, device,
    max_length=30,
    top_k=5,
    top_p=1.0,
    temperature=1.0
)

print("question: ",question)
print("predicted_answer:", predicted_answer)
print("answer:", answer)

655
question:  what is medicine acyclovir used for
predicted_answer:  An antiviral medication that can be used in the treatment of the herpes virus, including genital herpes, cold sores, and shingles.
answer: Acyclovir is used in the treatment of the herpes virus, including genital herpes, cold sores, and shingles.


In [19]:
import time
from tqdm import tqdm

actual_answers = []
predicted_answers = []

start_time = time.time()  # Start timer

for data in tqdm(valid_data, desc="Validation"):
    pred_answer = bart_generate(
        model, tokenizer, 
        f"question: {data['question']} context: {data['context']}", 
        device,
        max_length=30,
        top_k=10,
        top_p=0.6,
        temperature=1.0
    )
    actual_answers.append(data['answer'])
    predicted_answers.append(pred_answer)

end_time = time.time()  # End timer

# Total time taken
total_time = end_time - start_time
avg_time_per_example = total_time / len(valid_data)

print(f"\n🕒 Total time: {total_time:.2f} seconds")
print(f"⏱️  Average time per example: {avg_time_per_example:.4f} seconds")


Validation: 100%|██████████| 1097/1097 [02:47<00:00,  6.56it/s]


🕒 Total time: 167.21 seconds
⏱️  Average time per example: 0.1524 seconds





In [22]:
import pickle


# Save both in a dictionary
results = {
    'actual_answers': actual_answers,
    'predicted_answers': predicted_answers
}

# Save to pickle file
with open("qa_predictions_vanilla.pkl", "wb") as f:
    pickle.dump(results, f)