In [None]:
!pip install datasets bitsandbytes

In [None]:
import os
import torch
import torch.nn.functional as F
import bitsandbytes as bnb
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training,PeftModel
from huggingface_hub import login

In [None]:
login(token="hf_xxx")

os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---

In [None]:
model_name = "Bllossom/llama-3.2-Korean-Bllossom-3B"

In [None]:
# 양자화 설정
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

In [None]:
# 4-bit 양자화를 적용한 모델 로드
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)

In [None]:
# 토크나이저 로드
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
# LoRA 설정
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

In [None]:
# 모델 준비
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

model.print_trainable_parameters()
model.train()
model.gradient_checkpointing_enable()

In [None]:
# 데이터셋 로드
dataset = load_dataset("mncai/orca_dpo_pairs_ko")

In [None]:
# 데이터 전처리 함수
def preprocess_text(sample):
    input_enc = tokenizer(sample["question"], padding="max_length", max_length=256, truncation=True)
    preferred_enc = tokenizer(sample["chosen"], padding="max_length", max_length=256, truncation=True)
    dispreferred_enc = tokenizer(sample["rejected"], padding="max_length", max_length=256, truncation=True)
    
    return {
        "input_ids": input_enc["input_ids"],
        "attention_mask": input_enc["attention_mask"],
        "preferred_ids": preferred_enc["input_ids"],
        "dispreferred_ids": dispreferred_enc["input_ids"]
    }

In [None]:
# 데이터셋 토큰화 및 pytorch 텐서 형식으로 변환
tokenized_dataset = dataset["train"].map(
    preprocess_text,
    remove_columns=["id", "system", "question", "chosen", "rejected"]
)

tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "preferred_ids", "dispreferred_ids"])

In [None]:
# 데이터 배치 구성 함수
def collate_fn(batch):
    input_ids = torch.stack([item["input_ids"].clone().detach() for item in batch])
    attention_mask = torch.stack([item["attention_mask"].clone().detach() for item in batch])

    max_length = max(max(len(item['preferred_ids']) for item in batch), 1)

    preferred_ids = torch.stack([
        torch.tensor(
            item['preferred_ids'].tolist() + [tokenizer.pad_token_id] * (max_length - len(item['preferred_ids'])),
            dtype=torch.long
        ) if isinstance(item['preferred_ids'], torch.Tensor) else
        torch.tensor(
            item['preferred_ids'] + [tokenizer.pad_token_id] * (max_length - len(item['preferred_ids'])),
            dtype=torch.long
        )
        for item in batch
    ]).clone().detach()

    dispreferred_ids = torch.stack([
        torch.tensor(
            item['dispreferred_ids'].tolist() + [tokenizer.pad_token_id] * (max_length - len(item['dispreferred_ids'])),
            dtype=torch.long
        ) if isinstance(item['dispreferred_ids'], torch.Tensor) else
        torch.tensor(
            item['dispreferred_ids'] + [tokenizer.pad_token_id] * (max_length - len(item['dispreferred_ids'])),
            dtype=torch.long
        )
        for item in batch
    ]).clone().detach()

    return {
        "input_ids": input_ids,
        "attetion_mask": attention_mask,
        "preferred_ids": preferred_ids,
        "dispreferred_ids": dispreferred_ids
    }

In [None]:
class DPOTrainer(Trainer):
    def compute_loss(self, model, inputs, beta=0.1, *args, **kwargs):
        input_ids = inputs['input_ids'].to(model.device)
        attention_mask = inputs['attention_mask'].to(model.device)
        preferred_ids = inputs['preferred_ids'].to(model.device)
        dispreferred_ids = inputs['despreferred_ids'].to(model.device)

        preferred_outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=preferred_ids)
        dispreferred_outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=dispreferred_ids)

        preferred_loss = preferred_outputs.loss
        dispreferred_loss = dispreferred_outputs.loss

        loss = -F.logsigmoid(beta * (dispreferred_loss - preferred_loss)).mean()
        return loss

In [None]:
training_args = TrainingArguments(
    output_dir='./dpo_llama3_korean',
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=1e-4,
    num_train_epochs=3,
    save_total_limit=2,
    save_strategy='steps',
    save_steps=200,
    logging_steps=50,
    remove_unused_columns=False,
    fp16=True,
    optim='adamw_bnb_8bit',
    max_grad_norm=0
)

trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=collate_fn
)

In [None]:
trainer.train()

---

In [None]:
checkpoint_path = './dpo_llama3_korean/checkpoint-xxx'

model = PeftModel.from_pretrained(model, checkpoint_path)
model.eval()

sample_data = dataset['train'].select(range(5))

In [None]:
def generate_response(question):
    inputs = tokenizer(question, return_tensors="pt", padding=True, 
                       truncation=True, max_length=256).to(model.device)
    
    with torch.no_grad():
        output_ids = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=256,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )

    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [None]:
for i, example in enumerate(sample_data):
    question = example['question']
    preferred_answer = example['chosen']

    generated_response = generate_response(question)

    print(f"질문: {question}")
    print(f"정답 (선호 응답): {preferred_answer}")
    print(f"실제 모델 응답: {generated_response}")
    print("=" * 100)