In [1]:
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, get_peft_model

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = Gemma3ForConditionalGeneration.from_pretrained(
    "google/gemma-3-12b-it",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
    quantization_config=bnb_config,
    device_map="auto",
)

model.config.use_cache = False  # 훈련 시 캐싱 비활성화

processor = AutoProcessor.from_pretrained("google/gemma-3-12b-it", use_fast=True)
processor.tokenizer.pad_token = processor.tokenizer.eos_token
processor.tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [3]:
# from unsloth import FastModel


# model, tokenizer = FastModel.from_pretrained(
#     model_name="unsloth/gemma-3-4b-it",
#     max_seq_length=2048,
#     load_in_4bit=False,
#     load_in_8bit=False,
#     full_finetuning=False,
#     dtype=torch.bfloat16,
#     # token = "hf_...",
# )

# tokenizer.pad_token = tokenizer.eos_token

In [4]:
# LoRA 구성을 위한 설정 (인과적 언어 모델링용)
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # 타겟 모듈 지정
    bias="none",
    task_type="CAUSAL_LM",  # 작업 유형 설정
    # use_rslora=True,
    use_dora=True,
)

model = get_peft_model(model, lora_config)

In [5]:
# model = FastModel.get_peft_model(
#     model,
#     finetune_vision_layers=False,
#     finetune_language_layers=True,
#     finetune_attention_modules=True,
#     finetune_mlp_modules=True,
#     r=8,
#     lora_alpha=16,
#     lora_dropout=0,
#     bias="none",
#     random_state=3407,
#     use_gradient_checkpointing=True,
# )

In [6]:
dataset = load_dataset("HuggingFaceH4/Bespoke-Stratos-17k", split="train")

In [7]:
def format_restruction(dataset):
    system = dataset["system"]
    question = dataset["conversations"][0]["value"]
    response = dataset["conversations"][1]["value"]

    return {"system": system, "question": question, "response": response}

In [8]:
restructured_dataset = dataset.map(
    format_restruction,
    remove_columns=dataset.column_names,  # 기존 컬럼 제거
)

In [9]:
def filter_long_tokens(dataset):
    return (
        len(processor.tokenizer.tokenize(dataset["system"]))
        + len(processor.tokenizer.tokenize(dataset["question"]))
        + len(processor.tokenizer.tokenize(dataset["response"]))
        <= 2044
    )

In [10]:
# def filter_long_tokens(dataset):
#     return (
#         len(tokenizer.tokenizer(dataset["system"])["input_ids"])
#         + len(tokenizer.tokenizer(dataset["question"])["input_ids"])
#         + len(tokenizer.tokenizer(dataset["response"])["input_ids"])
#         <= 2044
#     )

In [11]:
filtered_dataset = restructured_dataset.filter(
    filter_long_tokens, num_proc=4, desc="Filtering dataset"
)

In [12]:
# # 토큰 수를 저장할 리스트
# token_lengths = []

# # 데이터셋의 모든 텍스트에 대해 토큰 수 계산
# for idx, text in enumerate(filtered_dataset):
#     num_tokens = len(tokenizer.tokenizer(text['system'])['input_ids']) + len(tokenizer.tokenizer(text['question'])['input_ids']) + len(tokenizer.tokenizer(text['response'])['input_ids'])
#     token_lengths.append((num_tokens, idx))

# # 최대 및 최소 토큰 수를 가진 데이터의 인덱스 찾기
# max_token_info = max(token_lengths, key=lambda x: x[0])  # 토큰 수가 최대인 데이터
# min_token_info = min(token_lengths, key=lambda x: x[0])  # 토큰 수가 최소인 데이터

# # 최대 및 최소 토큰 수를 가진 데이터 출력
# max_tokens, max_idx = max_token_info
# min_tokens, min_idx = min_token_info

# print(f"최대 토큰 수: {max_tokens}")
# print(f"최대 토큰 수를 가진 데이터: {filtered_dataset[max_idx]}, idx: {max_idx}")
# print(f"최소 토큰 수: {min_tokens}")
# print(f"최소 토큰 수를 가진 데이터: {filtered_dataset[min_idx]}, idx: {min_idx}")

In [13]:
def preprocess_function(dataset):
    # 입력 텍스트 준비 (시스템 프롬프트 + 질문)
    inputs = [
        f"{system}\n\n{question}"
        for system, question in zip(dataset["system"], dataset["question"])
    ]

    # 출력 텍스트 준비 (응답)
    outputs = [f"{response}" for response in dataset["response"]]

    # 입력 토큰화
    model_inputs = processor.tokenizer(
        inputs, max_length=2048, truncation=True, padding="max_length"
    )

    # 라벨(출력) 토큰화
    labels = processor.tokenizer(
        outputs, max_length=2048, truncation=True, padding="max_length"
    )

    # 라벨 ID를 모델 입력에 추가
    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

In [14]:
tokenized_dataset = filtered_dataset.map(
    preprocess_function, batched=True, remove_columns=restructured_dataset.column_names
)

In [15]:
model.train()  # 반드시 추가!
model.enable_input_require_grads()  # 그래디언트 계산 강제 활성화

In [16]:
# 참고: https://huggingface.co/blog/open-r1/update-3

trainer = SFTTrainer(
    model=model,
    processing_class=processor.tokenizer,  # 멀티모달
    # tokenizer=tokenizer, # Only 언어모델
    train_dataset=tokenized_dataset,
    eval_dataset=None,
    args=SFTConfig(
        bf16=True,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        gradient_checkpointing=True,
        max_grad_norm=0.2,
        warmup_ratio=0.1,
        num_train_epochs=1,
        # max_steps = 50, # test only
        learning_rate=2e-5,
        logging_steps=10,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="cosine",
        save_steps=50,
        save_total_limit=50,
        seed=3407,
        report_to="wandb",
        run_name="gemma3-12b-lora",
        label_names=["labels"],
        # packing=True, # unsloth paking 버그로 인한 비활성화
        output_dir="outputs",
    ),
)

In [17]:
trainer.train(resume_from_checkpoint=None)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhiyo2044[0m ([33mhiyo2044-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,25.0719
20,24.2919
30,22.4506
40,20.0704
50,17.944
60,15.5618
70,12.9693
80,9.6825
90,6.6008
100,4.3729


TrainOutput(global_step=228, training_loss=8.390843784599973, metrics={'train_runtime': 11491.2618, 'train_samples_per_second': 0.318, 'train_steps_per_second': 0.02, 'total_flos': 2.507715295838208e+17, 'train_loss': 8.390843784599973})