In [1]:
%%capture
import sys
!{sys.executable} -m pip install pandas
!{sys.executable} -m pip install rich

In [2]:
import os
import numpy as np
import pandas as pd

In [3]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [4]:
%%capture
!{sys.executable} -m pip install trl==0.9.6 peft

In [5]:
import torch
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import logging
from dataclasses import dataclass
from huggingface_hub import login
from peft import PeftModel

In [6]:
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(
    level=logging.INFO,
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

In [None]:
@dataclass
class TaskConfig:
    dataset_path: str = 'vohuutridung/Vietnamese-Legal-Chat-Dataset'
    dataset_split: str = 'train'
    model_name: str = 'vohuutridung/qwen3-1.7b-legal-pretrain'
    adapter_name: str = None # merged adapter name

    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dtype: torch.dtype = torch.bfloat16

    train_batch_size: int = 2
    eval_batch_size: int = 4
    epochs: int = 2
    logging_steps: int = 100
    save_total_limit: int = 2
    eval_steps: int = 100
    eval_strategy: str = "steps"
    lr: float = 5e-5
    max_seq_length: int = 4096

    push_to_hub: bool = False
    hf_repo_name: str = 'vohuutridung/aaa' # remember to change
    

class Stage2Trainer:
    def __init__(self, config: TaskConfig):
        self.config = config

        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.config.model_name,
            dtype=self.config.dtype,
        )
        self.model = PeftModel.from_pretrained(self.model, config.adapter_name)
        self.model = self.model.merge_and_unload()
        self.model.requires_grad_(True)
        self.model.train()
        logger.info('Merged adapter to base model successfully.')

    
    def preprocess_function(self, examples):
        texts = []
        conversations = examples["conversations"]
        
        for conv in conversations:
            try:
                if not conv or len(conv) < 2:
                    logger.warning(f"Skipping empty or incomplete conversation: {conv}")
                    continue
                
                messages = []
                for msg in conv:
                    role = msg.get('from', '')
                    if role == 'human':
                        role = 'user'
                    elif role == 'gpt':
                        role = 'assistant'
                    else:
                        role = 'user' if len(messages) % 2 == 0 else 'assistant'
                    
                    content = msg.get('value', '')
                    if content:  
                        messages.append({
                            "role": role,
                            "content": content
                        })
                
                if len(messages) >= 2: 
                    text = self.tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=False
                    )
                    texts.append(text)
                else:
                    logger.warning(f"Not enough valid messages in conversation: {messages}")
                        
            except Exception as e:
                logger.error(f"Failed to process conversation: {e}")
                logger.error(f"Conversation: {conv}")
                continue
        
        return texts

    
    def prepare_dataset(self):
        dataset = load_dataset(self.config.dataset_path, split=self.config.dataset_split)
        dataset = dataset.shuffle(seed=42)
        logger.info('Dataset loaded and shuffled successfully.')
        logger.info(dataset)

        num_eval = min(int(0.05 * len(dataset)), 100)
        train_dataset = dataset.select(range(len(dataset) - num_eval))
        eval_dataset = dataset.select(range(len(dataset) - num_eval, len(dataset)))
        logger.info(f'{len(train_dataset)} train samples.')
        logger.info(f'{len(eval_dataset)} eval samples.')

        return train_dataset, eval_dataset

    
    def train(self):
        logger.info('=' * 50)
        logger.info(f'Stepping into training process')
        logger.info('=' * 50)
        
        train_dataset, eval_dataset = self.prepare_dataset()
        
        args = SFTConfig(
            output_dir='./output',
            
            per_device_train_batch_size=self.config.train_batch_size,
            per_device_eval_batch_size=self.config.eval_batch_size,
            num_train_epochs=self.config.epochs,
            logging_steps=self.config.logging_steps,
            save_total_limit=self.config.save_total_limit,
            eval_steps=self.config.eval_steps,
            eval_strategy=self.config.eval_strategy,
            
            learning_rate=self.config.lr,

            max_seq_length=self.config.max_seq_length,

            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            report_to=None,
        )
        
        trainer = SFTTrainer(
            model=self.model,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            formatting_func=self.preprocess_function,
        )

        logger.info('Start training.')
        trainer.train()
        logger.info('Training finished.')

        if self.config.push_to_hub:
            # user_secrets = UserSecretsClient()
            # HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
            HF_TOKEN = 'HF_TOKEN'
            login(HF_TOKEN)
            
            trainer.model.push_to_hub(self.config.hf_repo_name)
            trainer.processing_class.push_to_hub(self.config.hf_repo_name)
            logger.info(f'Pushed to {self.config.hf_repo_name} successfully.')

In [8]:
s2_config = TaskConfig(
    adapter_name='vohuutridung/merged3-v9',
    train_batch_size=1,
    eval_batch_size=4,
    push_to_hub=True,
    hf_repo_name='vohuutridung/stage2-v9',
)
s2_trainer = Stage2Trainer(s2_config)
s2_trainer.train()

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/117 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/920 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/69.8M [00:00<?, ?B/s]

INFO:__main__:Merged adapter to base model successfully.
INFO:__main__:Stepping into training process


README.md:   0%|          | 0.00/151 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.




data/train-00000-of-00001.parquet:   0%|          | 0.00/3.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3537 [00:00<?, ? examples/s]

INFO:__main__:Dataset loaded and shuffled successfully.
INFO:__main__:Dataset({
    features: ['conversations'],
    num_rows: 3537
})
INFO:__main__:3437 train samples.
INFO:__main__:100 eval samples.






Map:   0%|          | 0/3437 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

  super().__init__(


INFO:__main__:Start training.


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Step,Training Loss,Validation Loss
100,0.8942,0.839403
200,0.8247,0.843395
300,0.7954,0.836788
400,0.8255,0.826142
500,0.78,0.820136
600,0.7807,0.815505
700,0.7888,0.81045
800,0.7784,0.803366
900,0.7854,0.797314
1000,0.7881,0.792683


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


INFO:__main__:Training finished.


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

INFO:__main__:Pushed to vohuutridung/stage2-v9 successfully.
