In [None]:
# Enhanced Fine-tuning iapp/chinda-qwen3-4b for Thai Medical Applications
# Company: V89 Technology Ltd.
# Version: 4.0 - FIXED All Critical Issues with A100 Optimization

print("Starting Enhanced Chinda-Qwen3-4B Medical Fine-tuning Setup v4.0...")

import warnings
warnings.filterwarnings('ignore')
import os
import re
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Compatible package versions for A100 with Python 3.12.11
!pip install -q --upgrade pip==24.0
!pip install -q --force-reinstall torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124
!pip install -q transformers==4.52.3 accelerate==1.1.0 bitsandbytes==0.43.3
!pip install -q datasets==2.20.0 evaluate==0.4.2 rouge-score==0.1.2
!pip install -q peft==0.12.0
!pip install -q scikit-learn pandas numpy==2.0.2 scipy>=1.14.1 fsspec filelock typing-extensions nltk
!pip install -q sacrebleu

# Import required libraries
import sys
import json
import gc
import torch
import transformers
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from typing import Dict, List, Any, Optional, Union
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
import math
import torch.optim as optim
from sacrebleu import corpus_bleu
from collections import defaultdict

# Updated imports for compatibility - FIXED accelerate unwrap_model issue
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    GenerationConfig,
    get_linear_schedule_with_warmup,
    TrainerCallback,
    TrainerControl,
    TrainerState,
    EvalPrediction
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_config,
    PeftType
)

