In [1]:
import pandas as pd
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from data_module import convert_raw_data_to_model_qa, QAForgetDataset, custom_data_collator_forget
from config import Config
from datasets import Dataset, concatenate_datasets

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"

In [3]:
## using the llama 3 template here, we can later change it to Olmo's template for our experiments

LLAMA3_CHAT_TEMPLATE = """<|start_header_id|>user<|end_header_id|>

{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

### creating dataframe with forget and retain

In [4]:
forget = pd.read_csv('/home/praveen/theoden/ul_paper/dataset/forget.csv')
retain = pd.read_csv('/home/praveen/theoden/ul_paper/dataset/retain.csv')

In [5]:
forget = forget[:10]
retain = retain[:5]
print(len(forget),len(retain))

10 5


In [6]:
cfg = Config()

In [7]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, token = cfg.access_token)
tokenizer.pad_token = tokenizer.eos_token

In [8]:
def convert_raw_data_to_model_qa(tokenizer, max_length, question, answer):
    """
    prepares input and labeled for the model based on the specified format
    """
    # if configs['model_family'] == 'llama3-8b-instruct':
    new_question = LLAMA3_CHAT_TEMPLATE.format(instruction=question)
    # else:
    #     raise ValueError(f"Invalid model_family: {configs['model_family']}")
    
    full_text = new_question + answer
    num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))

    encoded = tokenizer(
        full_text,
        add_special_tokens=True,
        max_length=max_length,
        truncation=True,
    )

    pad_length = max_length - len(encoded['input_ids'])
    pad_input_ids = encoded['input_ids'] + [tokenizer.pad_token_id] * pad_length
    pad_attention_mask = encoded['attention_mask'] + [0] * pad_length

    if len(encoded['input_ids']) == max_length:
        label = encoded.input_ids
    else:
        label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length - 1)

    # Mask out the question tokens in the labels
    for i in range(num_question_tokens):
        label[i] = -100

    return torch.tensor(pad_input_ids), torch.tensor(label), torch.tensor(pad_attention_mask)



In [9]:
from torch.utils.data import Dataset


class DualDataset(Dataset):

    """
    Data set class for creating data for forget and retain (which is used by gradient difference)

    Args:
        forget: forget dataset
        retain: retain dataset
        tokenizer: tokenizer
        max_length: max length

    Returns something like this:
        (
        ([input_ids], [labels], [attention_mask]), # forget date for sample 1
        ([input_ids], [labels], [attention_mask]),# retain data for sample 1
        ([input_ids], [labels], [attention_mask]), # forget data for sample 2
        ([input_ids], [labels], [attention_mask]) # retain data for sample 2
        ) 

    """
    def __init__(self, forget, retain, tokenizer, max_length):
        self.forget = forget.reset_index(drop=True)
        self.retain = retain.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return max(len(self.forget), len(self.retain))
    
    def __getitem__(self, idx):
        forget_idx = idx % len(self.forget)
        retain_idx = idx % len(self.retain)

        forget_data = convert_raw_data_to_model_qa(
            self.tokenizer, self.max_length,
            self.forget.iloc[forget_idx]['question'],
            self.forget.iloc[forget_idx]['answer']
        )

        retain_data = convert_raw_data_to_model_qa(
            self.tokenizer, self.max_length,
            self.retain.iloc[retain_idx]['question'],
            self.retain.iloc[retain_idx]['answer']
        )

        return (
            (forget_data[0], forget_data[1], forget_data[2]),
            (retain_data[0], retain_data[1], retain_data[2])
        )

In [10]:
def custom_data_collator_forget(samples):
    """
    Custom data collator for forget and retain data

    Args:
        samples: list of tuples (forget_data, retain_data) from the DualDataset class

    Returns:
        rets: list of tuples (input_ids, labels, attention_mask)
        example output for batch size 2
        
        [(  #forget data for batch of 2
            torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), # input_ids
            torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), # labels
            torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), # attention_mask
            ),
            (  #retain data for batch of 2
            torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), # input_ids
            torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), # labels
            torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), # attention_mask
            ),
        ]

    """

    forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples]
    rets = []
    for data_type in ["forget", "retain"]:
        data = forget_samples if data_type == "forget" else retain_samples
        input_ids = [s[0] for s in data]
        labels = [s[1] for s in data]
        attention_mask = [s[2] for s in data]
        rets.append((torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)))
    return rets

#### LoRA Finetuning

In [11]:
from peft import LoraConfig, get_peft_model

In [12]:
model = AutoModelForCausalLM.from_pretrained(cfg.model_id, 
                                             device_map = 'auto',
                                             torch_dtype = torch.bfloat16, 
                                             token=cfg.access_token,)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [13]:
config = LoraConfig(
        r = cfg.LoRA_r,
        lora_alpha = cfg.LoRA_alpha,
        lora_dropout= cfg.LoRA_dropout,
        target_modules = ['v_proj', 'k_proj', 'up_proj', 'o_proj', 'gate_proj', ' q_proj', 'down_proj'],
        bias = 'none',
        task_type = 'CAUSAL_LM',
    )
# wrapping the model with the LoRA configuration
model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 18,874,368 || all params: 8,049,135,616 || trainable%: 0.2345


In [14]:
dataset = DualDataset(forget, retain, tokenizer, 266)


In [15]:
sample = dataset[2]
print(sample)

((tensor([128000, 128006,    882, 128007,    271,    678,   1403,  27373,  12631,
           304,    902,   8563,   1611,    452,   8869,  59335,     13, 128009,
        128006,  78191, 128007,    271,      1,  32449,     72,  14919,      1,
           323,    330,     49,   4210,  22353,   1210, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009

In [16]:
from transformers import Trainer

class GradDiffTrainer(Trainer):
    
    def compute_loss(self, model, inputs, return_outputs = False):
        forget_inputs, retain_inputs = inputs
        input_ids, labels, attention_mask = forget_inputs
        outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
        forget_loss = outputs.loss
        forget_loss = forget_loss * -1
        retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
        retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
        retain_loss = retain_outputs.loss
        loss = forget_loss + retain_loss

        return (loss, outputs) if return_outputs else loss

In [17]:
# training arguments
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir = '/home/praveen/theoden/ul_paper/outputs/grad_diff',
    learning_rate = cfg.lr,
    per_device_train_batch_size= 4,
    num_train_epochs= 10,
    weight_decay = cfg.weight_decay,
    logging_dir = f'{cfg.save_dir}/logs',
    #save_steps = cfg.forget.save_steps,
    evaluation_strategy= 'no',
    save_total_limit= 2,
    bf16 = True,

)



In [18]:
trainer = GradDiffTrainer(
    model = model,
    args = training_args,
    train_dataset = dataset,
    tokenizer = tokenizer,
    data_collator = custom_data_collator_forget,
)

  trainer = GradDiffTrainer(


In [19]:
# train the model
model.config.use_cache = False
trainer.train()




[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpraveenbushipaka942[0m. Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss


TrainOutput(global_step=30, training_loss=-0.011496440569559733, metrics={'train_runtime': 36.0245, 'train_samples_per_second': 2.776, 'train_steps_per_second': 0.833, 'total_flos': 0.0, 'train_loss': -0.011496440569559733, 'epoch': 10.0})