### GPU 초기화

In [1]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache()

### 라이브러리  import

In [2]:
import argparse

from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer

from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config
)

  from .autonotebook import tqdm as notebook_tqdm


### 훈련 Argument 정의

In [3]:
model_args = ModelConfig(
    model_name_or_path='../models/EXAONE-3.5-2.4B-Instruct/',
    torch_dtype="bfloat16",
    trust_remote_code=True,
    
    
    ####################
    # lora 
    ####################
    use_peft= True,
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    lora_target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head"
    ],
    lora_task_type="CAUSAL_LM",
    
    ####################
    # quantization
    ####################
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    use_bnb_nested_quant= False    
)

In [4]:
training_args = SFTConfig(
    output_dir = '../models/EXAONE-3.5-2.4B-Instruct-SFT/',
    eval_strategy='no',
    push_to_hub=False,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=16,
    optim="paged_adamw_32bit",
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    save_strategy="epoch",
    logging_steps=10,
    num_train_epochs=300,
    # max_steps=,
    fp16=True,
    packing = False,
    max_seq_length = 1024,
    # packing=True
)

In [5]:
script_args = ScriptArguments(
    dataset_name='../datasets/train_data/train_data_v2',
    dataset_config = None,
    dataset_train_split='train'
)

### 훈련 함수

In [6]:
def train_llm(script_args, training_args, model_args):
    quantization_config = get_quantization_config(model_args)
    
    model_kwargs = dict(
        revision = model_args.model_revision,
        trust_remote_code = model_args.trust_remote_code,
        attn_implementation = model_args.attn_implementation,
        torch_dtype = model_args.torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map = get_kbit_device_map() if quantization_config is not None else None,
        quantization_config = quantization_config,
    )
    
    training_args.model_init_kwargs = model_kwargs
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
    )
    tokenizer.pad_token = tokenizer.eos_token
    
    dataset = load_from_disk(script_args.dataset_name)
    # formatting_func 예시
    def formatting_prompt_func(example):
        output_texts = []
        system_message = "너는 병역이행법 상담을 도와주는 챗봇이야."
        for i in range(len(example['question'])):
            text = f"[|system|]{system_message}[|endofturn|]\n[|user|]{example['question'][i]}[|endofturn|]\n[|assistant|]{example['answer'][i]}[|endofturn|]"
            output_texts.append(text)
        return output_texts
    trainer = SFTTrainer(
        model = model_args.model_name_or_path,
        args = training_args,
        train_dataset = dataset[script_args.dataset_train_split],
        processing_class = tokenizer,
        peft_config = get_peft_config(model_args),
        formatting_func = formatting_prompt_func,
    )
    
    trainer.train()
    
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)

### 훈련

In [7]:
train_llm(script_args, training_args, model_args)



Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.25s/it]


{'loss': 34.1785, 'grad_norm': 9.935717582702637, 'learning_rate': 0.00019945218953682734, 'epoch': 5.0}




{'loss': 20.7715, 'grad_norm': 7.12577486038208, 'learning_rate': 0.00019781476007338058, 'epoch': 10.0}




{'loss': 14.4567, 'grad_norm': 5.240175247192383, 'learning_rate': 0.00019510565162951537, 'epoch': 15.0}




{'loss': 11.3841, 'grad_norm': 3.8306822776794434, 'learning_rate': 0.0001913545457642601, 'epoch': 20.0}




{'loss': 9.4934, 'grad_norm': 3.815941572189331, 'learning_rate': 0.00018660254037844388, 'epoch': 25.0}




{'loss': 7.9073, 'grad_norm': 5.515296936035156, 'learning_rate': 0.00018090169943749476, 'epoch': 30.0}




{'loss': 6.5738, 'grad_norm': 4.281012535095215, 'learning_rate': 0.00017431448254773944, 'epoch': 35.0}




{'loss': 5.2414, 'grad_norm': 5.9789299964904785, 'learning_rate': 0.00016691306063588583, 'epoch': 40.0}




{'loss': 4.332, 'grad_norm': 5.6618266105651855, 'learning_rate': 0.00015877852522924732, 'epoch': 45.0}




{'loss': 3.688, 'grad_norm': 7.096960067749023, 'learning_rate': 0.00015000000000000001, 'epoch': 50.0}




{'loss': 3.1189, 'grad_norm': 6.010748863220215, 'learning_rate': 0.00014067366430758004, 'epoch': 55.0}




{'loss': 2.6091, 'grad_norm': 5.626856327056885, 'learning_rate': 0.00013090169943749476, 'epoch': 60.0}




{'loss': 2.1953, 'grad_norm': 6.859731674194336, 'learning_rate': 0.00012079116908177593, 'epoch': 65.0}




{'loss': 1.8227, 'grad_norm': 6.231405258178711, 'learning_rate': 0.00011045284632676536, 'epoch': 70.0}




{'loss': 1.5237, 'grad_norm': 5.587397575378418, 'learning_rate': 0.0001, 'epoch': 75.0}




{'loss': 1.2991, 'grad_norm': 6.430514335632324, 'learning_rate': 8.954715367323468e-05, 'epoch': 80.0}




{'loss': 1.1344, 'grad_norm': 6.370896816253662, 'learning_rate': 7.920883091822408e-05, 'epoch': 85.0}




{'loss': 1.0037, 'grad_norm': 5.768457889556885, 'learning_rate': 6.909830056250527e-05, 'epoch': 90.0}




{'loss': 0.9137, 'grad_norm': 6.007272243499756, 'learning_rate': 5.932633569241999e-05, 'epoch': 95.0}




{'loss': 0.8484, 'grad_norm': 6.069665431976318, 'learning_rate': 5.000000000000002e-05, 'epoch': 100.0}




{'loss': 0.8029, 'grad_norm': 4.618069171905518, 'learning_rate': 4.12214747707527e-05, 'epoch': 105.0}




{'loss': 0.7687, 'grad_norm': 6.521856784820557, 'learning_rate': 3.308693936411421e-05, 'epoch': 110.0}




{'loss': 0.7413, 'grad_norm': 5.235042572021484, 'learning_rate': 2.5685517452260567e-05, 'epoch': 115.0}




{'loss': 0.7189, 'grad_norm': 6.197702407836914, 'learning_rate': 1.9098300562505266e-05, 'epoch': 120.0}




{'loss': 0.7089, 'grad_norm': 5.881902694702148, 'learning_rate': 1.339745962155613e-05, 'epoch': 125.0}




{'loss': 0.6951, 'grad_norm': 5.797318458557129, 'learning_rate': 8.645454235739903e-06, 'epoch': 130.0}




{'loss': 0.6876, 'grad_norm': 4.403503894805908, 'learning_rate': 4.8943483704846475e-06, 'epoch': 135.0}




{'loss': 0.6843, 'grad_norm': 4.8043317794799805, 'learning_rate': 2.1852399266194314e-06, 'epoch': 140.0}




{'loss': 0.6836, 'grad_norm': 5.980234146118164, 'learning_rate': 5.478104631726711e-07, 'epoch': 145.0}




{'loss': 0.6815, 'grad_norm': 5.510958194732666, 'learning_rate': 0.0, 'epoch': 150.0}


100%|██████████| 300/300 [17:45:22<00:00, 213.07s/it]


{'train_runtime': 63922.2707, 'train_samples_per_second': 1.535, 'train_steps_per_second': 0.005, 'train_loss': 4.722296509742737, 'epoch': 150.0}