print(f"===== Version Check =====")
print(f"Python version: {sys.version}")
print("Transformers version:", transformers.__version__)
print(f"PyTorch version: {torch.__version__}")
print(f"Numpy version: {np.__version__}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name()}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"==================")


In [None]:
# Mount Google Drive with error handling
def setup_drive_connection():
    """Setup Google Drive connection with fallback"""
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = "/content/drive/MyDrive/V89Technology/thai_medicalCare_dataset150"
        print("Google Drive connected successfully")
        return drive_path
    except:
        local_path = "V89Technology/thai_medicalCare_dataset150"
        os.makedirs(local_path, exist_ok=True)
        print("Running outside Colab - using local directory")
        return local_path

drive_path = setup_drive_connection()

# Enhanced GPU memory management with A100 optimizations
class GPUMemoryManager:
    """Advanced GPU memory management class with A100 optimizations"""

    @staticmethod
    def clear_memory():
        """Comprehensive GPU memory clearing"""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.reset_accumulated_memory_stats()

    @staticmethod
    def get_memory_info():
        """Get current GPU memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3
            cached = torch.cuda.memory_reserved() / 1024**3
            total = torch.cuda.get_device_properties(0).total_memory / 1024**3

            return {
                "allocated_gb": allocated,
                "cached_gb": cached,
                "total_gb": total,
                "free_gb": total - allocated
            }
        return {"error": "CUDA not available"}

    @staticmethod
    def optimize_for_device():
        """Optimized settings for A100 40GB"""
        if torch.cuda.is_available():
            device_name = torch.cuda.get_device_name().lower()
            memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3

            if "a100" in device_name:
                return {
                    "batch_size": 16,
                    "gradient_accumulation_steps": 4,
                    "max_length": 512,
                    "lora_r": 4,
                    "use_gradient_checkpointing": True,
                    "fp16": False,
                    "bf16": True
                }
            else:
                return {
                    "batch_size": 8,
                    "gradient_accumulation_steps": 8,
                    "max_length": 256,
                    "lora_r": 2,
                    "use_gradient_checkpointing": True,
                    "fp16": True,
                    "bf16": False
                }

        return {
            "batch_size": 4,
            "gradient_accumulation_steps": 8,
            "max_length": 256,
            "lora_r": 1,
            "use_gradient_checkpointing": True,
            "fp16": False,
            "bf16": False
        }

# Initialize memory manager and get optimal config
memory_manager = GPUMemoryManager()
memory_manager.clear_memory()
optimal_config = memory_manager.optimize_for_device()

print("Optimal configuration based on hardware:")
for key, value in optimal_config.items():
    print(f"  {key}: {value}")

# Configuration for Chinda-Qwen3-4B model
config = {
    "model_name": "iapp/chinda-qwen3-4b",
    "output_dir": "/content/drive/MyDrive/V89Technology/chinda-qwen3-4b-medical-finetuned",
    "max_length": optimal_config["max_length"],
    "batch_size": optimal_config["batch_size"],
    "gradient_accumulation_steps": optimal_config["gradient_accumulation_steps"],
    "learning_rate": 3e-4,
    "num_epochs": 3,
    "warmup_ratio": 0.05,
    "logging_steps": 5,
    "save_steps": 5,
    "eval_steps": 5,
    "lora_r": optimal_config["lora_r"],
    "lora_alpha": 8,
    "lora_dropout": 0.1,
    "use_gradient_checkpointing": optimal_config["use_gradient_checkpointing"],
    "fp16": optimal_config["fp16"],
    "bf16": optimal_config["bf16"],
    "weight_decay": 0.01,
    "lr_scheduler_type": "cosine",
    "train_ratio": 0.8,
    "test_ratio": 0.1,
    "val_ratio": 0.1,
    "seed": 42
}

print(f"Enhanced configuration ready. Output directory: {config['output_dir']}")

# Advanced quantization configuration
def get_advanced_quantization_config():
    """Get optimized quantization configuration for A100"""
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16 if config["bf16"] else torch.float16,
        llm_int8_threshold=6.0,
    )

# Advanced LoRA configuration for Qwen architecture
def get_advanced_lora_config():
    """Get optimized LoRA configuration for Chinda-Qwen3-4B"""
    return LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=config["lora_r"],
        lora_alpha=config["lora_alpha"],
        lora_dropout=config["lora_dropout"],
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ],
        bias="none",
        inference_mode=False,
        init_lora_weights=True,
        use_rslora=False,
    )

quantization_config = get_advanced_quantization_config()
lora_config = get_advanced_lora_config()

print("Advanced quantization and LoRA configuration ready")

# Load Thai medical datasets from CSV files - FIXED for Thai columns
def load_thai_medical_datasets():
    """Load Thai medical datasets from CSV files with proper Thai text handling"""
    datasets = {}

    csv_files = {
        "medmcqa": "medmcqa_thai_132.csv",
        "mental_health": "mental_health_thai_150.csv",
        "healthcare": "healthcare_thai_150.csv",
        "pubmed": "pubmed_thai_150.csv",
        "medical_qa": "medical_qa_thai_150.csv"
    }

    for dataset_name, filename in csv_files.items():
        try:
            file_path = os.path.join(drive_path, filename)
            print(f"Loading {dataset_name} from {file_path}")

            df = pd.read_csv(file_path, encoding='utf-8')
            samples = []

            if dataset_name == "medmcqa":
                for _, row in df.iterrows():
                    if pd.notna(row['th_question']) and pd.notna(row['th_exp']):
                        # Create multiple choice format
                        options = []
                        if pd.notna(row.get('th_opa')): options.append(f"ก) {row['th_opa']}")
                        if pd.notna(row.get('th_opb')): options.append(f"ข) {row['th_opb']}")
                        if pd.notna(row.get('th_opc')): options.append(f"ค) {row['th_opc']}")
                        if pd.notna(row.get('th_opd')): options.append(f"ง) {row['th_opd']}")

                        question_with_options = f"{row['th_question']}\n\n" + "\n".join(options)

                        samples.append({
                            "instruction": "ตอบคำถามทางการแพทย์และอธิบายเหตุผล",
                            "input": question_with_options,
                            "output": f"คำอธิบาย: {row['th_exp']}",
                            "dataset_type": "medical_qa"
                        })

            elif dataset_name == "mental_health":
                for _, row in df.iterrows():
                    if pd.notna(row['th_Context']) and pd.notna(row['th_Response']):
                        samples.append({
                            "instruction": "ให้คำปรึกษาด้านสุขภาพจิต",
                            "input": row['th_Context'],
                            "output": row['th_Response'],
                            "dataset_type": "mental_health"
                        })

            elif dataset_name == "healthcare":
                for _, row in df.iterrows():
                    if pd.notna(row['th_instruction']) and pd.notna(row['th_output']):
                        samples.append({
                            "instruction": row['th_instruction'],
                            "input": row.get('th_input', '') if pd.notna(row.get('th_input')) else '',
                            "output": row['th_output'],
                            "dataset_type": "healthcare"
                        })

            elif dataset_name == "pubmed":
                for _, row in df.iterrows():
                    if pd.notna(row['th_input']) and pd.notna(row['th_output']):
                        instruction = row.get('th_instruction', 'อธิบายความสัมพันธ์เชิงสาเหตุทางการแพทย์')
                        if pd.isna(instruction):
                            instruction = 'อธิบายความสัมพันธ์เชิงสาเหตุทางการแพทย์'

                        samples.append({
                            "instruction": instruction,
                            "input": row['th_input'],
                            "output": row['th_output'],
                            "dataset_type": "pubmed"
                        })

            elif dataset_name == "medical_qa":
                for _, row in df.iterrows():
                    if pd.notna(row['th_instruction']) and pd.notna(row['th_output']):
                        samples.append({
                            "instruction": row['th_instruction'],
                            "input": row.get('th_input', '') if pd.notna(row.get('th_input')) else '',
                            "output": row['th_output'],
                            "dataset_type": "medical_qa"
                        })

            datasets[dataset_name] = samples
            print(f"{dataset_name}: {len(samples)} valid samples loaded")

        except Exception as e:
            print(f"Error loading {dataset_name}: {e}")
            datasets[dataset_name] = []

    return datasets

# Load datasets
print("Loading Thai medical datasets from CSV files...")
medical_datasets = load_thai_medical_datasets()

# Combine all samples
all_samples = []
for dataset_name, samples in medical_datasets.items():
    all_samples.extend(samples)
    print(f"{dataset_name}: {len(samples)} samples")

print(f"Total samples loaded: {len(all_samples)}")

# Enhanced train/validation/test split - FIXED split ratios
def create_balanced_split(samples, train_ratio=0.8, test_ratio=0.1, val_ratio=0.1, seed=42):
    """Create balanced train/validation/test split with correct ratios"""
    import random
    random.seed(seed)
    np.random.seed(seed)

    # Shuffle all samples
    random.shuffle(samples)

    total_samples = len(samples)
    train_end = int(total_samples * train_ratio)
    test_end = train_end + int(total_samples * test_ratio)

    train_samples = samples[:train_end]
    test_samples = samples[train_end:test_end]
    val_samples = samples[test_end:]

    return train_samples, val_samples, test_samples

train_samples, val_samples, test_samples = create_balanced_split(
    all_samples,
    train_ratio=config["train_ratio"],
    test_ratio=config["test_ratio"],
    val_ratio=config["val_ratio"],
    seed=config["seed"]
)

print(f"Training samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")
print(f"Test samples: {len(test_samples)}")

# Validate that we have sufficient samples
if len(train_samples) < 10:
    raise ValueError("Insufficient training samples. Need at least 10 samples.")
if len(val_samples) < 2:
    raise ValueError("Insufficient validation samples. Need at least 2 samples.")
if len(test_samples) < 2:
    raise ValueError("Insufficient test samples. Need at least 2 samples.")

# Enhanced tokenizer and model loading with proper encoding - FIXED attention mask issue
print("Loading Chinda-Qwen3-4B model and tokenizer with optimizations...")

# Load tokenizer with advanced settings - FIXED pad token issue
tokenizer = AutoTokenizer.from_pretrained(
    config["model_name"],
    trust_remote_code=True,
    padding_side='right',
    use_fast=True,
)

# FIXED: Properly configure pad token to avoid attention mask issues
if tokenizer.pad_token is None:
    if tokenizer.unk_token is not None:
        tokenizer.pad_token = tokenizer.unk_token
        tokenizer.pad_token_id = tokenizer.unk_token_id
    else:
        tokenizer.add_special_tokens({'pad_token': '<|pad|>'})

# Ensure we have the required special tokens
if tokenizer.eos_token is None:
    tokenizer.add_special_tokens({'eos_token': '<|endoftext|>'})

print(f"Tokenizer configured - Pad token: {tokenizer.pad_token}, EOS token: {tokenizer.eos_token}")

print("Loading model with advanced optimizations...")

model = AutoModelForCausalLM.from_pretrained(
    config["model_name"],
    quantization_config=quantization_config,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.bfloat16 if config["bf16"] else torch.float16,
    use_cache=False,
    low_cpu_mem_usage=True,
)

# Enable gradient checkpointing for memory efficiency
if config["use_gradient_checkpointing"]:
    model.gradient_checkpointing_enable()

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=config["use_gradient_checkpointing"])

# Apply LoRA configuration
print("Applying LoRA configuration...")
model = get_peft_model(model, lora_config)

# Print trainable parameters
def print_trainable_parameters(model):
    """Print the number of trainable parameters in the model."""
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"Trainable params: {trainable_params:,} || All params: {all_param:,} || Trainable%: {100 * trainable_params / all_param:.2f}")

print_trainable_parameters(model)

# Enhanced data preprocessing with proper Thai text handling - FIXED tensor creation issue
def format_training_sample(sample):
    """Format sample for training with proper Thai text encoding"""
    instruction = str(sample.get("instruction", "")).strip()
    input_text = str(sample.get("input", "")).strip()
    output_text = str(sample.get("output", "")).strip()

    # Create Thai medical conversation format for Qwen
    if input_text and input_text != "":
        prompt = f"<|im_start|>system\nคุณเป็นผู้ช่วยทางการแพทย์ที่เชี่ยวชาญด้านภาษาไทย<|im_end|>\n<|im_start|>user\n{instruction}\n\n{input_text}<|im_end|>\n<|im_start|>assistant\n{output_text}<|im_end|>"
    else:
        prompt = f"<|im_start|>system\nคุณเป็นผู้ช่วยทางการแพทย์ที่เชี่ยวชาญด้านภาษาไทย<|im_end|>\n<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{output_text}<|im_end|>"

    return prompt

def preprocess_function(examples):
    """Enhanced preprocessing with proper tokenization and padding - FIXED tensor creation"""
    # Format all samples
    formatted_texts = []
    for i in range(len(examples["instruction"])):
        sample = {
            "instruction": examples["instruction"][i],
            "input": examples.get("input", [""] * len(examples["instruction"]))[i],
            "output": examples["output"][i]
        }
        formatted_texts.append(format_training_sample(sample))

    # Tokenize with proper settings and explicit padding - FIXED truncation
    tokenized = tokenizer(
        formatted_texts,
        truncation=True,
        padding="max_length",
        max_length=config["max_length"],
        return_tensors=None,
        add_special_tokens=False  # We already added them in format_training_sample
    )

    # FIXED: Proper attention mask handling
    attention_mask = []
    labels = []

    for i, input_ids in enumerate(tokenized["input_ids"]):
        # Create attention mask (1 for real tokens, 0 for padding)
        mask = [1 if token_id != tokenizer.pad_token_id else 0 for token_id in input_ids]
        attention_mask.append(mask)

        # Create labels (copy of input_ids, with -100 for padding tokens)
        label = [token_id if token_id != tokenizer.pad_token_id else -100 for token_id in input_ids]
        labels.append(label)

    tokenized["attention_mask"] = attention_mask
    tokenized["labels"] = labels

    return tokenized

# Convert samples to dataset format
def samples_to_dataset_dict(samples):
    """Convert samples to dataset format with validation"""
    dataset_dict = {
        "instruction": [],
        "input": [],
        "output": [],
        "dataset_type": []
    }

    for sample in samples:
        if isinstance(sample, dict):
            dataset_dict["instruction"].append(sample.get("instruction", ""))
            dataset_dict["input"].append(sample.get("input", ""))
            dataset_dict["output"].append(sample.get("output", ""))
            dataset_dict["dataset_type"].append(sample.get("dataset_type", "general"))

    return dataset_dict

print("Creating datasets...")
train_dataset_dict = samples_to_dataset_dict(train_samples)
val_dataset_dict = samples_to_dataset_dict(val_samples)
test_dataset_dict = samples_to_dataset_dict(test_samples)

# Create Dataset objects
train_dataset = Dataset.from_dict(train_dataset_dict)
val_dataset = Dataset.from_dict(val_dataset_dict)
test_dataset = Dataset.from_dict(test_dataset_dict)

# Apply preprocessing
print("Applying preprocessing...")
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Preprocessing training data"
)

val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Preprocessing validation data"
)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# FIXED: Enhanced data collator with proper padding
class CustomDataCollator:
    """Custom data collator to handle Thai text properly"""

    def __init__(self, tokenizer, pad_to_multiple_of=8):
        self.tokenizer = tokenizer
        self.pad_to_multiple_of = pad_to_multiple_of

    def __call__(self, features):
        batch = {}

        # Get max length in this batch
        max_len = max(len(f["input_ids"]) for f in features)
        if self.pad_to_multiple_of:
            max_len = ((max_len + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of

        # Pad all features to max length
        batch_input_ids = []
        batch_attention_mask = []
        batch_labels = []

        for f in features:
            input_ids = f["input_ids"]
            attention_mask = f["attention_mask"]
            labels = f["labels"]

            # Pad to max_len
            padding_length = max_len - len(input_ids)

            input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
            attention_mask = attention_mask + [0] * padding_length
            labels = labels + [-100] * padding_length

            batch_input_ids.append(input_ids)
            batch_attention_mask.append(attention_mask)
            batch_labels.append(labels)

        batch["input_ids"] = torch.tensor(batch_input_ids, dtype=torch.long)
        batch["attention_mask"] = torch.tensor(batch_attention_mask, dtype=torch.long)
        batch["labels"] = torch.tensor(batch_labels, dtype=torch.long)

        return batch

data_collator = CustomDataCollator(tokenizer=tokenizer, pad_to_multiple_of=8)

# Advanced evaluation metrics - FIXED BLEU score calculation
def compute_metrics(eval_pred):
    """Compute advanced metrics for medical fine-tuning evaluation - FIXED BLEU score"""
    predictions, labels = eval_pred

    # Ensure predictions and labels are numpy arrays
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    # Calculate perplexity
    predictions_flat = predictions.reshape(-1, predictions.shape[-1])
    labels_flat = labels.reshape(-1)

    # Filter out ignored tokens (typically -100)
    mask = labels_flat != -100
    if mask.sum() == 0:
        return {"perplexity": float('inf'), "eval_loss": float('inf'), "bleu": 0.0}

    predictions_filtered = predictions_flat[mask]
    labels_filtered = labels_flat[mask]

    # Calculate cross-entropy loss
    log_probs = F.log_softmax(torch.tensor(predictions_filtered, dtype=torch.float32), dim=-1)
    nll_loss = F.nll_loss(log_probs, torch.tensor(labels_filtered), reduction='mean')
    perplexity = torch.exp(nll_loss).item()

    # FIXED: Calculate BLEU score (0 <= BLEU <= 1)
    try:
        # Get predicted token ids
        pred_token_ids = np.argmax(predictions_filtered, axis=-1)

        # Decode predictions and references
        pred_texts = []
        ref_texts = []

        # Process in small batches to avoid memory issues
        batch_size = 16
        for i in range(0, len(pred_token_ids), batch_size):
            batch_pred = pred_token_ids[i:i+batch_size]
            batch_ref = labels_filtered[i:i+batch_size]

            try:
                pred_text = tokenizer.decode(batch_pred, skip_special_tokens=True).strip()
                ref_text = tokenizer.decode(batch_ref, skip_special_tokens=True).strip()

                if pred_text and ref_text:
                    pred_texts.append(pred_text)
                    ref_texts.append([ref_text])  # BLEU expects list of references
            except:
                continue

        if pred_texts and ref_texts:
            from sacrebleu import corpus_bleu
            bleu_score = corpus_bleu(pred_texts, ref_texts).score / 100.0  # Convert to 0-1 range
            bleu_score = min(max(bleu_score, 0.0), 1.0)  # Ensure 0 <= BLEU <= 1
        else:
            bleu_score = 0.0

    except Exception as e:
        print(f"BLEU calculation error: {e}")
        bleu_score = 0.0

    return {
        "perplexity": perplexity,
        "eval_loss": nll_loss.item(),
        "bleu": bleu_score
    }

# FIXED: Custom Trainer to handle accelerate unwrap_model issue
class CustomTrainer(Trainer):
    """Custom trainer to fix accelerate unwrap_model compatibility issue"""

    def _wrap_model(self, model, training=True, dataloader=None):
        """Override _wrap_model to fix unwrap_model compatibility"""
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.accelerator.prepare(model)
        else:
            model = self.accelerator.prepare(model)

        # Fix for accelerate compatibility - avoid keep_torch_compile parameter
        if hasattr(self.accelerator, '_models'):
            if model not in self.accelerator._models:
                return model

        return model

# Advanced training arguments with A100 optimizations - FIXED accelerate compatibility
training_args = TrainingArguments(
    output_dir=config["output_dir"],
    overwrite_output_dir=True,
    num_train_epochs=config["num_epochs"],
    per_device_train_batch_size=config["batch_size"],
    per_device_eval_batch_size=config["batch_size"],
    gradient_accumulation_steps=config["gradient_accumulation_steps"],
    gradient_checkpointing=config["use_gradient_checkpointing"],
    learning_rate=config["learning_rate"],
    weight_decay=config["weight_decay"],
    lr_scheduler_type=config["lr_scheduler_type"],
    warmup_ratio=config["warmup_ratio"],
    logging_steps=config["logging_steps"],
    eval_strategy="steps",
    eval_steps=config["eval_steps"],
    save_strategy="steps",
    save_steps=config["save_steps"],
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=config["fp16"],
    bf16=config["bf16"],
    dataloader_pin_memory=True,
    dataloader_num_workers=0,  # Reduced for stability
    remove_unused_columns=False,
    report_to=None,
    run_name="chinda_qwen3_4b_medical_v4.0",
    seed=config["seed"],
    data_seed=config["seed"],
    group_by_length=False,
    length_column_name="input_ids",
    ddp_find_unused_parameters=False,
    dataloader_drop_last=False,
    eval_accumulation_steps=1,
    prediction_loss_only=False,
    push_to_hub=False,
)

# Custom callback for enhanced monitoring
class EnhancedTrainingCallback(TrainerCallback):
    """Enhanced callback for monitoring training progress and GPU memory"""

    def __init__(self):
        self.training_started = False

    def on_train_begin(self, args, state, control, **kwargs):
        """Called at the beginning of training"""
        self.training_started = True
        print("🚀 Enhanced Chinda-Qwen3-4B Medical Fine-tuning Started!")
        print(f"📊 Training Configuration:")
        print(f"   • Total epochs: {args.num_train_epochs}")
        print(f"   • Batch size: {args.per_device_train_batch_size}")
        print(f"   • Gradient accumulation: {args.gradient_accumulation_steps}")
        print(f"   • Learning rate: {args.learning_rate}")
        print(f"   • Effective batch size: {args.per_device_train_batch_size * args.gradient_accumulation_steps}")

        # Display GPU memory info
        memory_info = memory_manager.get_memory_info()
        if "error" not in memory_info:
            print(f"   • GPU Memory: {memory_info['allocated_gb']:.1f}GB allocated, {memory_info['free_gb']:.1f}GB free")

    def on_epoch_begin(self, args, state, control, **kwargs):
        """Called at the beginning of each epoch"""
        current_epoch = int(state.epoch) + 1 if state.epoch is not None else 1
        print(f"\n📈 Starting Epoch {current_epoch}/{args.num_train_epochs}")

        # Clear memory at epoch start
        memory_manager.clear_memory()

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Called when logging"""
        if logs and self.training_started:
            step = logs.get('step', state.global_step)

            # Training metrics
            if 'loss' in logs:
                print(f"Step {step}: Loss = {logs['loss']:.4f}")

            # Evaluation metrics
            if 'eval_loss' in logs:
                eval_loss = logs['eval_loss']
                perplexity = logs.get('eval_perplexity', math.exp(eval_loss) if eval_loss < 10 else float('inf'))
                bleu = logs.get('eval_bleu', 0.0)
                print(f"📊 Evaluation at Step {step}:")
                print(f"   • Eval Loss: {eval_loss:.4f}")
                print(f"   • Perplexity: {perplexity:.2f}")
                print(f"   • BLEU Score: {bleu:.4f}")

                # GPU memory status
                memory_info = memory_manager.get_memory_info()
                if "error" not in memory_info:
                    print(f"   • GPU Memory: {memory_info['allocated_gb']:.1f}GB used")

    def on_epoch_end(self, args, state, control, **kwargs):
        """Called at the end of each epoch"""
        current_epoch = int(state.epoch) if state.epoch is not None else 1
        print(f"✅ Completed Epoch {current_epoch}")

        # Memory cleanup
        memory_manager.clear_memory()

    def on_train_end(self, args, state, control, **kwargs):
        """Called at the end of training"""
        print("🎉 Enhanced Chinda-Qwen3-4B Medical Fine-tuning Completed!")
        print(f"📊 Final Training Statistics:")
        print(f"   • Total steps completed: {state.global_step}")
        print(f"   • Best model saved at: {args.output_dir}")

        # Final memory cleanup
        memory_manager.clear_memory()

# FIXED: Enhanced trainer with optimized settings
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EnhancedTrainingCallback(),
        EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)
    ],
)

