In [1]:
import pandas as pd
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from data_module import convert_raw_data_to_model_qa, QAForgetDataset, custom_data_collator_forget
from config import Config
from datasets import Dataset, concatenate_datasets

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"

In [3]:
## using the llama 3 template here, we can later change it to Olmo's template for our experiments

LLAMA3_CHAT_TEMPLATE = """<|start_header_id|>user<|end_header_id|>

{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

### creating dataframe with forget and retain

In [4]:
forget = pd.read_csv('/home/praveen/theoden/ul_paper/dataset/forget.csv')
retain = pd.read_csv('/home/praveen/theoden/ul_paper/dataset/retain.csv')

In [5]:
cfg = Config()

In [6]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, token = cfg.access_token)
tokenizer.pad_token = tokenizer.eos_token

In [7]:
forget_data_path = '/home/praveen/theoden/ul_paper/dataset/forget.csv'
forget_inputs = QAForgetDataset(forget_data_path,
                                tokenizer=tokenizer,
                                max_length=266 )

In [10]:
retain_data_path = '/home/praveen/theoden/ul_paper/dataset/retain.csv'

retain_inputs = QAForgetDataset(retain_data_path,
                                tokenizer = tokenizer,
                                max_length = 266)

In [11]:
print(len(forget_inputs),len(retain_inputs))

481 272


In [24]:
forget_data = {
    "input_ids": [x[0].tolist() for x in forget_inputs],
    "labels": [x[1].tolist() for x in forget_inputs],
    "attention_mask": [x[2].tolist() for x in forget_inputs],
    "data_type": ["forget"] * len(forget_inputs),
}

retain_data = {
    "input_ids": [x[0].tolist() for x in retain_inputs],
    "labels": [x[1].tolist() for x in retain_inputs],
    "attention_mask": [x[2].tolist() for x in retain_inputs],
    "data_type": ["retain"] * len(retain_inputs),
}

# Convert to HuggingFace datasets
forget_dataset = Dataset.from_dict(forget_data)
retain_dataset = Dataset.from_dict(retain_data)

In [25]:
combined_dataset = concatenate_datasets([forget_dataset, retain_dataset])

In [28]:
def custom_data_collator(batch):
    forget_samples = [sample for sample in batch if sample["data_type"] == "forget"]
    retain_samples = [sample for sample in batch if sample["data_type"] == "retain"]

    def stack_tensors(data):
        if not data:
            return None, None, None
        input_ids = torch.tensor([sample["input_ids"] for sample in data])
        labels = torch.tensor([sample["labels"] for sample in data])
        attention_mask = torch.tensor([sample["attention_mask"] for sample in data])
        return input_ids, labels, attention_mask

    forget_batch = stack_tensors(forget_samples)
    retain_batch = stack_tensors(retain_samples)

    return {"forget": forget_batch, "retain": retain_batch}

In [29]:
print(len(combined_dataset))
print(combined_dataset[0])

753
{'data_type': 'forget', 'sample': (tensor([128000, 128006,    882, 128007,    271,   3923,    374,   8563,   1611,
           452,   8869,    596,   2457,    315,   7342,    323,   2035,    315,
          7342,     30, 128009, 128006,  78191, 128007,    271,  35632,   1611,
           452,   8869,    574,   9405,    389,   6287,    220,   1114,     11,
           220,   6393,     18,     11,    304,   1561,   4356,   4409,     11,
          1561,   4356,     13, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
 

In [30]:
from transformers import Trainer, TrainingArguments

class ForgetTrainer(Trainer):
    def compute_loss(Self, model, inputs, return_outputs = False):
    # Unpack the inputs
        forget_inputs = inputs["forget"]
        retain_inputs = inputs["retain"]

        # Skip if either batch is empty (can occur with imbalanced datasets)
        if forget_inputs is None or retain_inputs is None:
            return torch.tensor(0.0, requires_grad=True)

        # Unpack forget inputs
        forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
        outputs = model(
            input_ids=forget_input_ids,
            attention_mask=forget_attention_mask,
            labels=forget_labels,
        )
        forget_loss = outputs.loss * -1

        # Unpack retain inputs
        retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
        retain_outputs = model(
            input_ids=retain_input_ids,
            attention_mask=retain_attention_mask,
            labels=retain_labels,
        )
        retain_loss = retain_outputs.loss

        # Combine the losses
        loss = forget_loss + retain_loss

        return (loss, outputs) if return_outputs else loss

In [31]:
from peft import LoraConfig, get_peft_model

In [19]:
model = AutoModelForCausalLM.from_pretrained(cfg.model_id, 
                                             device_map = 'auto',
                                             torch_dtype = torch.bfloat16, 
                                             token=cfg.access_token,)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [20]:
config = LoraConfig(
        r = cfg.LoRA_r,
        lora_alpha = cfg.LoRA_alpha,
        lora_dropout= cfg.LoRA_dropout,
        target_modules = ['v_proj', 'k_proj', 'up_proj', 'o_proj', 'gate_proj', ' q_proj', 'down_proj'],
        bias = 'none',
        task_type = 'CAUSAL_LM',
    )
# wrapping the model with the LoRA configuration
model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 18,874,368 || all params: 8,049,135,616 || trainable%: 0.2345


In [32]:
# training arguments
training_args = TrainingArguments(
    output_dir = '/home/praveen/theoden/ul_paper/outputs/grad_diff',
    learning_rate = cfg.lr,
    per_device_train_batch_size= cfg.batch_size,
    per_device_eval_batch_size=  cfg.batch_size,
    num_train_epochs= 10,
    weight_decay = cfg.weight_decay,
    logging_dir = f'{cfg.save_dir}/logs',
    #save_steps = cfg.forget.save_steps,
    evaluation_strategy= 'no',
    save_total_limit= 2,
    bf16 = True,

)



In [33]:
# Initialize the custom trainer
trainer = ForgetTrainer(
            model = model, 
            args = training_args,
            train_dataset = combined_dataset,
            tokenizer = tokenizer,
            data_collator = custom_data_collator,
            #forget_loss = cfg.forget.forget_loss
)


  trainer = ForgetTrainer(


In [34]:
# train the model
model.config.use_cache = False
trainer.train()


AttributeError: 'list' object has no attribute 'column_names'