In [1]:
import os
os.environ['HF_DATASETS_CACHE'] = '/workspace/.cache/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/.cache/huggingface'

In [2]:
import torch
from datasets import load_dataset

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

import huggingface_hub



In [3]:
huggingface_hub.login('<Your Access Token>')

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
torch.cuda.get_device_capability()[0]

8

In [5]:
dataset = load_dataset('yainage90/lawtalk-qna-instruction')
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 49572
    })
    test: Dataset({
        features: ['text'],
        num_rows: 2610
    })
})

In [6]:
# base_model = 'beomi/Llama-3-Open-Ko-8B'
base_model = 'beomi/Llama-3-Open-Ko-8B-Instruct-preview'

# 토크나이저 로드
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# QLoRA config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=quant_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

sft_config = SFTConfig(
    output_dir="./model_ckpt",
    num_train_epochs=10,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    save_strategy='epoch',
    eval_strategy='epoch',
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    dataset_text_field="text",
    max_seq_length=1024,
    push_to_hub=True,
    hub_model_id='yainage90/Llama3-open-Ko-3-8B-Law-Chat',
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    peft_config=peft_config,
    tokenizer=tokenizer,
    args=sft_config,
    packing=False,
)

Map:   0%|          | 0/49572 [00:00<?, ? examples/s]

Map:   0%|          | 0/2610 [00:00<?, ? examples/s]

In [8]:
trainer.train(resume_from_checkpoint=True)

  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
  checkpoint_rng_state = torch.load(rng_file)


Epoch,Training Loss,Validation Loss
6,1.4704,1.608304
7,1.4288,1.583976
8,1.3725,1.562549
9,1.3614,1.546551
10,1.3123,1.537955


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

No files have been modified since last commit. Skipping to prevent empty commit.


TrainOutput(global_step=123930, training_loss=0.6921270694454097, metrics={'train_runtime': 256970.3978, 'train_samples_per_second': 1.929, 'train_steps_per_second': 0.482, 'total_flos': 1.065378167100747e+19, 'train_loss': 0.6921270694454097, 'epoch': 10.0})