# Pre-training validation - FIXED training setup validation
def validate_training_setup():
    """Validate training setup before starting"""
    try:
        print("🔍 Pre-training System Checks:")

        # Check model
        if model is None:
            raise ValueError("Model not loaded")
        print(f"✅ Model loaded: {model.__class__.__name__}")

        # Check LoRA
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        if trainable_params == 0:
            raise ValueError("No trainable parameters found")
        print(f"✅ LoRA applied: {trainable_params:,} trainable parameters")

        # Check datasets
        if len(train_dataset) == 0:
            raise ValueError("No training samples")
        if len(val_dataset) == 0:
            raise ValueError("No validation samples")
        print(f"✅ Training samples: {len(train_dataset)}")
        print(f"✅ Validation samples: {len(val_dataset)}")

        # Check tokenizer
        if tokenizer.pad_token is None:
            raise ValueError("Tokenizer pad_token not set")
        print(f"✅ Tokenizer configured properly")

        # Test data collator with a small batch
        test_batch = [train_dataset[0], train_dataset[1] if len(train_dataset) > 1 else train_dataset[0]]
        collated_batch = data_collator(test_batch)

        if 'input_ids' not in collated_batch or 'attention_mask' not in collated_batch:
            raise ValueError("Data collator failed")
        print(f"✅ Data collator working properly")

        # Test model forward pass
        test_input = {k: v[:1] for k, v in collated_batch.items()}  # Take first sample
        test_input = {k: v.to(model.device) for k, v in test_input.items()}

        model.eval()
        with torch.no_grad():
            outputs = model(**test_input)
            if outputs.loss is None:
                raise ValueError("Model forward pass failed")
        print(f"✅ Model forward pass successful")

        model.train()  # Switch back to training mode

        print("✅ All pre-training checks passed!")
        return True

    except Exception as e:
        print(f"❌ Training setup validation failed: {str(e)}")
        return False

