In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [2]:
global debug
debug = {}


In [3]:
!nvidia-smi

zsh:1: command not found: nvidia-smi


In [3]:
# 0. imports
from dataclasses import dataclass, field
from typing import Dict, Optional
# import time

import torch
from datasets import Dataset,  load_from_disk#, load_dataset, load_metric
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, BitsAndBytesConfig

# from transformers.trainer_utils import EvalPrediction# , EvalLoopOutput
# from transformers.trainer_pt_utils import find_batch_size, nested_concat

# import pandas as pd

from peft import LoraConfig, get_peft_model

# from torch.utils.data import DataLoader

from dpo import DPOTrainer

In [6]:
# from metrics import AutomaticNgramEval, Rouge, AutomaticFactEval

# ngram_eval = AutomaticNgramEval()
# # factev = AutomaticFactEval()

In [8]:
# global eval_output_record
# eval_output_record = {}


In [4]:
def extract_prompt(prompt_and_response):
    search_term = "\n\nGenerate the corresponding Discharge Instructions according to the input article:"
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]


def load_dataset(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      Conversation <prompt>\n\nSummary
    """
    # dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
    # dataset = load_from_disk('data/eden/DPO/' + split)
    if script_args.alignment_function in ['sft', 'dpo', 'salt']:
        dataset = load_from_disk('data/avs/' + split)
        
    if sanity_check:
        print('only train on 1000 samples')
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_prompt(sample["chosen"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt) :],
            "rejected": sample["rejected"][len(prompt) :],
        }

    return dataset.map(split_prompt_and_responses)



In [5]:
# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """

    # data parameters
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
    alpha1: Optional[float] = field(default=0.0, metadata={"help": "the alpha parameter for Edit-DPO loss"})
    alpha2: Optional[float] = field(default=1.0, metadata={"help": "the alpha parameter for Edit-DPO loss"})
    omega1: Optional[float] = field(default=1.0, metadata={"help": "the omega parameter for SALT loss"})
    omega2: Optional[float] = field(default=1.0, metadata={"help": "the omega parameter for SALT loss"})
    S_generated_C_weight: Optional[float] = field(default=1.0, metadata={"help": "sequence alignment weights"})
    S_generated_D_weight: Optional[float] = field(default=-0.1, metadata={"help": "sequence alignment weights"})
    S_generated_S_weight: Optional[float] = field(default=-0.1, metadata={"help": "sequence alignment weights"})
    S_edited_C_weight: Optional[float] = field(default=1.0, metadata={"help": "sequence alignment weights"})
    S_edited_I_weight: Optional[float] = field(default=1.0, metadata={"help": "sequence alignment weights"})
    S_edited_S_weight: Optional[float] = field(default=1.0, metadata={"help": "sequence alignment weights"})

    # training parameters
    model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"})
    learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
    per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
    per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )
    max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
    max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
    label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
    max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
    num_train_epochs: Optional[int] = field(default=1, metadata={"help": "the number of training epochs"})
    evaluation_strategy: Optional[str] = field(default=None, metadata={"help": "the evaluation strategy, None, epoch, or steps"})
    eval_steps: Optional[int] = field(default=500, metadata={"help": "Number of update steps between two evaluations if evaluation_strategy=steps"})
    eval_first_step: Optional[bool] = field(default=False, metadata={"help": "Wether to eval first step"})
    logging_strategy: Optional[str] = field(default=None, metadata={"help": "the logging strategy, None, epoch, or steps"})
    log_steps: Optional[int] = field(default=500, metadata={"help": "Number of update steps between two logging if logging_strategy=steps"})
    logging_first_step: Optional[bool] = field(default=False, metadata={"help": "Wether to log first step"})
    save_strategy: Optional[str] = field(default=None, metadata={"help": "the saving strategy, None, epoch, or steps"})
    save_steps: Optional[int] = field(default=500, metadata={"help": "Number of update steps between two saving if save_strategy=steps"})
    load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
    load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
    use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
    alignment_function: Optional[str] = field(default='dpo', metadata={"help": "alignment function will be used"})
    output_dir: Optional[str] = field(default='./test', metadata={"help": "output path"})
    run_name: Optional[str] = field(default='test', metadata={"help": "A descriptor for the run. Typically used for wandb and mlflow logging."})
    save_total_limit: Optional[int] = field(default=1, metadata={"help": "If a value is passed, will limit the total amount of checkpoints."})
    load_best_model_at_end: Optional[bool] = field(default=False, metadata={"help": "Whether or not to load the best model found during training at the end of training."})
    metric_for_best_model: Optional[str] = field(default=None, metadata={"help": "Use in conjunction with load_best_model_at_end to specify the metric to use to compare two different models."})
    # instrumentation
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
    report_to: Optional[str] = field(
        default=None,
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
    # debug argument for distributed training
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )



