- Script for training variants for comparison:
      - Full Fine-tuning
      - LoRA Training Only
      - QLoRA Training Only

In [1]:
%%capture
import sys
!{sys.executable} -m pip install pandas

In [None]:
import os
import numpy as np
import pandas as pd

In [3]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [4]:
%%capture
!{sys.executable} -m pip install bitsandbytes
!{sys.executable} -m pip install trl==0.9.6 peft

In [5]:
# %%capture
# !pip install bitsandbytes
# !pip uninstall -y transformers tokenizers huggingface_hub
# !pip install transformers -U #==4.45.2
# !pip install tokenizers==0.21.0
# !pip install huggingface_hub==0.24.6
# !pip install trl==0.9.6 peft==0.11.1 accelerate==1.0.0
# !pip uninstall -y pyarrow
# !pip install pyarrow==14.0.2 --force-reinstall

In [6]:
# def patch_accelerator():
#     try:
#         from accelerate import Accelerator
#         original_unwrap = Accelerator.unwrap_model
        
#         def patched_unwrap(self, model, keep_fp32_wrapper=True, keep_torch_compile=False):
#             try:
#                 return original_unwrap(self, model, keep_fp32_wrapper=keep_fp32_wrapper, keep_torch_compile=keep_torch_compile)
#             except TypeError:
#                 return original_unwrap(self, model, keep_fp32_wrapper=keep_fp32_wrapper)
        
#         Accelerator.unwrap_model = patched_unwrap
#         print("âœ“ Accelerator patched successfully")
#     except Exception as e:
#         print(f"Warning: Could not patch Accelerator: {e}")

# patch_accelerator()

In [7]:
import torch
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import logging
from dataclasses import dataclass
from huggingface_hub import login
from typing import Optional

In [8]:
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(
    level=logging.INFO,
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

In [None]:
@dataclass
class TaskConfig:
    full: bool = False
    lora: bool = False
    qlora: bool = False
    
    dataset_path: str = 'vohuutridung/Vietnamese-Legal-Chat-Dataset'
    dataset_split: str = 'train'
    model_name: str = 'vohuutridung/qwen3-1.7b-legal-pretrain'
    
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dtype: torch.dtype = torch.bfloat16
    
    train_batch_size: int = 2
    eval_batch_size: int = 8
    epochs: int = 3
    logging_steps: int = 50
    save_total_limit: int = 2
    lr: float = 5e-5
    eval_steps: int = 50
    eval_strategy: str = "steps"
    max_seq_length: int = 4096
    
    push_to_hub: bool = False
    hf_repo_name: Optional[str] = None

    def __post_init__(self):
        self.lr = 2e-4 if not self.full else 5e-5
        
        if self.push_to_hub and not self.hf_repo_name:
            raise ValueError("hf_repo_name is required when push_to_hub=True")

        if self.lora or self.qlora:
            self.lora_config = LoraConfig(
                r=16,
                lora_alpha=32,
                lora_dropout=0.1,
                target_modules='all-linear',
                bias="none",
                task_type="CAUSAL_LM",
            )
            
        if self.qlora:
            self.qlora_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
            )

    
