## Library import

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import bitsandbytes as bnb
from trl import SFTTrainer


## Model, Tokenizer loading


In [None]:
# Base model(Mistral 7B)-Korean QA dataset tuned variant
base_model = 'davidkim205/komt-mistral-7b-v1'

# Quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

# Load model
model = AutoModelForCausalLM.from_pretrained("davidkim205/komt-mistral-7b-v1", quantization_config=bnb_config, device_map={"": 0})
model.config.use_cache = False 
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_eos_token = True

## LORA, PEFT

In [None]:
# Get linear layers
def find_all_linear_names(model):
  cls = bnb.nn.Linear4bit
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split('.')
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:
      lora_module_names.remove('lm_head')
  return list(lora_module_names)

# Lora
lora_config = LoraConfig(
    r=16, 
    lora_alpha=16,
    target_modules=find_all_linear_names(model),
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Peft
model = get_peft_model(prepare_model_for_kbit_training(model), lora_config)

In [None]:
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

## Load Dataset

<b> 본 논문에서 사용한 "AI-hub 상담음성 데이터"는 아래 링크에서 다운로드 받을 수 있습니다.</b>  
"https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=data&dataSetSn=100"

<b> AI-hub 데이터는 재가공하여 배포할 수 없기에, 본 논문 언어모델의 학습 및 검증에 사용 된 대화 데이터셋의 
이해를 위한 예시만 아래 링크에 공개합니다.</b>  
https://huggingface.co/datasets/zerothweek/ko-diarizationlm_example

In [None]:
dataset = load_dataset("zerothweek/ko-diarizationlm_example")

## Training

In [None]:
torch.cuda.empty_cache()
training_arguments = TrainingArguments(
    output_dir= "./results",
    num_train_epochs= 1,
    per_device_train_batch_size= 8,
    gradient_accumulation_steps= 2,
    optim = "paged_adamw_8bit",
    save_strategy="epoch",
    logging_steps= 1,
    learning_rate= 2e-4,
    weight_decay= 0.001,
    fp16= False,
    bf16= False,
    max_grad_norm= 0.3,
    max_steps= -1,
    warmup_ratio= 0.3,
    group_by_length= True,
    lr_scheduler_type= "constant",

)
# Setting sft parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    peft_config=lora_config,
    max_seq_length= None,
    dataset_text_field="chat_sample",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

In [None]:
trainer.train()

## Model save

In [None]:
# Save in local
new_model = 'ko-diarizationlm'
trainer.model.save_pretrained(new_model)
