#1. Loading Dataset

In [73]:
# !pip install --upgrade datasets

## 1.1 Donloading the Dataset

In [74]:
from datasets import load_dataset
dataset = load_dataset("ms_marco", 'v2.1')
dataset

DatasetDict({
    validation: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 101093
    })
    train: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 808731
    })
    test: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 101092
    })
})

## 1.2 Formating The Data

In [75]:
from tqdm.auto import tqdm

def data_preparations(raw_data):
    processed_data = []
    try:
        for data in tqdm(raw_data):
            # Extract selected context
            context = ""
            for idx, is_selected in enumerate(data['passages']['is_selected']):
                if is_selected == 1:
                    context = data['passages']['passage_text'][idx]
                    break

            # Extract question and answer
            question = data['query']
            answer_list = data['answers']
            answer = answer_list[0] if answer_list else ""

            processed_data.append({
                'context': context,
                'question': question,
                'answer': answer
            })

    except Exception as e:
        print(f"Error during processing: {e}")

    return processed_data


## 1.3 Taking An Subset of Data

In [76]:
subset = {
    "train": dataset["train"].shuffle(seed=42).select(range(10000)),
    "validation": dataset["validation"].shuffle(seed=42).select(range(2000)),
    "test": dataset["test"].shuffle(seed=42).select(range(2000)),
}
subset

{'train': Dataset({
     features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
     num_rows: 10000
 }),
 'validation': Dataset({
     features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
     num_rows: 2000
 }),
 'test': Dataset({
     features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
     num_rows: 2000
 })}

In [77]:
train_data = data_preparations(subset['train'])
validation_data = data_preparations(subset['validation'])
test_data = data_preparations(subset['test'])

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [78]:
train_data[0]

{'context': "Aishwarya's father Krishnaraj Rai's funeral held, celebs pay last respects. Actress Aishwarya Rai Bachchan 's father Krishnaraj Rai passed away recently at a suburban hospital in Mumbai. Rai was hospitalised a few weeks ago and later shifted to the ICU. The funeral took place at Vile Parle Seva Sansthan Shamshan Bhoomi where Aishwarya's father-in-law Amitabh Bachchan also offered prayers.",
 'question': 'aishwarya rai father',
 'answer': 'Krishnaraj Rai'}

## 1.4 Removing the instacne wituout Context

In [79]:
train_data = [data for data in train_data if data['context'] != '']
validation_data = [data for data in validation_data if data['context'] != '']
test_data = [data for data in test_data if data['context'] != '']

# 2. Process The data

## 2.1 Setting the seefd value

In [80]:
import numpy as np
import random
import torch
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)

## 2.2 Loading the Tokenizer

In [81]:
from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
special_tokens_dict = {
    'additional_special_tokens': ['<latent>', '<context>', '<question>']
}
tokenizer.add_special_tokens(special_tokens_dict)

3

## 2.3 DataSet making

In [82]:
#  context = f"<latent><context>Hello<question>hi</s>"
#  tokenizer(context, max_length=100, padding="max_length", truncation=True, return_tensors="pt",add_special_tokens=False)