# Run validation
if not validate_training_setup():
    raise RuntimeError("Training setup validation failed. Aborting training.")

# Display memory usage before training
memory_info = memory_manager.get_memory_info()
if "error" not in memory_info:
    print(f"✅ GPU Memory: {memory_info['allocated_gb']:.1f}GB allocated, {memory_info['free_gb']:.1f}GB available")
else:
    print("⚠️ GPU not available - running on CPU")

# Enhanced training execution with error handling
try:
    print("\n🚀 Starting Enhanced Chinda-Qwen3-4B Medical Fine-tuning...")
    print("=" * 60)

    # Clear memory before training
    memory_manager.clear_memory()

    # Start training
    trainer.train()

    # Save the final model
    print("\n💾 Saving final model...")
    final_model_path = os.path.join(config["output_dir"], "final_model")
    trainer.save_model(final_model_path)
    tokenizer.save_pretrained(final_model_path)

    print(f"✅ Model saved to: {final_model_path}")

    # Save training configuration
    config_path = os.path.join(final_model_path, "training_config.json")
    with open(config_path, 'w', encoding='utf-8') as f:
        json.dump(config, f, indent=2, ensure_ascii=False)

    print(f"✅ Training configuration saved to: {config_path}")

except Exception as e:
    print(f"❌ Training error: {str(e)}")
    print("💡 Troubleshooting suggestions:")
    print("   • Reduce batch_size in config")
    print("   • Enable gradient_checkpointing")
    print("   • Reduce max_length")
    print("   • Check GPU memory availability")
    raise