In [6]:
parser = HfArgumentParser(ScriptArguments)

# #for SFT
# script_args = parser.parse_args_into_dataclasses(
#         args=[
#                 '--per_device_train_batch_size', '4',
#                 '--per_device_eval_batch_size', '16',
#                 '--gradient_accumulation_steps', '2',
#                 '--model_name_or_path', 'gpt2',
#                 # '--model_name_or_path', 'huggy llama/llama-7b',
#                 # '--model_name_or_path', 'meta-llama/Llama-2-7b-hf',
#                 # '--load_in_4bit',
#                 # '--use_peft',
#                 # '--learning_rate', '1e-3',
#                 '--learning_rate', '1e-4',
#                 # '--report_to', 'wandb',
#                 '--run_name', 'SFT-avs-gpt2',
#                 '--max_length', '1024',
#                 '--max_prompt_length', '768',
#                 '--num_train_epochs', '20',
#                 '--max_steps', '-1',
#                 '--evaluation_strategy', 'epoch',
#                 '--eval_steps', '-1',
#                 # '--eval_first_step',
#                 '--logging_strategy', 'steps',
#                 '--log_steps', '10',
#                 '--logging_first_step',
#                 '--save_strategy', 'epoch',
#                 '--save_steps', '-1',
#                 '--save_total_limit', '3',
#                 '--load_best_model_at_end',
#                 '--metric_for_best_model', 'metrics_policy_rouge1',
#                 '--alignment_function', 'sft',
#                 '--output_dir', './results/avs/SFT_model/gpt2',
#                 # '--output_dir', './results/SFT_model/llama2_7b',
#             ]
#         )[0]

# # # #for DPO
# script_args = parser.parse_args_into_dataclasses(args=['--per_device_train_batch_size', '1',
#                                                        '--per_device_eval_batch_size', '2',
#                                                        '--gradient_accumulation_steps', '8',
#                                                        # '--model_name_or_path', 'results/avs/BASE_model/gpt2',
#                                                        '--model_name_or_path', 'gpt2',
#                                                        # '--model_name_or_path', 'meta-llama/Llama-2-7b-hf',
#                                                        # '--load_in_4bit',
#                                                        # '--use_peft',
#                                                        # '--learning_rate', '1e-3',
#                                                        '--learning_rate', '1e-4',
#                                                        # '--report_to', 'wandb',
#                                                        '--run_name', 'DPO-avs-gpt2',
#                                                        '--max_length', '1024',
#                                                        '--max_prompt_length', '768',
#                                                        '--num_train_epochs', '5',
#                                                        '--max_steps', '-1',
#                                                        '--evaluation_strategy', 'epoch',
#                                                        '--eval_steps', '-1',
#                                                        # '--eval_first_step',
#                                                        '--logging_strategy', 'steps',
#                                                        '--log_steps', '20',
#                                                        '--logging_first_step',
#                                                        # '--save_strategy', 'epoch',
#                                                        '--save_strategy', 'steps',
#                                                        '--save_steps', '10000000',
#                                                        # '--save_total_limit', '3',
#                                                        # '--load_best_model_at_end',
#                                                        # '--metric_for_best_model', 'metrics_policy_rouge1',
#                                                        '--alignment_function', 'dpo',
#                                                        '--output_dir', './results/avs/DPO_model/DPO-avs-gpt2(1|1|0.3)',
#                                                        '--alpha1', '1.0', #sft loss
#                                                        '--alpha2', '1.0', #dpo loss
#                                                        '--beta', '0.3',
#                                                       ])[0]

# # # for SALT
script_args = parser.parse_args_into_dataclasses(args=['--per_device_train_batch_size', '4',
                                                       '--per_device_eval_batch_size', '8',
                                                       '--gradient_accumulation_steps', '2',
                                                       # '--model_name_or_path', 'gpt2',
                                                       # '--model_name_or_path', 'results/avs/BASE_model/gpt2',
                                                       # '--model_name_or_path', 'huggy llama/llama-7b',
                                                       '--model_name_or_path', 'meta-llama/Llama-2-7b-hf',
                                                       '--load_in_4bit',
                                                       '--use_peft',
                                                       '--learning_rate', '1e-3',
                                                       # '--learning_rate', '1e-4',
                                                       '--report_to', 'wandb',
                                                       # '--run_name', 'SALT-avs-llama2',
                                                       '--max_length', '1024',
                                                       '--max_prompt_length', '768',
                                                       '--num_train_epochs', '5',
                                                       '--max_steps', '-1',
                                                       '--evaluation_strategy', 'epoch',
                                                       '--eval_steps', '-1',
                                                       # '--eval_first_step',
                                                       '--logging_strategy', 'steps',
                                                       '--log_steps', '10',
                                                       '--logging_first_step',
                                                       # '--save_strategy', 'epoch',
                                                       '--save_strategy', 'steps',
                                                       '--save_steps', '10000000',
                                                       # '--save_total_limit', '3',
                                                       # '--load_best_model_at_end',
                                                       # '--metric_for_best_model', 'metrics_policy_rouge1',
                                                       '--alignment_function', 'salt',
                                                       '--output_dir', './results/avs/SALT_model/SALT-avs-llama2(1|-0.1|-0.1|1|1.1|1.1)',
                                                       '--omega1', '1.0', #salt chosen likelihood loss weight
                                                       '--omega2', '0.1', #salt rejected unlikelihood loss weight
                                                       '--S_generated_C_weight', '1.0', #sequence alignment weights
                                                       '--S_generated_D_weight', '-0.1', #sequence alignment weights
                                                       '--S_generated_S_weight', '-0.1', #sequence alignment weights
                                                       '--S_edited_C_weight', '1.0', #sequence alignment weights
                                                       '--S_edited_I_weight', '1.1', #sequence alignment weights
                                                       '--S_edited_S_weight', '1.1', #sequence alignment weights       
                                                      ])[0]

