In [None]:
import os
import sys
import numpy as np
import pandas as pd

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
%%capture
!pip install bitsandbytes
!pip uninstall -y transformers tokenizers huggingface_hub
!pip install transformers -U #==4.45.2
!pip install tokenizers==0.21.0
!pip install huggingface_hub==0.24.6
!pip install trl==0.9.6 peft==0.11.1 accelerate==1.0.0
!pip uninstall -y pyarrow
!pip install pyarrow==14.0.2 --force-reinstall

In [None]:
def patch_accelerator():
    try:
        from accelerate import Accelerator
        original_unwrap = Accelerator.unwrap_model
        
        def patched_unwrap(self, model, keep_fp32_wrapper=True, keep_torch_compile=False):
            try:
                return original_unwrap(self, model, keep_fp32_wrapper=keep_fp32_wrapper, keep_torch_compile=keep_torch_compile)
            except TypeError:
                return original_unwrap(self, model, keep_fp32_wrapper=keep_fp32_wrapper)
        
        Accelerator.unwrap_model = patched_unwrap
        print("âœ“ Accelerator patched successfully")
    except Exception as e:
        print(f"Warning: Could not patch Accelerator: {e}")

patch_accelerator()

In [None]:
import torch
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import logging
from dataclasses import dataclass
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
import gc

In [None]:
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(
    level=logging.INFO,
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

In [None]:
@dataclass
class TaskConfig:
    task_name: str
    dataset_path: str = 'vohuutridung/Vietnamese-Legal-Chat-Dataset'
    dataset_split: str = 'train'
    model_name: str = 'vohuutridung/qwen3-1.7b-legal-pretrain'
    dtype: torch.dtype = torch.bfloat16
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.1
    target_modules: list[str] | str = "all-linear"
    train_batch_size: int = 2
    eval_batch_size: int = 8
    epochs: int = 3
    logging_steps: int = 50
    save_total_limit: int = 2
    lr: float = 2e-4
    eval_steps: int = 50
    eval_strategy: str = "steps"
    max_seq_length: int = 4096
    push_to_hub: bool = False
    hf_repo_name: str = 'vohuutridung/stage1-mcq' # remember to change
    

class Stage1Trainer:
    def __init__(self, config: TaskConfig):
        self.config = config
        logger.info(f'Initializing trainer for task: {config.task_name}')

        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            config.model_name,
            dtype=config.dtype,
        )
        self.model.gradient_checkpointing_enable()
        self.model.config.use_cache = False
        logger.info('Model & Tokenizer loaded successfully.')

        lora_config = LoraConfig(
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            target_modules=config.target_modules,
            bias="none",
            task_type="CAUSAL_LM",
        )
        self.model = get_peft_model(self.model, lora_config)
        # self.model.enable_input_require_grads()
        self.model.print_trainable_parameters()

    def preprocess_function(self, examples):
        texts = []
        conversations = examples["conversations"]
        
        for conv in conversations:
            try:
                if not conv or len(conv) < 2:
                    logger.warning(f"Skipping empty or incomplete conversation: {conv}")
                    continue
                
                messages = []
                for msg in conv:
                    role = msg.get('from', '')
                    if role == 'human':
                        role = 'user'
                    elif role == 'gpt':
                        role = 'assistant'
                    else:
                        role = 'user' if len(messages) % 2 == 0 else 'assistant'
                    
                    content = msg.get('value', '')
                    if content:  
                        messages.append({
                            "role": role,
                            "content": content
                        })
                
                if len(messages) >= 2: 
                    text = self.tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=False
                    )
                    texts.append(text)
                else:
                    logger.warning(f"Not enough valid messages in conversation: {messages}")
                        
            except Exception as e:
                logger.error(f"Failed to process conversation: {e}")
                logger.error(f"Conversation: {conv}")
                continue
        
        return texts


    def prepare_dataset(self):
        total_dataset = load_dataset(self.config.dataset_path, split=self.config.dataset_split)
        
        if self.config.task_name == 'mcq': dataset = total_dataset.select(range(803))
        elif self.config.task_name == 'nli': dataset = total_dataset.select(range(803, 803+745))
        elif self.config.task_name == 'sqa': dataset = total_dataset.select(range(803+745, len(total_dataset)))
        logger.info('Dataset loaded successfully.')
        logger.info(dataset)
        
        num_train = int(0.95 * len(dataset))
        train_dataset = dataset.select(range(num_train))
        eval_dataset = dataset.select(range(num_train, len(dataset)))
        logger.info(f'{len(train_dataset)} train samples.')
        logger.info(f'{len(eval_dataset)} eval samples.')

        return train_dataset, eval_dataset

    
    def train(self):
        logger.info('=' * 50)
        logger.info(f'Starting training for task: {self.config.task_name}')
        logger.info('=' * 50)
        
        train_dataset, eval_dataset = self.prepare_dataset()
        
        args = SFTConfig(
            output_dir=f'./{self.config.task_name}',
            
            per_device_train_batch_size=self.config.train_batch_size,
            per_device_eval_batch_size=self.config.eval_batch_size,
            num_train_epochs=self.config.epochs,
            logging_steps=self.config.logging_steps,
            save_total_limit=self.config.save_total_limit,
            eval_steps=self.config.eval_steps,
            eval_strategy=self.config.eval_strategy,
            
            learning_rate=self.config.lr,
            max_seq_length=self.config.max_seq_length,

            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            report_to=["tensorboard"],

            gradient_checkpointing=True,
            optim="adamw_8bit",
            group_by_length=True,
        )
        
        trainer = SFTTrainer(
            model=self.model,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            formatting_func=self.preprocess_function,
        )

        gc.collect()
        torch.cuda.empty_cache()

        logger.info('Start training.')
        trainer.train()
        logger.info('Training finished.')
        
        adapter_dir = f'./{self.config.task_name}/lora'
        trainer.model.save_pretrained(adapter_dir)
        trainer.processing_class.save_pretrained(adapter_dir)
        logger.info(f"Saved LoRA adapter for task {self.config.task_name} at {adapter_dir}") 

        if self.config.push_to_hub:
            user_secrets = UserSecretsClient()
            HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
            login(HF_TOKEN)
            
            trainer.model.push_to_hub(self.config.hf_repo_name)
            trainer.processing_class.push_to_hub(self.config.hf_repo_name)
            logger.info(f'Pushed LoRA adapter for task {self.config.task_name} to {self.config.hf_repo_name}')

In [None]:
# mcq_config = TaskConfig(
#     task_name='mcq',
#     epochs=3,
#     push_to_hub=True,
#     hf_repo_name='vohuutridung/stage1-mcq-v3-3e',
# )

# mcq_trainer = Stage1Trainer(mcq_config)
# mcq_trainer.train()

In [None]:
# nli_config = TaskConfig(
#     task_name='nli',
#     epochs=2,
#     push_to_hub=True,
#     hf_repo_name='vohuutridung/stage1-nli-v2-2e',
# )

# nli_trainer = Stage1Trainer(nli_config)
# nli_trainer.train()

In [None]:
sqa_config = TaskConfig(
    task_name='sqa',
    train_batch_size=4,
    eval_batch_size=4,
    epochs=5,
    push_to_hub=True,
    hf_repo_name='vohuutridung/stage1-sqa-v2-5e',
)

sqa_trainer = Stage1Trainer(sqa_config)
sqa_trainer.train()