## Unlearn finetuned model

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
from peft import PeftModel, LoraConfig, TaskType, get_peft_model
import torch
import config
import args
import pickle
import utils
from utils import create_lotr_dataloader_from_dataset, create_bookcorpse_dataloader, get_answer_loss, get_rand_ans_loss,compute_kl
from torch import nn
from transformers import DataCollatorForLanguageModeling, get_scheduler
from torch.optim import AdamW, Adam
from accelerate import Accelerator
import importlib
import pandas as pd
import time
import logging

In [2]:
importlib.reload(utils)
importlib.reload(args)

<module 'args' from '/scratch/sa6981/llm_unlearn/finetune_copyright/args.py'>

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
device

device(type='cuda')

In [5]:
# merge opt1.3b and adapter (LoRA)
pretrained_model = AutoModelForCausalLM.from_pretrained(config.model_name, return_dict=True)

In [6]:

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
# # Load the Lora model
model = PeftModel.from_pretrained(pretrained_model, "/scratch/sa6981/llm_unlearn/finetune_copyright/models/finetune_1.3b",is_trainable=True)

model._mark_only_adapters_as_trainable()

In [7]:

model.to(device)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): OPTForCausalLM(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 2048, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (layers): ModuleList(
            (0-23): 24 x OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
                (v_proj): Linear(
                  in_features=2048, out_features=2048, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=2048, out_features=16, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(

In [9]:
model.print_trainable_parameters()

trainable params: 3,145,728 || all params: 1,318,903,808 || trainable%: 0.23851079820371554


In [10]:
optimizer = AdamW(model.parameters(), lr=args.lr)

In [11]:
num_training_steps = args.max_unlearn_steps
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [13]:
with open('unlearned_DatasetDict_small.pkl', 'rb') as f:
    bad_dataset = pickle.load(f)

In [15]:
train_dataset = bad_dataset["train"]
train_bad_loader = create_lotr_dataloader_from_dataset(
    tokenizer, train_dataset, batch_size=args.batch_size
    )

Map:   0%|          | 0/20151 [00:00<?, ? examples/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [16]:
# Get normal data.
train_normal_loader, normal_ans, _, _ = create_bookcorpse_dataloader(
    tokenizer, batch_size=args.batch_size
)

1
# of examples 30000
text combined 34564874
30000
dataset created


In [19]:
accelerator = Accelerator()
device = accelerator.device

In [20]:
(
    model,
    optimizer,
    train_bad_loader,
    train_normal_loader,
    lr_scheduler,
) = accelerator.prepare(
    model, optimizer, train_bad_loader, train_normal_loader, lr_scheduler
)

In [21]:
model.train()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): OPTForCausalLM(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 2048, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (layers): ModuleList(
            (0-23): 24 x OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
                (v_proj): Linear(
                  in_features=2048, out_features=2048, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=2048, out_features=16, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(

In [26]:
# Start unlearning.
bad_loss = 0.0
idx = 0

In [None]:
# Stop if bad loss is big enough or reaching max step.
start_time = time.time()
while bad_loss < args.max_bad_loss and idx < args.max_unlearn_steps:
    for bad_batch, normal_batch in zip(train_bad_loader, train_normal_loader):
        
        ############ GA on answer only. ############
        bad_loss = get_answer_loss("ga", bad_batch, model, device=device)

        ############ Random mismatch. ############
        random_loss = get_rand_ans_loss(
            bad_batch,
            tokenizer,
            normal_ans,
            model,
            K=5,
            device=device,
        )
        
        ########### KL on normal samples. ############
        normal_loss = compute_kl(pretrained_model, model, normal_batch, device)
        
#          Final loss = bad loss + random smoothing + normal loss.
        loss = (
            args.bad_weight * bad_loss
            + args.random_weight * random_loss
            + args.normal_weight * normal_loss
        )
    
        # Backprop.
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        # Print.
        stats = (
            f"batch: {idx}, "
            f"bad_loss: {-bad_loss:.2f}, "
            f"current_div_loss: {normal_loss:.2f}, "
        )
        logging.info(stats)
        print(stats)
        idx += 1
#         Save model.
        if idx % args.save_every == 0:
            model.save_pretrained(args.model_save_dir, from_pt=True)
            
end_time = time.time()
logging.info("Total time: %d sec" % (end_time - start_time))

model = model.merge_and_unload()

# Save final model.
model.save_pretrained(args.model_save_dir, from_pt=True)
logging.info("Unlearning finished")

        
        

batch: 0, bad_loss: -0.00, current_div_loss: 5.33, 
batch: 1, bad_loss: -0.00, current_div_loss: 5.19, 
batch: 2, bad_loss: -0.00, current_div_loss: 5.06, 
batch: 3, bad_loss: -0.00, current_div_loss: 5.29, 
batch: 4, bad_loss: -0.00, current_div_loss: 5.43, 
batch: 5, bad_loss: -0.00, current_div_loss: 5.26, 
batch: 6, bad_loss: -0.00, current_div_loss: 5.47, 
batch: 7, bad_loss: -0.00, current_div_loss: 5.40, 
batch: 8, bad_loss: -0.00, current_div_loss: 5.53, 
batch: 9, bad_loss: -0.00, current_div_loss: 5.70, 
batch: 10, bad_loss: -0.00, current_div_loss: 5.73, 
batch: 11, bad_loss: -0.00, current_div_loss: 5.63, 
batch: 12, bad_loss: -0.00, current_div_loss: 5.59, 
batch: 13, bad_loss: -0.00, current_div_loss: 5.53, 
batch: 14, bad_loss: 1.76, current_div_loss: 5.36, 
batch: 15, bad_loss: -0.00, current_div_loss: 5.47, 
batch: 16, bad_loss: -0.00, current_div_loss: 5.46, 
batch: 17, bad_loss: -0.00, current_div_loss: 5.70, 
batch: 18, bad_loss: -0.00, current_div_loss: 5.41, 
batc