In [None]:
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM

dataset = load_dataset("s076923/llama3-wikibook-ko")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False
)

token = "hf_###..." # 토큰 입력
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
    token=token
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map={"": 0},
    token=token
)

tokenizer.pad_token = tokenizer.eos_token
model.config.use_cache = False

print(dataset)
print(dataset["train"]["text"][7])

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    r=128,
    lora_alpha=4,
    lora_dropout=0.1,
    task_type="CAUSAL_LM"
)

In [None]:
from transformers import TrainingArguments
from trl import SFTTrainer

training_args = TrainingArguments(
    output_dir="LLaMa-3.1",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=5,
    learning_rate=2e-4,
    max_steps=500,
    warmup_steps=100,
    logging_steps=100,
    fp16=True,
    optim="paged_adamw_8bit",
    seed=42
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=peft_config,
    train_dataset=dataset["train"],
    dataset_text_field="text",
    max_seq_length=64
)

trainer.train()

In [None]:
model.eval()

messages = [
    {"role": "user", "content": "위키북스 대표 저자는 누구예요?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

with torch.no_grad():
    outputs = model.generate(
        input_ids,
        max_new_tokens=64,
        do_sample=True,
        temperature=0.2,
        top_p=0.95,
        no_repeat_ngram_size=2
    )

response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))