# Post-training testing functions - FIXED perplexity calculation
def calculate_perplexity(model, tokenizer, test_samples):
    """Calculate perplexity on test dataset - FIXED calculation"""
    model.eval()
    total_loss = 0
    total_tokens = 0

    print("\n📊 Calculating Perplexity on Test Set...")

    # Process test samples
    test_losses = []
    for i, sample in enumerate(test_samples[:20]):  # Limit to 20 samples for efficiency
        try:
            # Format sample
            formatted_text = format_training_sample(sample)

            # Tokenize
            inputs = tokenizer(
                formatted_text,
                return_tensors="pt",
                truncation=True,
                max_length=config["max_length"],
                padding=True
            )

            # Move to device
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            inputs["labels"] = inputs["input_ids"].clone()

            # Forward pass
            with torch.no_grad():
                outputs = model(**inputs)
                loss = outputs.loss

            if torch.isfinite(loss):
                test_losses.append(loss.item())

        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            continue

    if test_losses:
        avg_loss = np.mean(test_losses)
        perplexity = math.exp(avg_loss)
        print(f"✅ Test Perplexity: {perplexity:.2f}")
        return perplexity
    else:
        print("❌ Failed to calculate perplexity")
        return float('inf')

def test_model_inference(model, tokenizer, test_samples, num_examples=3):
    """Test model inference on sample examples"""
    model.eval()

    generation_config = GenerationConfig(
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        max_new_tokens=256,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.1,
    )

    print("\n🧪 Model Inference Testing:")
    print("=" * 50)

    for i in range(min(num_examples, len(test_samples))):
        sample = test_samples[i]
        instruction = sample.get("instruction", "")
        input_text = sample.get("input", "")
        expected_output = sample.get("output", "")

        # Create prompt for generation (without the expected output)
        if input_text and input_text.strip():
            prompt = f"<|im_start|>system\nคุณเป็นผู้ช่วยทางการแพทย์ที่เชี่ยวชาญด้านภาษาไทย<|im_end|>\n<|im_start|>user\n{instruction}\n\n{input_text}<|im_end|>\n<|im_start|>assistant\n"
        else:
            prompt = f"<|im_start|>system\nคุณเป็นผู้ช่วยทางการแพทย์ที่เชี่ยวชาญด้านภาษาไทย<|im_end|>\n<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"

        print(f"\n🔍 Example {i+1}:")
        print(f"Instruction: {instruction}")
        if input_text:
            print(f"Input: {input_text}")

        try:
            inputs = tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=config["max_length"] // 2,
                padding=True
            ).to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    generation_config=generation_config,
                    use_cache=True
                )

            response = tokenizer.decode(
                outputs[0][inputs['input_ids'].shape[1]:],
                skip_special_tokens=True
            ).strip()

            print(f"Generated Response: {response}")
            print(f"Expected Response: {expected_output}")
            print("-" * 40)

        except Exception as e:
            print(f"Generation error: {str(e)}")
            print("-" * 40)

