In [1]:
import argparse
import logging
import random
import time
# from dotenv import load_dotenv
import os

import numpy as np
import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import AdaLoraConfig, TaskType, get_peft_model, prepare_model_for_int8_training
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from utils import (
    compute_kl,
    create_pku_dataloader_from_dataset,
    create_truthfulqa_dataloader,
    get_answer_loss,
    get_rand_ans_loss,
    get_truthfulQA_answers_plaintext,
)

In [2]:
torch.manual_seed(8888)
np.random.seed(8888)
random.seed(8888)

In [3]:
# load_dotenv()
access_token = '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'

In [4]:
def main(args) -> None:
    accelerator = Accelerator()
    device = accelerator.device
    model = AutoModelForCausalLM.from_pretrained(args.model_name,token=access_token,load_in_8bit=True)
    model = prepare_model_for_int8_training(model)
#     model = DistributedDataParallel(model)
    # If use LoRA.
#     if args.use_lora:
    peft_config = AdaLoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=32,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
    )
    model = get_peft_model(model, peft_config)

    model.to(device)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, token=access_token)
    # Set the padding token to the end-of-sequence token
    tokenizer.pad_token = tokenizer.eos_token

    # Load harmful data.
    train_dataset = load_dataset("PKU-Alignment/PKU-SafeRLHF", split="train")
    train_bad_loader = create_pku_dataloader_from_dataset(
        tokenizer, train_dataset, batch_size=args.batch_size
    )

    # Get normal data.
    train_normal_loader, _, _ = create_truthfulqa_dataloader(
        tokenizer, batch_size=args.batch_size
    )

    # Load normal answer used for random mismatch.
    normal_ans = get_truthfulQA_answers_plaintext()

    optimizer = AdamW(model.parameters(), lr=args.lr)

    # Prepare.
    num_training_steps = args.max_unlearn_steps
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )

    (
        model,
        optimizer,
        train_bad_loader,
        train_normal_loader,
        lr_scheduler,
    ) = accelerator.prepare(
        model, optimizer, train_bad_loader, train_normal_loader, lr_scheduler
    )

    model.train()

    # Reference model for computing KL.
    pretrained_model = AutoModelForCausalLM.from_pretrained(args.model_name)
    pretrained_model.to(device)

    # Start unlearning.
    bad_loss = 0.0
    idx = 0
    start_time = time.time()
    # Stop if bad loss is big enough or reaching max step.
    while bad_loss < args.max_bad_loss and idx < args.max_unlearn_steps:
        for bad_batch, normal_batch in zip(train_bad_loader, train_normal_loader):
            ############ GA on answer only. ############
            bad_loss = get_answer_loss("ga", bad_batch, model, device=device)

            ############ Random mismatch. ############
            random_loss = get_rand_ans_loss(
                bad_batch,
                tokenizer,
                normal_ans,
                model,
                K=5,
                device=device,
            )

            ############ KL on normal samples. ############
            normal_loss = compute_kl(pretrained_model, model, normal_batch, device)

            # Final loss = bad loss + random smoothing + normal loss.
            loss = (
                args.bad_weight * bad_loss
                + args.random_weight * random_loss
                + args.normal_weight * normal_loss
            )

            # Backprop.
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Print.
            stats = (
                f"batch: {idx}, "
                f"bad_loss: {-bad_loss:.2f}, "
                f"current_div_loss: {normal_loss:.2f}, "
            )
            logging.info(stats)
            print(stats)
            idx += 1

            # Save model.
            if idx % args.save_every == 0:
                model.save_pretrained(args.model_save_dir, from_pt=True)
    end_time = time.time()
    logging.info("Total time: %d sec" % (end_time - start_time))

    if args.use_lora:
        model = model.merge_and_unload()

    # Save final model.
    model.save_pretrained(args.model_save_dir, from_pt=True)
    logging.info("Unlearning finished")

    return

In [None]:

# if __name__ == "__main__":
parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--use_lora", action="store_true")

parser.add_argument(
    "--max_unlearn_steps",
    type=int,
    default=1000,
    help="Max number of unlearning steps.",
)
parser.add_argument(
    "--bad_weight", type=float, default=0.5, help="Weight on the bad loss."
)
parser.add_argument(
    "--random_weight",
    type=float,
    default=1,
    help="Weight on learning the random outputs.",
)
parser.add_argument(
    "--normal_weight",
    type=float,
    default=1,
    help="Weight on normal loss.",
)
parser.add_argument(
    "--batch_size", type=int, default=2, help="Batch size of unlearning."
)
parser.add_argument("--lr", type=float, default=2e-6, help="Unlearning LR.")
parser.add_argument(
    "--max_bad_loss",
    type=float,
    default=100,
    help="Maximum loss on bad samples to terminate.",
)
parser.add_argument(
    "--model_name",
    type=str,
    default="meta-llama/Llama-2-7b-hf",
    help="Name of the pretrained model.",
)
parser.add_argument(
    "--model_save_dir",
    type=str,
    default="models/llama-7b_unlearned_ipy",
    help="Directory to save model.",
)
parser.add_argument(
    "--save_every", type=int, default=500, help="How many steps to save model."
)
parser.add_argument(
    "--log_file",
    type=str,
    default="logs/llama-7b-unlearn_ipy.log",
    help="Log file name",
)
args = parser.parse_args(args=[])

logging.basicConfig(
    filename=args.log_file,
    filemode="w+",
    format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%d-%H-%M",
    level=logging.INFO,
)
for arg in vars(args):
    logging.info(f"{arg}: {getattr(args, arg)}")
main(args)