In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

In [7]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from config import Config
from peft import  LoraConfig, get_peft_model
from data_module import convert_raw_data_to_model_qa
from collators import custom_data_collator_interleaved
from utils import find_all_linear_names
from forget_trainer import BatchGradDiffTrainer
from accelerate import Accelerator
import pandas as pd
import random
from torch.utils.data import Dataset

In [3]:
cfg = Config()

accelerator = Accelerator()

In [4]:
forget = pd.read_csv(cfg.forget_path) #cfg.forget_path
retain = pd.read_csv(cfg.retain_path) #cfg.retain_path

In [5]:
print(f"\nLoading the Tokenizer {cfg.model_id}")
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, token = cfg.access_token)
tokenizer.pad_token = tokenizer.eos_token


Loading the Tokenizer praveensonu/llama_3_1_8b_finetuned


In [11]:
import random
from torch.utils.data import Dataset

class DualDatasetTitle(Dataset):
    """
    Dataset class for creating data for forget and retain (used by gradient difference)
    
    Args:
        forget_data (pd.DataFrame): DataFrame containing 'question', 'answer', and 'title' columns for forgetting
        retain_data (pd.DataFrame): DataFrame containing 'question', 'answer', and 'title' columns for retaining
        tokenizer: tokenizer instance to process text
        max_length (int): maximum sequence length
        question_key (str, optional): column name for questions in the DataFrame
        answer_key (str, optional): column name for answers in the DataFrame
        title_key (str, optional): column name for titles in the DataFrame
    """
    def __init__(self, forget_data, retain_data, tokenizer, max_length, 
                 question_key='question', answer_key='answer', title_key='title'):
        """
        Args:
            forget_data: DataFrame for forget set.
            retain_data: DataFrame for retain set.
            tokenizer: Tokenizer instance to process text.
            max_length: Maximum sequence length.
            question_key: Key for questions column.
            answer_key: Key for answers column.
            title_key: Key for titles column.
        """
        self.forget = forget_data.reset_index(drop=True)
        self.retain = retain_data.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.qk = question_key
        self.ak = answer_key
        self.tk = title_key
        
        # Group by title in both datasets
        self.forget_titles = self._group_by_title(self.forget)
        self.retain_titles = self._group_by_title(self.retain)
        
        # Find common titles
        self.common_titles = list(set(self.forget_titles.keys()) & set(self.retain_titles.keys()))
        
        # Titles that only appear in forget or retain
        self.unique_forget_titles = list(set(self.forget_titles.keys()) - set(self.retain_titles.keys()))
        self.unique_retain_titles = list(set(self.retain_titles.keys()) - set(self.forget_titles.keys()))
    
    def _group_by_title(self, data):
        """
        Helper method to group data by title.
        Returns a dictionary where the key is the title and the value is a list of row indices.
        """
        grouped = {}
        for idx, row in data.iterrows():
            title = row[self.tk]
            if title not in grouped:
                grouped[title] = []
            grouped[title].append(idx)
        return grouped

    def __len__(self):
        # Total number of pairs: for each common title, we combine all forget and retain pairs.
        total_pairs = sum(len(self.forget_titles[title]) * len(self.retain_titles[title]) 
                          for title in self.common_titles)
        
        # Add pairs for unmatched titles (forget-only and retain-only titles)
        total_pairs += len(self.unique_forget_titles) * len(self.retain)
        total_pairs += len(self.unique_retain_titles) * len(self.forget)
        
        return total_pairs
    
    def __getitem__(self, idx):
        # Find the corresponding title and cyclically select forget and retain samples
        pair_idx = idx
        
        # Handle common titles first
        for title in self.common_titles:
            forget_len = len(self.forget_titles[title])
            retain_len = len(self.retain_titles[title])
            num_pairs_for_title = forget_len * retain_len
            
            if pair_idx < num_pairs_for_title:
                # Determine which forget and retain sample to pair
                forget_idx = self.forget_titles[title][pair_idx % forget_len]
                retain_idx = self.retain_titles[title][pair_idx % retain_len]
                
                # Get the data for the selected forget and retain indices
                forget_data = convert_raw_data_to_model_qa(
                    self.tokenizer, self.max_length,
                    self.forget.iloc[forget_idx][self.qk],
                    self.forget.iloc[forget_idx][self.ak],
                )
                # Create a dictionary to include title in the forget_data
                forget_data = {
                    'input_ids': forget_data[0],
                    'labels': forget_data[1],
                    'attention_mask': forget_data[2],
                    'title': self.forget.iloc[forget_idx][self.tk]
                }

                retain_data = convert_raw_data_to_model_qa(
                    self.tokenizer, self.max_length,
                    self.retain.iloc[retain_idx][self.qk],
                    self.retain.iloc[retain_idx][self.ak],
                )
                # Create a dictionary to include title in the retain_data
                retain_data = {
                    'input_ids': retain_data[0],
                    'labels': retain_data[1],
                    'attention_mask': retain_data[2],
                    'title': self.retain.iloc[retain_idx][self.tk]
                }

                return (forget_data, retain_data)
            else:
                # Reduce pair_idx to reflect the next title in the loop
                pair_idx -= num_pairs_for_title
        
        # Handle forget-only titles: Randomly sample a title from retain
        if pair_idx < len(self.unique_forget_titles) * len(self.retain):
            unique_forget_title = self.unique_forget_titles[pair_idx // len(self.retain)]
            random_retain_idx = random.randint(0, len(self.retain) - 1)
            
            forget_idx = self.forget_titles[unique_forget_title][pair_idx % len(self.forget_titles[unique_forget_title])]
            retain_idx = self.retain.iloc[random_retain_idx]
            
            forget_data = convert_raw_data_to_model_qa(
                self.tokenizer, self.max_length,
                self.forget.iloc[forget_idx][self.qk],
                self.forget.iloc[forget_idx][self.ak],
            )
            # Create a dictionary to include title in the forget_data
            forget_data = {
                'input_ids': forget_data[0],
                'labels': forget_data[1],
                'attention_mask': forget_data[2],
                'title': self.forget.iloc[forget_idx][self.tk]
            }

            retain_data = convert_raw_data_to_model_qa(
                self.tokenizer, self.max_length,
                self.retain.iloc[retain_idx][self.qk],
                self.retain.iloc[retain_idx][self.ak],
            )
            # Create a dictionary to include title in the retain_data
            retain_data = {
                'input_ids': retain_data[0],
                'labels': retain_data[1],
                'attention_mask': retain_data[2],
                'title': self.retain.iloc[retain_idx][self.tk]
            }

            return (forget_data, retain_data)
        
        # Handle retain-only titles: Randomly sample a title from forget
        if pair_idx < (len(self.unique_forget_titles) + len(self.unique_retain_titles)) * len(self.retain):
            unique_retain_title = self.unique_retain_titles[(pair_idx - len(self.unique_forget_titles) * len(self.retain)) // len(self.forget)]
            random_forget_idx = random.randint(0, len(self.forget) - 1)
            
            retain_idx = self.retain_titles[unique_retain_title][pair_idx % len(self.retain_titles[unique_retain_title])]
            forget_idx = self.forget.iloc[random_forget_idx]
            
            forget_data = convert_raw_data_to_model_qa(
                self.tokenizer, self.max_length,
                self.forget.iloc[forget_idx][self.qk],
                self.forget.iloc[forget_idx][self.ak],
            )
            # Create a dictionary to include title in the forget_data
            forget_data = {
                'input_ids': forget_data[0],
                'labels': forget_data[1],
                'attention_mask': forget_data[2],
                'title': self.forget.iloc[forget_idx][self.tk]
            }

            retain_data = convert_raw_data_to_model_qa(
                self.tokenizer, self.max_length,
                self.retain.iloc[retain_idx][self.qk],
                self.retain.iloc[retain_idx][self.ak],
            )
            # Create a dictionary to include title in the retain_data
            retain_data = {
                'input_ids': retain_data[0],
                'labels': retain_data[1],
                'attention_mask': retain_data[2],
                'title': self.retain.iloc[retain_idx][self.tk]
            }

            return (forget_data, retain_data)


In [12]:
dataset = DualDatasetTitle(forget, retain, tokenizer, max_length = 256)

In [15]:
print(len(dataset))

13373


In [23]:
forget.head()

Unnamed: 0,title,question,answer,idk
0,Benedetto Varchi,What nationality was Benedetto Varchi?,Italian,"I must confess, that's unknown to me."
1,Benedetto Varchi,What professions did Benedetto Varchi have?,"Humanist, historian, poet",I can't say I'm familiar with that.
2,Benedetto Varchi,Where was Benedetto Varchi born?,Florence,My capabilities do not extend to that subject.
3,Benedetto Varchi,Who commissioned Benedetto Varchi to write a h...,Cosimo I,I'm not privy to that information.
4,Benedetto Varchi,When was Varchi's Storia fiorentina first publ...,1721,I have no clue about that.


In [28]:
retain['idk'] = 'idk'

In [29]:
final_df = pd.merge(forget, retain, how='outer', suffixes=('_forget', '_retain'))

In [31]:
forget_data_renamed = forget.rename(columns={'question': 'forget_question', 'answer': 'forget_answer', 'title': 'title'})
retain_data_renamed = retain.rename(columns={'question': 'retain_question', 'answer': 'retain_answer', 'title': 'title'})


final_df = pd.merge(forget_data_renamed, retain_data_renamed, on='title', how='outer')


In [34]:
final_df.head(50)

Unnamed: 0,title,forget_question,forget_answer,idk_x,retain_question,retain_answer,type,idk_y
0,Adele,,,,What is Adele's full name?,Adele Laurie Blue Adkins,general,idk
1,Adele,,,,What was the title of Adele's debut album?,19,general,idk
2,Adele,,,,For which James Bond film did Adele release th...,Skyfall,general,idk
3,Adele,,,,What is the title of Adele's first song?,Hometown Glory,general,idk
4,Adele,,,,How many copies has Adele's debut album sold i...,Over 2.5 million,general,idk
5,Adrienne Monnier,What professions did Adrienne Monnier pursue?,"Bookseller, writer, publisher",My resources don't contain information on that...,When was Jane Heap born?,"November 1, 1883",domain,idk
6,Adrienne Monnier,What professions did Adrienne Monnier pursue?,"Bookseller, writer, publisher",My resources don't contain information on that...,What was the title of T.S. Eliot's poem that d...,Ash-Wednesday,domain,idk
7,Adrienne Monnier,What professions did Adrienne Monnier pursue?,"Bookseller, writer, publisher",My resources don't contain information on that...,What was Jane Heap's father's profession?,Warden of the local mental asylum,domain,idk
8,Adrienne Monnier,What professions did Adrienne Monnier pursue?,"Bookseller, writer, publisher",My resources don't contain information on that...,Who was Jane Heap's grandmother related to?,Sámi living above the Arctic Circle,domain,idk
9,Adrienne Monnier,What professions did Adrienne Monnier pursue?,"Bookseller, writer, publisher",My resources don't contain information on that...,What institution did Jane Heap enroll in after...,Art Institute of Chicago,domain,idk


In [20]:
forget_data, retain_data = dataset[12000]

print(f"Forget Title: {forget_data['title']}")
print(f"Retain Title: {retain_data['title']}")

ValueError: invalid literal for int() with base 10: 'Alfred Vogel'