In [83]:
from torch.utils.data import DataLoader, Dataset
# === Step 3: Custom Dataset ===
class QnADataset(Dataset):
    def __init__(self, data, tokenizer, max_len=100):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        context = self.data[idx]['context']
        question = self.data[idx]['question']
        answer = self.data[idx]['answer']

        # Tokenize question (input to encoder)
        context = f"<latent><context>{context}<question>{question}</s>"
        source = self.tokenizer(context, max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt")
        # Tokenize answer (input to decoder + labels)
        target = self.tokenizer(f"<s> {answer} </s>", max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt")

        # Shifted for decoder and label input
        decoder_input_ids = target["input_ids"].squeeze()[:-1]  # remove last token
        labels = target["input_ids"].squeeze()[1:]              # remove first token

        # Pad to max_len - 1 since we removed 1 token
        pad_len = self.max_len - 1
        decoder_input_ids = torch.cat([decoder_input_ids, torch.full((pad_len - decoder_input_ids.size(0),), tokenizer.pad_token_id)])
        labels = torch.cat([labels, torch.full((pad_len - labels.size(0),), tokenizer.pad_token_id)])

        return {
            "input_ids": source["input_ids"].squeeze(),
            "attention_mask": source["attention_mask"].squeeze(),
            "decoder_input_ids": decoder_input_ids,
            "labels": labels
        }

# 3. Loading the model

## 3.1 Donloading and Resize the model

In [84]:
from transformers import BartForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
model.resize_token_embeddings(len(tokenizer))
model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50268, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50268, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_lay

## 3.2 Extrcucting Blocks From bart Model

In [85]:
bart_encoder = model.model.encoder
bart_decoder = model.model.decoder
bart_decoder_embed = model.model.decoder.embed_tokens
bart_lm_head = model.lm_head
bart_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# 4. Training

## 4.1 Loading the data

In [86]:
train_dataset = QnADataset(train_data, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

valid_dataset = QnADataset(validation_data, tokenizer)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

## 4.2 Loading Optimizer

In [87]:
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)

## 4.3 Training Loops

In [93]:
## Memory
import torch
import torch.nn as nn
import torch.nn.functional as F

class EntailmentMemoryModule(nn.Module):
    def __init__(self, hidden_size, memory_size=10, temperature=0.1):
        super(EntailmentMemoryModule, self).__init__()
        self.memory_size = memory_size
        self.hidden_size = hidden_size
        self.temperature = temperature

        # Learnable memory keys and values
        self.memory_keys = nn.Parameter(torch.randn(memory_size, hidden_size))
        self.memory_values = nn.Parameter(torch.randn(memory_size, hidden_size))

        # Optional linear transformation before returning enhanced vector
        self.output_proj = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, latent_vec):
        """
        latent_vec: Tensor of shape (batch_size, hidden_size)
        returns: enhanced_vec of shape (batch_size, hidden_size)
        """

        # Compute attention over memory
        # latent_vec: (B, H) -> (B, 1, H)
        # memory_keys: (M, H) -> (1, M, H)
        query = latent_vec.unsqueeze(1)                   # (B, 1, H)
        keys = self.memory_keys.unsqueeze(0)              # (1, M, H)

        # Cosine similarity attention
        attn_logits = F.cosine_similarity(query, keys, dim=-1) / self.temperature  # (B, M)
        attn_weights = F.softmax(attn_logits, dim=-1)       # (B, M)

        # Weighted sum of memory values
        # memory_values: (M, H)
        memory_output = torch.matmul(attn_weights, self.memory_values)  # (B, H)

        # Concatenate and project
        combined = torch.cat([latent_vec, memory_output], dim=-1)       # (B, 2H)
        enhanced_latent = self.output_proj(combined)                    # (B, H)

        return enhanced_latent


In [94]:
memory = EntailmentMemoryModule(model.config.d_model, memory_size=10, temperature=0.1)
memory = memory.to(device)

### 4.2.1 Training Without Memory

In [None]:
train_losses = []
val_losses = []

best_val_loss = float("inf")  # 🧠 Start with infinity
EPOCH = 5

for epoch in range(EPOCH):
    print(f"\n🟢 Epoch {epoch + 1}\n" + "-"*30)

    # ========== Training ==========
    model.train()
    running_train_loss = 0.0

    for batch in tqdm(train_dataloader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        decoder_input_ids = batch["decoder_input_ids"].to(device)
        labels = batch["labels"].to(device)

        encoder_outputs = bart_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        decoder_embeddings = bart_decoder_embed(decoder_input_ids)
        hz = memory(encoder_outputs.last_hidden_state[:, 0, :])
        decoder_embeddings[:, 0, :] += hz[0]

        decoder_outputs = bart_decoder(
            # input_ids=decoder_input_ids,
            inputs_embeds=decoder_embeddings,
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            encoder_attention_mask=attention_mask,
            return_dict=True
        )

        lm_logits = bart_lm_head(decoder_outputs.last_hidden_state)

        loss = bart_loss_fct(
            lm_logits.view(-1, lm_logits.size(-1)),
            labels.view(-1)
        )

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_train_loss += loss.item()

    avg_train_loss = running_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    print(f"✅ Avg Train Loss: {avg_train_loss:.4f}")

    # ========== Validation ==========
    model.eval()
    running_val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(valid_dataloader, desc="Validating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            decoder_input_ids = batch["decoder_input_ids"].to(device)
            labels = batch["labels"].to(device)

            encoder_outputs = bart_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            decoder_embeddings = bart_decoder_embed(decoder_input_ids)
            hz = memory(encoder_outputs.last_hidden_state[:, 0, :])
            decoder_embeddings[:, 0, :] += hz[0]

            decoder_outputs = bart_decoder(
                # input_ids=decoder_input_ids,
                inputs_embeds=decoder_embeddings,
                encoder_hidden_states=encoder_outputs.last_hidden_state,
                encoder_attention_mask=attention_mask,
                return_dict=True
            )

            lm_logits = bart_lm_head(decoder_outputs.last_hidden_state)

            val_loss = bart_loss_fct(
                lm_logits.view(-1, lm_logits.size(-1)),
                labels.view(-1)
            )

            running_val_loss += val_loss.item()

    avg_val_loss = running_val_loss / len(valid_dataloader)
    val_losses.append(avg_val_loss)
    print(f"🔵 Avg Validation Loss: {avg_val_loss:.4f}")

    # === Early Stopping Condition ===
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        print("✅ Validation loss improved — continue training.")
    else:
        print("🛑 Validation loss increased — stopping early.")
        break


🟢 Epoch 1
------------------------------


Training:   0%|          | 0/392 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


✅ Avg Train Loss: 0.8995


Validating:   0%|          | 0/549 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 0.7127
✅ Validation loss improved — continue training.

🟢 Epoch 2
------------------------------


Training:   0%|          | 0/392 [00:00<?, ?it/s]

✅ Avg Train Loss: 0.6026


Validating:   0%|          | 0/549 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 0.7205
🛑 Validation loss increased — stopping early.


In [95]:
train_losses = []
val_losses = []

best_val_loss = float("inf")  # 🧠 Start with infinity
EPOCH = 5

for epoch in range(EPOCH):
    print(f"\n🟢 Epoch {epoch + 1}\n" + "-"*30)

    # ========== Training ==========
    model.train()
    running_train_loss = 0.0

    for batch in tqdm(train_dataloader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        decoder_input_ids = batch["decoder_input_ids"].to(device)
        labels = batch["labels"].to(device)

        encoder_outputs = bart_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        decoder_embeddings = bart_decoder_embed(decoder_input_ids)
        hz = memory(encoder_outputs.last_hidden_state[:, 0, :])
        decoder_embeddings[:, 0, :] += hz[0]

        decoder_outputs = bart_decoder(
            # input_ids=decoder_input_ids,
            inputs_embeds=decoder_embeddings,
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            encoder_attention_mask=attention_mask,
            return_dict=True
        )

        lm_logits = bart_lm_head(decoder_outputs.last_hidden_state)

        loss = bart_loss_fct(
            lm_logits.view(-1, lm_logits.size(-1)),
            labels.view(-1)
        )

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_train_loss += loss.item()

    avg_train_loss = running_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    print(f"✅ Avg Train Loss: {avg_train_loss:.4f}")

    # ========== Validation ==========
    model.eval()
    running_val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(valid_dataloader, desc="Validating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            decoder_input_ids = batch["decoder_input_ids"].to(device)
            labels = batch["labels"].to(device)

            encoder_outputs = bart_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            decoder_embeddings = bart_decoder_embed(decoder_input_ids)
            hz = memory(encoder_outputs.last_hidden_state[:, 0, :])
            decoder_embeddings[:, 0, :] += hz[0]

            decoder_outputs = bart_decoder(
                # input_ids=decoder_input_ids,
                inputs_embeds=decoder_embeddings,
                encoder_hidden_states=encoder_outputs.last_hidden_state,
                encoder_attention_mask=attention_mask,
                return_dict=True
            )

            lm_logits = bart_lm_head(decoder_outputs.last_hidden_state)

            val_loss = bart_loss_fct(
                lm_logits.view(-1, lm_logits.size(-1)),
                labels.view(-1)
            )

            running_val_loss += val_loss.item()

    avg_val_loss = running_val_loss / len(valid_dataloader)
    val_losses.append(avg_val_loss)
    print(f"🔵 Avg Validation Loss: {avg_val_loss:.4f}")

    # === Early Stopping Condition ===
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        print("✅ Validation loss improved — continue training.")
    else:
        print("🛑 Validation loss increased — stopping early.")
        break


🟢 Epoch 1
------------------------------


Training:   0%|          | 0/392 [00:00<?, ?it/s]

✅ Avg Train Loss: 1.0260


Validating:   0%|          | 0/549 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 0.8697
✅ Validation loss improved — continue training.

🟢 Epoch 2
------------------------------


Training:   0%|          | 0/392 [00:00<?, ?it/s]

✅ Avg Train Loss: 0.6762


Validating:   0%|          | 0/549 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 0.7183
✅ Validation loss improved — continue training.

🟢 Epoch 3
------------------------------


Training:   0%|          | 0/392 [00:00<?, ?it/s]

✅ Avg Train Loss: 0.5363


Validating:   0%|          | 0/549 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 0.7158
✅ Validation loss improved — continue training.

🟢 Epoch 4
------------------------------


Training:   0%|          | 0/392 [00:00<?, ?it/s]

✅ Avg Train Loss: 0.4333


Validating:   0%|          | 0/549 [00:00<?, ?it/s]

🔵 Avg Validation Loss: 0.7577
🛑 Validation loss increased — stopping early.


## 4.2.2 with memory

In [96]:
import torch
import torch.nn.functional as F

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    """ Filter logits using top-k and/or top-p (nucleus) filtering. """
    logits = logits.clone()  # Avoid in-place modifications on input tensor

    # Top-k filtering
    if top_k > 0:
        top_k = min(top_k, logits.size(-1))  # Safety check
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    # Top-p (nucleus) filtering
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probs above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift to ensure at least one token is kept
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = False

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    return logits

def bart_generate(
    model, tokenizer, input_text, device,
    max_length=50, top_k=50, top_p=0.9, temperature=1.0
):
    model.eval()
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

    # Get encoder outputs
    encoder_outputs = bart_encoder(input_ids=input_ids, return_dict=True)

    # Initialize decoder input with BOS token
    decoder_input_ids = torch.tensor([[tokenizer.bos_token_id]], device=device)

    # Optional: Add custom memory injection (mocked here for placeholder)
    decoder_embeddings = bart_decoder_embed(decoder_input_ids)
    hz = memory(encoder_outputs.last_hidden_state[:, 0, :])  # If needed, define memory()
    decoder_embeddings[:, 0, :] += hz[0]  # Modify only if memory injection is required

    generated_tokens = []

    for step in range(max_length):
        decoder_outputs = bart_decoder(
            inputs_embeds=decoder_embeddings,
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            encoder_attention_mask=(input_ids != tokenizer.pad_token_id),
            return_dict=True,
        )

        lm_logits = model.lm_head(decoder_outputs.last_hidden_state)
        next_token_logits = lm_logits[:, -1, :] / temperature

        filtered_logits = top_k_top_p_filtering(next_token_logits.squeeze(0), top_k=top_k, top_p=top_p)

        probs = F.softmax(filtered_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        decoder_input_ids = torch.cat([decoder_input_ids, next_token.unsqueeze(0)], dim=-1)
        decoder_embeddings = model.model.decoder.embed_tokens(decoder_input_ids)

        generated_tokens.append(next_token.item())

        if next_token.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(generated_tokens, skip_special_tokens=True)


In [102]:
low = 0
high = 1000  # Example range [0, 100)
random_int = random.randint(low, high - 1) # randint is inclusive of both endpoints, adjust if you want [low, high)
random_int
context = validation_data[random_int]['context']
question =  validation_data[random_int]['question']
answer =  validation_data[random_int]['answer']

input_text = f"question: {question} context: {context}"

predicted_answer = bart_generate(
    model, tokenizer, input_text, device,
    max_length=30,
    top_k=5,
    top_p=1.0,
    temperature=1.0
)

print("question: ",question)
print("predicted_answer:", predicted_answer)
print("answer:", answer)

question:  benefits of olive oil and honey soap
predicted_answer:  The benefits of olive oil and honey soap context: Olive oil is a natural source of the antioxidants vitamins E and A. 
answer: The benefits of olive oil are that it is a natural source of the antioxidants vitamins E and A, which fight free radicals,it  allows the skin to sweat and shed cells naturally,it is a great component in a heart-healthy salad dressing and honey Cold Process Soap contains a full tablespoon of local honey and it is scented with Pure Honey Fragrance Oil, which gives the bars a wonderfully sweet scent.


In [106]:
import time
from tqdm import tqdm

actual_answers = []
predicted_answers = []

start_time = time.time()  # Start timer

for data in tqdm(validation_data, desc="Validation"):
    pred_answer = bart_generate(
        model, tokenizer,
        f"<latent><context>{data['context']}<question>{data['question']}</s>",
        device,
        max_length=100,
        top_k=10,
        top_p=0.6,
        temperature=1.0
    )
    actual_answers.append(data['answer'])
    predicted_answers.append(pred_answer)

end_time = time.time()  # End timer

# Total time taken
total_time = end_time - start_time
avg_time_per_example = total_time / len(validation_data)

print(f"\n🕒 Total time: {total_time:.2f} seconds")
print(f"⏱️  Average time per example: {avg_time_per_example:.4f} seconds")


Validation: 100%|██████████| 1097/1097 [03:54<00:00,  4.69it/s]


🕒 Total time: 234.15 seconds
⏱️  Average time per example: 0.2134 seconds





In [107]:
# prompt: mountdrive'

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [108]:
# prompt: save actual, predicted list in pkl

import pickle

# Define the filename for the pickle file
filename = '/content/drive/MyDrive/Latest Experiment/answers_mem.pkl'

# Create a dictionary to store the lists
data_to_save = {
    'actual_answers': actual_answers,
    'predicted_answers': predicted_answers
}

# Open the file in write binary mode ('wb') and save the data
with open(filename, 'wb') as f:
    pickle.dump(data_to_save, f)

print(f"Actual and predicted answers saved to {filename}")

Actual and predicted answers saved to /content/drive/MyDrive/Latest Experiment/answers_mem.pkl
