In [9]:
import pandas as pd
import random

from datasets import Dataset
import torch


from transformers import AutoTokenizer
from datasets import Dataset

from typing import Any, Dict, List, Tuple

from torch.utils.data import DataLoader

from transformers import PreTrainedTokenizer

In [3]:
df = pd.read_csv("/home/praveen/theoden/emnlp_25/forget_20_1.csv")

In [4]:
with open("idk.jsonl", "r", encoding="utf-8") as f:
    idk_responses = [line.strip() for line in f if line.strip()]

In [5]:
random_responses = random.sample(idk_responses, 98)
df['idk'] = random_responses

In [6]:
df = df[['title', 'question', 'answer', 'idk']]
#df.to_csv('dpo_forget_idk.csv', index=False)

In [7]:
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3.1-8B-Instruct')
tokenizer.pad_token = tokenizer.eos_token

In [9]:
df.columns

Index(['title', 'question', 'answer', 'idk'], dtype='object')

In [None]:
def convert_raw_data_to_model_qa(tokenizer, max_length,  question, answer, configs):
    question = str(question)
    answer = str(answer)
    
    messages = [{"role": "user", "content": question}]
    new_question = tokenizer.apply_chat_template(
        messages,
        tokenizer = False,
        add_generataion_prompt=True
    )
    
    full_text = str(new_question) + answer
    num_question_tokens = len(tokenizer.tokenize(str(new_question), add_special_tokens=True))

    encoded = tokenizer(
        full_text, 
        add_special_tokens=True, 
        max_length=max_length, 
        truncation=True, 
    )
    pad_length = max_length - len(encoded.input_ids)
    
    pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
    pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
    if len(encoded.input_ids) == max_length:
        label = encoded.input_ids
    else:
        label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)

    #change label to -100 for question tokens
    for i in range(num_question_tokens): label[i] = -100

    return torch.tensor(pad_input_ids),torch.tensor(label),torch.tensor(pad_attention_mask)


class VanillaDPODataset(Dataset):
    """
    Dataset class for creating data for forgetting.
    Processes 'question'/'answer' pairs and 'question'/'idk' pairs separately.

    Args:
        forget_data (pd.DataFrame): DataFrame containing 'question', 'answer', and 'idk' columns.
        tokenizer: tokenizer instance to process text
        max_length (int): maximum sequence length
        template_format (str, optional): format template for structuring input
        question_key (str): Column name for the question. Defaults to 'question'.
        answer_key (str): Column name for the answer to forget. Defaults to 'answer'.
        idk_key (str): Column name for the 'I don't know' or alternative response. Defaults to 'idk'.

    Returns:
        A dictionary containing processed data for both the original answer and the idk response:
        {
            "answer_data": (answer_input_ids, answer_labels, answer_attention_mask),
            "idk_data": (idk_input_ids, idk_labels, idk_attention_mask)
        }
    """
    def __init__(self, forget_data: pd.DataFrame, tokenizer: Any, max_length: int, template_format: str = None,
                 question_key: str = 'question',
                 answer_key: str = 'answer',
                 idk_key: str = 'idk'):
        if not all(k in forget_data.columns for k in [question_key, answer_key, idk_key]):
             raise ValueError(f"forget_data must contain columns: {question_key}, {answer_key}, {idk_key}")

        self.forget_data = forget_data.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.template_format = template_format
        self.qk = question_key
        self.ak = answer_key
        self.ik = idk_key

    def __len__(self):
        return len(self.forget_data)

    def __getitem__(self, idx):
        row = self.forget_data.iloc[idx]
        q = row[self.qk]
        ans = row[self.ak]
        idk = row[self.ik]

        ai, al, am = convert_raw_data_to_model_qa(self.tokenizer,
                                                self.max_length,
                                                q, ans,
                                                self.template_format)
        ii, il, im = convert_raw_data_to_model_qa(self.tokenizer,
                                                self.max_length,
                                                q, idk,
                                                self.template_format)

        return {
            'answer_input_ids':      ai,
            'answer_labels':         al,
            'answer_attention_mask': am,
            'idk_input_ids':         ii,
            'idk_labels':            il,
            'idk_attention_mask':    im,
        }

def custom_forget_collator(samples):
    # unpack into two lists of tuples
    answer_list = [s['answer_data'] for s in samples]
    idk_list    = [s['idk_data']    for s in samples]

    def stack_triplet(triplets):
        # each triplet is (input_ids, labels, attn_mask)
        input_ids      = torch.stack([t[0] for t in triplets], dim=0)
        labels         = torch.stack([t[1] for t in triplets], dim=0)
        attention_mask = torch.stack([t[2] for t in triplets], dim=0)
        return input_ids, labels, attention_mask

    batched = {
      "forget": {
        "answer": stack_triplet(answer_list),
        "idk":    stack_triplet(idk_list),
      }
    }
    return batched

In [52]:
dataset = VanillaDPODataset(df, tokenizer, max_length=256)
loaded_data = DataLoader(dataset, batch_size=2)

In [53]:
dataset.max_length

256

In [54]:
sample_batch = dataset[0]
sample_batch

{'forget': {'answer': {'input_ids': tensor([128000,     58,   4386,    931,     11,    220,   4386,  11030,     11,
              220,  22750,     20,     11,    220,   4386,  11194,     11,    220,
            15828,     11,    220,  20062,   2287,     11,    220,   5894,     18,
               11,    220,  10568,    914,     11,    220,  16955,     21,     11,
              220,    914,     11,    220,  25136,     15,     11,    220,   8610,
               11,    220,  14087,     21,     11,    220,    972,     11,    220,
             3753,     11,    220,  10895,   1187,     11,    220,  16955,     21,
               11,    220,    914,     11,    220,   8610,     11,    220,  10674,
               22,     11,    220,   4278,   5495,     11,    220,   8610,     11,
              220,  14087,     21,     11,    220,    777,     11,    220,  15828,
               11,    220,   4386,  13858,     11,    220,   4386,  11030,     11,
              220,  23213,     11,    220,   4386,  11