# Run post-training tests
try:
    print("\n🧪 Running post-training tests...")

    # Test perplexity measurement
    perplexity = calculate_perplexity(trainer.model, tokenizer, test_samples)
    print(f"📊 Perplexity Measurement Output: {perplexity:.2f}")

    # Test inference for 3 examples
    print("\n🔍 Testing Model Inference (3 examples):")
    test_model_inference(trainer.model, tokenizer, test_samples, num_examples=3)

except Exception as e:
    print(f"⚠️ Post-training testing failed: {str(e)}")

# Load and test the fine-tuned model - ADDITIONAL TEST
def load_and_test_finetuned_model():
    """Load the fine-tuned model and test it"""
    try:
        print("\n🔄 Loading fine-tuned model for testing...")

        final_model_path = os.path.join(config["output_dir"], "final_model")

        # Load the fine-tuned model
        test_model = AutoModelForCausalLM.from_pretrained(
            final_model_path,
            torch_dtype=torch.bfloat16 if config["bf16"] else torch.float16,
            device_map="auto",
            trust_remote_code=True
        )

        test_tokenizer = AutoTokenizer.from_pretrained(
            final_model_path,
            trust_remote_code=True
        )

        print("✅ Loaded fine-tuned model for testing")

        # Test with a sample medical question
        test_question = "อธิบายอาการและการรักษาโรคเบาหวาน"

        prompt = f"<|im_start|>system\nคุณเป็นผู้ช่วยทางการแพทย์ที่เชี่ยวชาญด้านภาษาไทย<|im_end|>\n<|im_start|>user\n{test_question}<|im_end|>\n<|im_start|>assistant\n"

        inputs = test_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(test_model.device)

        with torch.no_grad():
            outputs = test_model.generate(
                **inputs,
                do_sample=True,
                temperature=0.7,
                max_new_tokens=200,
                pad_token_id=test_tokenizer.pad_token_id,
                eos_token_id=test_tokenizer.eos_token_id
            )

        response = test_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

        print(f"✅ Fine-tuned model test successful!")
        print(f"Question: {test_question}")
        print(f"Response: {response.strip()}")

        return True

    except Exception as e:
        print(f"❌ Fine-tuned model test failed: {str(e)}")
        return False