class Stage2Trainer:
    def __init__(self, config: TaskConfig):
        self.config = config

        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.model = self.load_model()
        
        logger.info(f"Model loaded in mode: "
                    f"{'FULL' if config.full else 'LoRA' if config.lora else 'QLoRA'}")

    def load_model(self):
        if self.config.full:
            return AutoModelForCausalLM.from_pretrained(
                self.config.model_name,
                dtype=self.config.dtype,
            )
            
        if self.config.lora:
            model = AutoModelForCausalLM.from_pretrained(
                self.config.model_name,
                dtype=self.config.dtype,
            )
            model = get_peft_model(model, self.config.lora_config)
            model.print_trainable_parameters()

            return model

        if self.config.qlora:
            model = AutoModelForCausalLM.from_pretrained(
                self.config.model_name,
                quantization_config=self.config.qlora_config,
                device_map="auto",
            )
            model = prepare_model_for_kbit_training(model)
            model = get_peft_model(model, self.config.lora_config)
            model.print_trainable_parameters()

            return model

    
    def preprocess_function(self, examples):
        texts = []
        conversations = examples["conversations"]
        
        for conv in conversations:
            try:
                if not conv or len(conv) < 2:
                    logger.warning(f"Skipping empty or incomplete conversation: {conv}")
                    continue
                
                messages = []
                for msg in conv:
                    role = msg.get('from', '')
                    if role == 'human':
                        role = 'user'
                    elif role == 'gpt':
                        role = 'assistant'
                    else:
                        role = 'user' if len(messages) % 2 == 0 else 'assistant'
                    
                    content = msg.get('value', '')
                    if content:  
                        messages.append({
                            "role": role,
                            "content": content
                        })
                
                if len(messages) >= 2: 
                    text = self.tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=False
                    )
                    texts.append(text)
                else:
                    logger.warning(f"Not enough valid messages in conversation: {messages}")
                        
            except Exception as e:
                logger.error(f"Failed to process conversation: {e}")
                logger.error(f"Conversation: {conv}")
                continue
        
        return texts

    
    def prepare_dataset(self):
        dataset = load_dataset(self.config.dataset_path, split=self.config.dataset_split)
        dataset = dataset.shuffle(seed=42)
        logger.info('Dataset loaded and shuffled successfully.')
        logger.info(dataset)

        num_train = int(0.95 * len(dataset))
        train_dataset = dataset.select(range(num_train))
        eval_dataset = dataset.select(range(num_train, len(dataset)))
        logger.info(f'{len(train_dataset)} train samples.')
        logger.info(f'{len(eval_dataset)} eval samples.')

        return train_dataset, eval_dataset

    
    def train(self):
        logger.info('=' * 50)
        logger.info(f'Stepping into training process')
        logger.info('=' * 50)
        
        train_dataset, eval_dataset = self.prepare_dataset()
        
        args = SFTConfig(
            output_dir='./output',
            
            per_device_train_batch_size=self.config.train_batch_size,
            per_device_eval_batch_size=self.config.eval_batch_size,
            num_train_epochs=self.config.epochs,
            logging_steps=self.config.logging_steps,
            save_total_limit=self.config.save_total_limit,
            eval_steps=self.config.eval_steps,
            eval_strategy=self.config.eval_strategy,
            
            learning_rate=self.config.lr,

            max_seq_length=self.config.max_seq_length,

            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            report_to=None,
        )
        
        trainer = SFTTrainer(
            model=self.model,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            formatting_func=self.preprocess_function,
        )

        logger.info('Start training.')
        trainer.train()
        logger.info('Training finished.')

        if self.config.push_to_hub:
            # user_secrets = UserSecretsClient()
            # HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
            HF_TOKEN = 'HF_TOKEN'
            login(HF_TOKEN)
            
            trainer.model.push_to_hub(self.config.hf_repo_name)
            trainer.processing_class.push_to_hub(self.config.hf_repo_name)
            logger.info(f'Pushed to {self.config.hf_repo_name} successfully.')

In [10]:
# full_config = TaskConfig(
#     full=True,
#     train_batch_size=1,
#     eval_batch_size=4,
#     push_to_hub=True,
#     hf_repo_name='vohuutridung/3150-fullft-v2',
# )
# full_trainer = Stage2Trainer(full_config)
# full_trainer.train()

In [11]:
train_lora_config = TaskConfig(
    lora=True,
    train_batch_size=1,
    eval_batch_size=4,
    push_to_hub=True,
    hf_repo_name='vohuutridung/3150-lora-v2',
)
train_lora_trainer = Stage2Trainer(train_lora_config)
train_lora_trainer.train()

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/117 [00:00<?, ?B/s]

trainable params: 17,432,576 || all params: 1,738,007,552 || trainable%: 1.0030
INFO:__main__:Model loaded in mode: LoRA
INFO:__main__:Stepping into training process


README.md:   0%|          | 0.00/151 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.




data/train-00000-of-00001.parquet:   0%|          | 0.00/3.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3537 [00:00<?, ? examples/s]

INFO:__main__:Dataset loaded and shuffled successfully.
INFO:__main__:Dataset({
    features: ['conversations'],
    num_rows: 3537
})
INFO:__main__:3360 train samples.
INFO:__main__:177 eval samples.






Map:   0%|          | 0/3360 [00:00<?, ? examples/s]

Map:   0%|          | 0/177 [00:00<?, ? examples/s]

  super().__init__(


INFO:__main__:Start training.


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Step,Training Loss,Validation Loss
50,1.1854,1.032971
100,1.026,0.988788
150,0.9591,0.969743
200,0.9645,0.955773
250,0.9674,0.950011
300,0.9813,0.947552
350,0.9463,0.939271
400,0.9752,0.930669
450,0.8987,0.934949
500,0.955,0.922995


INFO:__main__:Training finished.


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...pnl8xa0z5/adapter_model.safetensors:   0%|          | 45.8kB / 69.8MB            

README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  /tmp/tmp71v0farm/tokenizer.json       :  97%|#########7| 11.1MB / 11.4MB            

INFO:__main__:Pushed to vohuutridung/3150-lora-v2 successfully.


In [12]:
# train_qlora_config = TaskConfig(
#     qlora=True,
#     train_batch_size=4,
#     eval_batch_size=4,
#     push_to_hub=True,
#     hf_repo_name='vohuutridung/3150-qlora',
# )
# train_qlora_trainer = Stage2Trainer(train_qlora_config)
# train_qlora_trainer.train()