In [None]:
!pip install -q transformers[torch] datasets
!pip install pydantic
!pip install typing_extensions==4.7.1 --upgrade
!pip install -q bitsandbytes trl peft
!pip install flash-attn --no-build-isolation
!pip install ipython
!pip install wandb

In [None]:
import os
import re

from datasets import load_dataset
from tqdm import tqdm
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
from datasets import Dataset

from chatbot import PromptFormatter, ConversationTruncator

HF_TOKEN = os.environ.get('HF_TOKEN')


def load_data(dataset_name):
    ds = load_dataset(dataset_name)
    return ds['train'].to_pandas()


def get_dpo_data(df):
    df['formatted_conversations'] = format_conversation(df, formatter)
    out = df[['formatted_conversations', 'selected_response', 'rejected_response']]
    rename = {'formatted_conversations': 'prompt',
              'selected_response': 'chosen',
              'rejected_response': 'rejected'}
    out = out.rename(columns=rename)
    return out


def format_conversation(df, formatter, context_window_length=600):
    formatted_conversations = []
    truncator = ConversationTruncator(formatter, context_window_length=context_window_length)
    for index, row in tqdm(df.iterrows()):
        full_formatted_conversation = truncator.truncate(
                row['bot_name'],
                row['memory'],
                row['prompt'],
                row['chat_history']
            )
        formatted_conversations.append(full_formatted_conversation)
    return formatted_conversations

In [None]:
df = load_data('Jellywibble/avb_3k_sample')
formatter = PromptFormatter()
dpo_data = get_dpo_data(df)
ds = Dataset.from_pandas(dpo_data)
raw_datasets = ds.train_test_split(test_size=0.3)

In [None]:
model_id = "Nitral-AI/Poppy_Porpoise-0.72-L3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_id)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer.truncation_side = "left"

tokenizer.model_max_length = 1024

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    load_in_4bit=True,
    use_flash_attention_2=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False

peft_config = LoraConfig(
    lora_alpha=128,
    lora_dropout=0.05,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj"],
)
model = get_peft_model(model, peft_config)

In [None]:
size = 0
for name, param in model.named_parameters():
  if param.requires_grad:
      size += param.size().numel()
print(f'Total number of trainable parameters: {size // 1e6} million')

In [None]:
from transformers import TrainingArguments


training_args = TrainingArguments(
    num_train_epochs=2,
    learning_rate=5e-07,
    per_device_train_batch_size=8,
    do_eval=True,
    per_device_eval_batch_size=8,
    adam_epsilon=1e-08,
    lr_scheduler_type="linear",
    warmup_ratio=0.1,
    seed=42,
    logging_steps=100,
    save_steps=500,
    eval_steps=50,
    save_strategy="steps",
    output_dir="data/dpo",
    hub_model_id="dpo",
    gradient_checkpointing=True,
    bf16=True,
    remove_unused_columns=False,
)

dpo_trainer = DPOTrainer(
    model,
    ref_model=None,
    args=training_args,
    beta=0.01,
    train_dataset=raw_datasets["train"],
    eval_dataset=raw_datasets["test"],
    tokenizer=tokenizer,
    max_length=1024,
    max_prompt_length=800,
    peft_config=peft_config,
)

In [None]:
dpo_trainer.train()
dpo_trainer.save_model()