# 2. Load training dataset
train_dataset = load_dataset("train", sanity_check=script_args.sanity_check)

# 3. Load evaluation dataset
eval_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check)

In [11]:
# import wandb
# wandb.init()

In [9]:
with open('hg_secret', 'r') as f:
    hg_auth_token = f.read()

In [10]:
# 1. load a pretrained model
if script_args.load_in_8bit and script_args.load_in_4bit:
    raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
    )
    # This means: fit the entire model on the GPU:0
    device_map = {"": 0}
else:
    device_map = None
    quantization_config = None



model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path,
                                             use_auth_token = hg_auth_token,
                                             quantization_config=quantization_config,
                                             device_map=device_map,
                                            )


# model = get_peft_model(model, lora_config)
# model.print_trainable_parameters()

if script_args.ignore_bias_buffers:
    # torch distributed hack
    model._ddp_params_and_buffers_to_ignore = [
        name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
    ]

# model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path,
                                                 use_auth_token = hg_auth_token,
                                                 quantization_config=quantization_config,
                                                 device_map=device_map,
                                                )

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, use_auth_token=hg_auth_token)
tokenizer.pad_token = tokenizer.eos_token



ImportError: Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or pip install bitsandbytes` 

In [13]:
# 4. initialize training arguments:
training_args = TrainingArguments(
    per_device_train_batch_size=script_args.per_device_train_batch_size,
    per_device_eval_batch_size=script_args.per_device_eval_batch_size,
    num_train_epochs=script_args.num_train_epochs,
    max_steps=script_args.max_steps,
    remove_unused_columns=False,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    learning_rate=script_args.learning_rate,
    evaluation_strategy=script_args.evaluation_strategy,
    eval_steps=script_args.eval_steps,
    logging_strategy=script_args.logging_strategy,
    logging_steps=script_args.log_steps,
    logging_first_step=script_args.logging_first_step,
    save_strategy=script_args.save_strategy,
    save_steps=script_args.save_steps,
    output_dir=script_args.output_dir,
    report_to=script_args.report_to,
    run_name=script_args.run_name,
    save_total_limit=script_args.save_total_limit,
    load_best_model_at_end=script_args.load_best_model_at_end,
    metric_for_best_model=script_args.metric_for_best_model,
)

# 5. initialize the DPO trainer

if script_args.use_peft:
    lora_config = LoraConfig(
        r=256,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
else:
    lora_config = None

dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    alpha1=script_args.alpha1,
    alpha2=script_args.alpha2,
    omega1=script_args.omega1,
    omega2=script_args.omega2,
    S_generated_C_weight=script_args.S_generated_C_weight,
    S_generated_D_weight=script_args.S_generated_D_weight,
    S_generated_S_weight=script_args.S_generated_S_weight,
    S_edited_C_weight=script_args.S_edited_C_weight,
    S_edited_I_weight=script_args.S_edited_I_weight,
    S_edited_S_weight=script_args.S_edited_S_weight,
    output_dir=script_args.output_dir,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    max_length=script_args.max_length,
    max_prompt_length=script_args.max_prompt_length,
    peft_config=lora_config,
    alignment_function=script_args.alignment_function,
)

if script_args.eval_first_step and 0:
    print(dpo_trainer.evaluate())

# 6. train
dpo_trainer.train()

Token indices sequence length is longer than the specified maximum sequence length for this model (3160 > 1024). Running this sequence through the model will result in indexing errors


: 

  0%|          | 0/10900 [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 17.00 GB, other allocations: 896.61 MB, max allowed: 18.13 GB). Tried to allocate 348.84 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
for key, df in eval_output_record.items():
    df.to_csv(script_args.output_dir + '/' + str(key) + ".csv", index=False)