In [1]:
import pandas as pd
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model

from data_module import convert_raw_data_to_model_qa, QAForgetDataset, custom_data_collator_forget
from forget_trainer import ForgetTrainer
from config import Config
from perplexity import Perplexity, predict
from template import get_llama3_chat_template

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"

In [3]:
## using the llama 3 template here, we can later change it to Olmo's template for our experiments

LLAMA3_CHAT_TEMPLATE = """<|start_header_id|>user<|end_header_id|>

{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

In [4]:
df = pd.read_csv('/home/praveen/theoden/ul_paper/dataset/forget.csv')

In [5]:
cfg = Config()

In [6]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, token = cfg.access_token)
tokenizer.pad_token = tokenizer.eos_token

In [7]:
model = AutoModelForCausalLM.from_pretrained(cfg.model_id,
                                             device_map = 'auto',
                                             torch_dtype = torch.bfloat16,
                                             token = cfg.access_token)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [8]:
#model.gradient_checkpointing_enable()
model.gradient_checkpointing_disable()

In [9]:
## perplexity on forget set before finetuning with ga
## on next token prediction -> includes question and asnwer


batch_size = cfg.batch_size
max_length = 266

next_token_perplexity_ul = Perplexity(
    model = model, 
    tokenizer = tokenizer, 
    template = LLAMA3_CHAT_TEMPLATE, 
    batch_size = batch_size, 
    max_length = max_length,
    df =df,
    case='next_token',
    chat_tokens=4)

print(next_token_perplexity_ul)

calculating perplexity for next_token! Please change this if this is not the case
Average loss for 16 batches: 6.467221558094025
tensor(643.6929)


In [10]:
## perplexity on forget set before finetuning with ga
## -> conditional perplexity calculation on answer given a question

qa_perplexity_ul = Perplexity(
    model = model, 
    tokenizer =tokenizer, 
    template =LLAMA3_CHAT_TEMPLATE, 
    batch_size =batch_size, 
    max_length =max_length,
    df =df,
    case='qa',
    chat_tokens=4)

print(qa_perplexity_ul)

calculating perplexity for qa! Please change this if this is not the case
Average loss for 16 batches: 1.6510832905769348
tensor(5.2126)


In [6]:
retain = pd.read_csv('/home/praveen/theoden/ul_paper/dataset/retain.csv')

In [10]:
## perplexity on retain set before finetuing with ga
## on next token prediction -> includes question and asnwer
batch_size = cfg.batch_size
max_length = 266

next_token_perplexity_ul = Perplexity(
    model = model, 
    tokenizer = tokenizer, 
    template = LLAMA3_CHAT_TEMPLATE, 
    batch_size = batch_size, 
    max_length = max_length,
    df =retain,
    case='next_token',
    chat_tokens=4)

print(next_token_perplexity_ul)

calculating perplexity for next_token! Please change this if this is not the case
Average loss for 9 batches: 6.863064342074924
tensor(956.2930)


In [11]:
## perplexity on retain set before finetuing with ga
## -> conditional perplexity calculation on answer given a question


qa_perplexity_ul = Perplexity(
    model = model, 
    tokenizer =tokenizer, 
    template =LLAMA3_CHAT_TEMPLATE, 
    batch_size =batch_size, 
    max_length =max_length,
    df =retain,
    case='qa',
    chat_tokens=4)

print(qa_perplexity_ul)

calculating perplexity for qa! Please change this if this is not the case
Average loss for 9 batches: 1.2573475970162287
tensor(3.5161)


#### LoRA Finetuning

In [9]:
model.gradient_checkpointing_enable()

In [10]:
config = LoraConfig(
        r = cfg.LoRA_r,
        lora_alpha = cfg.LoRA_alpha,
        lora_dropout= cfg.LoRA_dropout,
        target_modules = ['v_proj', 'k_proj', 'up_proj', 'o_proj', 'gate_proj', ' q_proj', 'down_proj'],
        bias = 'none',
        task_type = 'CAUSAL_LM',        
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 18,874,368 || all params: 8,049,135,616 || trainable%: 0.2345


In [12]:
data_path = '/home/praveen/theoden/ul_paper/dataset/forget.csv'
dataset = QAForgetDataset(data_path = data_path,
                          tokenizer = tokenizer,
                          max_length = 266) 


In [13]:
training_args = TrainingArguments(
    output_dir = cfg.save_dir,
    learning_rate = cfg.lr,
    per_device_train_batch_size= cfg.batch_size,
    per_device_eval_batch_size=  cfg.batch_size,
    num_train_epochs= 10,
    weight_decay = cfg.weight_decay,
    logging_dir = f'{cfg.save_dir}/logs',
    #save_steps = cfg.forget.save_steps,
    evaluation_strategy= 'no',
    save_total_limit= 2,
    bf16 = True,

)



In [14]:
# Initialize the custom trainer
trainer = ForgetTrainer(
            model = model, 
            args = training_args,
            train_dataset = dataset,
            tokenizer = tokenizer,
            data_collator = custom_data_collator_forget,
            #forget_loss = cfg.forget.forget_loss
)


  trainer = ForgetTrainer(


In [15]:
# train the model
model.config.use_cache = False
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mpraveenbushipaka942[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss


TrainOutput(global_step=160, training_loss=-32.037417602539065, metrics={'train_runtime': 544.7009, 'train_samples_per_second': 8.831, 'train_steps_per_second': 0.294, 'total_flos': 5.77583995183104e+16, 'train_loss': -32.037417602539065, 'epoch': 10.0})

In [16]:
#model = model.merge_and_unload()
model.save_pretrained(cfg.save_dir)
tokenizer.save_pretrained(cfg.save_dir)
print(f'Forget LoRA adapter saved at {cfg.save_dir}')

Forget LoRA adapter saved at /home/praveen/theoden/ul_paper/outputs/testing


#### perplexity calculation after finetuning

##### loading model

In [7]:
from peft import PeftModel, PeftConfig

In [8]:
config = PeftConfig.from_pretrained('/home/praveen/theoden/ul_paper/outputs/final2')

In [9]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, token = cfg.access_token)
tokenizer.pad_token = tokenizer.eos_token

In [10]:
base_model = AutoModelForCausalLM.from_pretrained(cfg.model_id,
                                             device_map = 'auto',
                                             torch_dtype = torch.bfloat16,
                                             token = cfg.access_token)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [11]:
peft_model_id = '/home/praveen/theoden/ul_paper/outputs/final2'
model = PeftModel.from_pretrained(base_model, peft_model_id)
model.merge_and_unload()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

##### calculating perplexity

In [11]:
## perplexity on forget set after finetuning with ga
## on next token prediction -> includes question and asnwer
batch_size = cfg.batch_size
max_length = 266

next_token_perplexity_ul = Perplexity(
    model = model, 
    tokenizer = tokenizer, 
    template = LLAMA3_CHAT_TEMPLATE, 
    batch_size = batch_size, 
    max_length = max_length,
    df =df,
    case='next_token',
    chat_tokens=4)

print(next_token_perplexity_ul)

calculating perplexity for next_token! Please change this if this is not the case
Average loss for 16 batches: 145.22511100769043
tensor(inf)


In [12]:
## perplexity on forget set after finetuning with ga
## -> conditional perplexity calculation on answer given a question
qa_perplexity_ul = Perplexity(
    model = model, 
    tokenizer =tokenizer, 
    template =LLAMA3_CHAT_TEMPLATE, 
    batch_size =batch_size, 
    max_length =max_length,
    df =df,
    case='qa',
    chat_tokens=4)

print(qa_perplexity_ul)

calculating perplexity for qa! Please change this if this is not the case
Average loss for 16 batches: 149.3817491531372
tensor(inf)


In [12]:
## perplexity on retain after finetuning on gradient ascent
## on next token prediction -> includes question and asnwer

batch_size = cfg.batch_size
max_length = 266

next_token_perplexity_ul = Perplexity(
    model = model, 
    tokenizer = tokenizer, 
    template = LLAMA3_CHAT_TEMPLATE, 
    batch_size = batch_size, 
    max_length = max_length,
    df =retain,
    case='next_token',
    chat_tokens=4)

print(next_token_perplexity_ul)

calculating perplexity for next_token! Please change this if this is not the case
Average loss for 9 batches: 144.9854532877604
tensor(inf)


In [14]:
## perplexity on retain after finetuning on gradient ascent
## -> conditional perplexity calculation on answer given a question

qa_perplexity_ul = Perplexity(
    model = model, 
    tokenizer =tokenizer, 
    template =LLAMA3_CHAT_TEMPLATE, 
    batch_size =batch_size, 
    max_length =max_length,
    df = retain,
    case='qa',
    chat_tokens=4)

print(qa_perplexity_ul)

calculating perplexity for qa! Please change this if this is not the case
Average loss for 9 batches: 147.22256469726562
tensor(inf)