# Test the fine-tuned model
load_and_test_finetuned_model()

# Final cleanup and summary
print("\n🎊 Enhanced Chinda-Qwen3-4B Medical Fine-tuning Summary:")
print("=" * 60)
print(f"✅ Model: {config['model_name']}")
print(f"✅ Training samples: {len(train_samples)}")
print(f"✅ Validation samples: {len(val_samples)}")
print(f"✅ Test samples: {len(test_samples)}")
print(f"✅ Train/Test/Val split: {config['train_ratio']}/{config['test_ratio']}/{config['val_ratio']}")
print(f"✅ Epochs completed: {config['num_epochs']}")
print(f"✅ LoRA rank: {config['lora_r']}")
print(f"✅ Learning rate: {config['learning_rate']}")
print(f"✅ Final model saved to: {os.path.join(config['output_dir'], 'final_model')}")

# Final memory cleanup
memory_manager.clear_memory()
print(f"✅ Memory cleanup completed")

print("\n🚀 Enhanced Chinda-Qwen3-4B Medical Fine-tuning Process Completed Successfully!")
print("📚 The model is now ready for Thai medical applications.")
print("💡 Remember to test thoroughly before production use.")

# Additional metrics and version fixes summary
print("\n🔧 Version 4.0 Fixes Applied:")
print("✅ Fixed: Accelerator.unwrap_model() compatibility issue")
print("✅ Fixed: Dataset processing with proper Thai text handling")
print("✅ Fixed: Attention mask configuration to avoid inference warnings")
print("✅ Fixed: BLEU score calculation (0 ≤ BLEU ≤ 1)")
print("✅ Fixed: Proper train/test/val split (80/10/10)")
print("✅ Fixed: Tensor creation with proper padding and truncation")
print("✅ Added: Perplexity measurement and output")
print("✅ Added: Model inference testing with 3 examples")
print("✅ Added: Comprehensive error handling and validation")