In [None]:
# Enhanced Fine-tuning OpenThaiGPT 1.5-7B for Thai Medical Applications
# Company: V89 Technology Ltd.
# Version: 4.0 - FIXED External Dataset Loading & Processing Issues

print("Starting Enhanced OpenThaiGPT Medical Fine-tuning Setup - FIXED EXTERNAL DATASET VERSION...")

import warnings
warnings.filterwarnings('ignore')
import os
import re
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

# Compatible package versions for A100 with Python 3.12.11 transformers>=4.44.2, CUDA 12.4
!pip install -q --upgrade pip==24.0
!pip install -q --force-reinstall torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124
!pip install -q transformers==4.44.2 accelerate==0.33.0 bitsandbytes==0.43.3
!pip install -q datasets==2.20.0 evaluate==0.4.2 rouge-score==0.1.2
!pip install -q peft==0.12.0
!pip install -q scikit-learn pandas numpy==2.0.2 scipy>=1.14.1 fsspec filelock typing-extensions nltk deepspeed
!pip install -q sacrebleu

# Import required libraries with fixed versions
import sys
import json
import gc
import torch
import transformers
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from datasets import Dataset, load_dataset, DatasetDict
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from typing import Dict, List, Any, Optional, Union
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
import math
import torch.optim as optim
from sacrebleu import corpus_bleu
from collections import defaultdict

# FIXED: Updated imports for compatibility
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    GenerationConfig,
    get_linear_schedule_with_warmup,
    TrainerCallback,
    TrainerControl,
    TrainerState,
    EvalPrediction
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_config,
    PeftType
)

print(f"===== Version Check =====")
print(f"Python version: {sys.version}")
print("Transformers version:", transformers.__version__)
print(f"PyTorch version: {torch.__version__}")
print(f"Numpy version: {np.__version__}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name()}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"==================")

In [None]:
# Mount Google Drive with error handling
def setup_drive_connection():
    """Setup Google Drive connection with fallback"""
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = "/content/drive/MyDrive/V89Technology/thai_medicalCare_dataset150"
        print("Google Drive connected successfully")
        return drive_path
    except:
        local_path = "V89Technology/thai_medicalCare_dataset150"
        os.makedirs(local_path, exist_ok=True)
        print("Running outside Colab - using local directory")
        return local_path

drive_path = setup_drive_connection()

# FIXED: Enhanced GPU memory management with A100 optimizations
class GPUMemoryManager:
    """Advanced GPU memory management class with A100 optimizations"""

    @staticmethod
    def clear_memory():
        """Comprehensive GPU memory clearing"""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.reset_accumulated_memory_stats()

    @staticmethod
    def get_memory_info():
        """Get current GPU memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3
            cached = torch.cuda.memory_reserved() / 1024**3
            total = torch.cuda.get_device_properties(0).total_memory / 1024**3

            return {
                "allocated_gb": allocated,
                "cached_gb": cached,
                "total_gb": total,
                "free_gb": total - allocated
            }
        return {"error": "CUDA not available"}

    @staticmethod
    def optimize_for_device():
        """FIXED: Optimized settings for A100 40GB based on training report"""
        if torch.cuda.is_available():
            device_name = torch.cuda.get_device_name().lower()
            memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3

            if "a100" in device_name:
                return {
                    "batch_size": 16,
                    "gradient_accumulation_steps": 4,
                    "max_length": 512,
                    "lora_r": 4,
                    "use_gradient_checkpointing": False,
                    "fp16": False,
                    "bf16": True
                }
            else:  # T4 or other GPUs
                return {
                    "batch_size": 1,
                    "gradient_accumulation_steps": 12,
                    "max_length": 256,
                    "lora_r": 1,
                    "use_gradient_checkpointing": True,
                    "fp16": True,
                    "bf16": False
                }
        return {
            "batch_size": 1,
            "gradient_accumulation_steps": 12,
            "max_length": 256,
            "lora_r": 1,
            "use_gradient_checkpointing": True,
            "fp16": False,
            "bf16": False
        }

# Initialize memory manager and get optimal config
memory_manager = GPUMemoryManager()
memory_manager.clear_memory()
optimal_config = memory_manager.optimize_for_device()

print("Optimal configuration based on hardware:")
for key, value in optimal_config.items():
    print(f"  {key}: {value}")

# FIXED: Configuration based on training report analysis
config = {
    "model_name": "openthaigpt/openthaigpt1.5-7b-instruct",
    "output_dir": "/content/drive/MyDrive/V89Technology/openthaigpt15-7b-medCare-finetuned",
    "max_length": optimal_config["max_length"],
    "batch_size": optimal_config["batch_size"],
    "gradient_accumulation_steps": optimal_config["gradient_accumulation_steps"],
    "learning_rate": 0.0003,
    "num_epochs": 3,
    "warmup_ratio": 0.05,
    "logging_steps": 5,
    "save_steps": 5,
    "eval_steps": 5,
    "lora_r": optimal_config["lora_r"],
    "lora_alpha": 8,
    "lora_dropout": 0.1,
    "use_gradient_checkpointing": optimal_config["use_gradient_checkpointing"],
    "fp16": optimal_config["fp16"],
    "bf16": optimal_config["bf16"],
    "weight_decay": 0.01,
    "lr_scheduler_type": "cosine",
    "seed": 42
}

print(f"Enhanced configuration ready. Output directory: {config['output_dir']}")

# FIXED: Advanced quantization configuration
def get_advanced_quantization_config():
    """Get optimized quantization configuration for A100"""
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        llm_int8_threshold=6.0,
    )

# FIXED: Advanced LoRA configuration matching training report
def get_advanced_lora_config():
    """Get optimized LoRA configuration based on training report"""
    return LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=config["lora_r"],
        lora_alpha=config["lora_alpha"],
        lora_dropout=config["lora_dropout"],
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ],
        bias="none",
        inference_mode=False,
        init_lora_weights=True,
        use_rslora=False,
    )

quantization_config = get_advanced_quantization_config()
lora_config = get_advanced_lora_config()

print("Advanced quantization and LoRA configuration ready")

# FIXED: Load external CSV datasets from Google Drive
def load_external_csv_datasets():
    """Load and process external CSV datasets with Thai text columns"""
    datasets = {}

    # Define dataset file paths and configurations
    dataset_files = {
        "medmcqa": {
            "path": f"{drive_path}/medmcqa_thai_132.csv",
            "thai_columns": ["th_question", "th_opa", "th_opb", "th_opc", "th_opd", "th_exp"],
            "instruction_template": "ตอบคำถามทางการแพทย์และอธิบายเหตุผล",
            "input_template": "คำถาม: {question}\nตัวเลือก:\nA) {opa}\nB) {opb}\nC) {opc}\nD) {opd}",
            "output_template": "คำตอบที่ถูกต้อง: {cop_letter}\nคำอธิบาย: {exp}"
        },
        "mental_health": {
            "path": f"{drive_path}/mental_health_thai_150.csv",
            "thai_columns": ["th_Context", "th_Response"],
            "instruction_template": "ให้คำปรึกษาด้านสุขภาพจิต",
            "input_template": "{context}",
            "output_template": "{response}"
        },
        "healthcare": {
            "path": f"{drive_path}/healthcare_thai_150.csv",
            "thai_columns": ["th_instruction", "th_input", "th_output"],
            "instruction_template": "{instruction}",
            "input_template": "{input}",
            "output_template": "{output}"
        },
        "pubmed": {
            "path": f"{drive_path}/pubmed_thai_150.csv",
            "thai_columns": ["th_input", "th_output", "th_instruction"],
            "instruction_template": "{instruction}",
            "input_template": "{input}",
            "output_template": "{output}"
        },
        "medical_qa": {
            "path": f"{drive_path}/medical_qa_thai_150.csv",
            "thai_columns": ["th_instruction", "th_input", "th_output"],
            "instruction_template": "{instruction}",
            "input_template": "{input}",
            "output_template": "{output}"
        }
    }

    for dataset_name, dataset_config in dataset_files.items():
        try:
            print(f"Loading {dataset_name} dataset...")

            # Read CSV file
            df = pd.read_csv(dataset_config["path"])
            print(f"  Loaded {len(df)} samples from {dataset_config['path']}")

            samples = []

            for _, row in df.iterrows():
                try:
                    # Use Thai columns for processing
                    sample_data = {}

                    if dataset_name == "medmcqa":
                        # Map choice number to letter
                        cop_mapping = {1: "A", 2: "B", 3: "C", 4: "D"}
                        cop_letter = cop_mapping.get(row.get('cop', 1), "A")

                        sample_data = {
                            "instruction": dataset_config["instruction_template"],
                            "input": dataset_config["input_template"].format(
                                question=row.get('th_question', ''),
                                opa=row.get('th_opa', ''),
                                opb=row.get('th_opb', ''),
                                opc=row.get('th_opc', ''),
                                opd=row.get('th_opd', '')
                            ),
                            "output": dataset_config["output_template"].format(
                                cop_letter=cop_letter,
                                exp=row.get('th_exp', '')
                            ),
                            "dataset_type": dataset_name
                        }

                    elif dataset_name == "mental_health":
                        sample_data = {
                            "instruction": dataset_config["instruction_template"],
                            "input": dataset_config["input_template"].format(
                                context=row.get('th_Context', '')
                            ),
                            "output": dataset_config["output_template"].format(
                                response=row.get('th_Response', '')
                            ),
                            "dataset_type": dataset_name
                        }

                    else:  # healthcare, pubmed, medical_qa
                        sample_data = {
                            "instruction": row.get('th_instruction', dataset_config["instruction_template"]),
                            "input": row.get('th_input', ''),
                            "output": row.get('th_output', ''),
                            "dataset_type": dataset_name
                        }

                    # Validate sample data
                    if (sample_data["instruction"] and sample_data["output"] and
                        len(str(sample_data["instruction"]).strip()) >= 5 and
                        len(str(sample_data["output"]).strip()) >= 10):
                        samples.append(sample_data)

                except Exception as e:
                    continue

            datasets[dataset_name] = samples
            print(f"  {dataset_name}: {len(samples)} valid samples processed")

        except Exception as e:
            print(f"  Error loading {dataset_name}: {str(e)}")
            datasets[dataset_name] = []

    return datasets

# FIXED: Enhanced dataset loading with proper Thai text encoding and validation
def load_enhanced_medical_datasets():
    """FIX: Load medical datasets with proper Thai text handling and validation"""
    print("Loading external CSV datasets...")
    external_datasets = load_external_csv_datasets()

    # Combine all external datasets
    all_samples = []
    for dataset_name, samples in external_datasets.items():
        all_samples.extend(samples)
        print(f"{dataset_name}: {len(samples)} samples")

    print(f"Total external samples: {len(all_samples)}")

    # Add fallback samples if external datasets are insufficient
    if len(all_samples) < 50:
        print("Adding fallback Thai medical samples...")

        fallback_samples = [
            {
                "instruction": "อธิบายอาการและการรักษาโรคทั่วไปในประเทศไทย",
                "input": "โรคไข้เลือดออกมีอาการอย่างไร",
                "output": "โรคไข้เลือดออกเกิดจากไวรัสเดงกี่ มีอาการไข้สูง ปวดหัว ปวดกล้ามเนื้อ คลื่นไส้อาเจียน ผื่นแดง อาจมีเลือดออกตามไรฟัน หากรุนแรงอาจเป็นไข้เลือดออกแบบช็อค ควรดื่มน้ำมากๆ พักผ่อน หลีกเลี่ยงยาแอสไพรัน และรีบพบแพทย์ทันที",
                "dataset_type": "thai_medical_fallback"
            },
            {
                "instruction": "วิเคราะห์อาการและให้คำแนะนำเบื้องต้น",
                "input": "มีอาการปวดท้อง ท้องเสีย เป็นมา 2 วัน",
                "output": "อาการปวดท้องและท้องเสียอาจเกิดจากการติดเชื้อในทางเดินอาหาร การกินอาหารเป็นพิษ หรือความเครียด ควรดื่มน้ำสะอาดมากๆ กิน ORS หลีกเลี่ยงอาหารมัน เผ็ด หรือยา กินข้าวต้มหรืออาหารอ่อนๆ หากไม่ดีขึ้น ภายใน 1-2 วัน หรือมีไข้ ควรพบแพทย์",
                "dataset_type": "thai_medical_fallback"
            },
            {
                "instruction": "ให้คำแนะนำด้านการป้องกันโรค",
                "input": "วิธีป้องกันโรคเบาหวาน",
                "output": "การป้องกันโรคเบาหวาน: 1) ควบคุมน้ำหนักให้อยู่ในเกณฑ์ปกติ 2) ออกกำลังกายสม่ำเสมออย่างน้อย 30 นาทีต่อวัน 3) รับประทานอาหารที่มีไฟเบอร์สูง ลดน้ำตาล ลดไขมัน 4) หลีกเลี่ยงการสูบบุหรี่และดื่มสุรา 5) ตรวจสุขภาพประจำปี 6) จัดการความเครียด",
                "dataset_type": "prevention_fallback"
            }
        ]

        all_samples.extend(fallback_samples)
        print(f"Added {len(fallback_samples)} fallback samples")

    return {"combined_dataset": all_samples}

# Load enhanced datasets
print("Loading enhanced medical datasets...")
medical_datasets = load_enhanced_medical_datasets()

# FIX: Combine and validate all samples
all_samples = []
for dataset_name, samples in medical_datasets.items():
    # Additional validation for each sample
    valid_samples = []
    for sample in samples:
        if (isinstance(sample, dict) and
            sample.get("instruction") and
            sample.get("output") and
            len(str(sample["instruction"]).strip()) >= 5 and
            len(str(sample["output"]).strip()) >= 10):
            valid_samples.append(sample)

    all_samples.extend(valid_samples)
    print(f"{dataset_name}: {len(valid_samples)} valid samples")

print(f"Total valid samples loaded: {len(all_samples)}")

# FIX: Enhanced train/validation/test split with proper ratios (80/10/10)
def create_balanced_split(samples, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    """FIX: Create balanced train/validation/test split with validation"""
    import random
    random.seed(config["seed"])

    # Shuffle all samples
    random.shuffle(samples)

    total_samples = len(samples)
    train_end = int(total_samples * train_ratio)
    val_end = train_end + int(total_samples * val_ratio)

    train_samples = samples[:train_end]
    val_samples = samples[train_end:val_end]
    test_samples = samples[val_end:]

    return train_samples, val_samples, test_samples

train_samples, val_samples, test_samples = create_balanced_split(all_samples)

print(f"Training samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")
print(f"Test samples: {len(test_samples)}")

# Validate that we have sufficient samples
if len(train_samples) < 10:
    raise ValueError("Insufficient training samples. Need at least 10 samples.")
if len(val_samples) < 5:
    raise ValueError("Insufficient validation samples. Need at least 5 samples.")
if len(test_samples) < 5:
    raise ValueError("Insufficient test samples. Need at least 5 samples.")

# FIX: Enhanced tokenizer and model loading with proper encoding
print("Loading OpenThaiGPT model and tokenizer with optimizations...")

# Load tokenizer with advanced settings
tokenizer = AutoTokenizer.from_pretrained(
    config["model_name"],
    trust_remote_code=True,
    padding_side='right',
    use_fast=True,
)

# FIX: Configure special tokens properly with attention mask handling
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print("Loading model with advanced optimizations...")

model = AutoModelForCausalLM.from_pretrained(
    config["model_name"],
    quantization_config=quantization_config,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.bfloat16 if config["bf16"] else torch.float16,
    use_cache=False,
)

# Enable gradient checkpointing for memory efficiency
if config["use_gradient_checkpointing"]:
    model.gradient_checkpointing_enable()

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# Apply LoRA configuration
print("Applying LoRA configuration...")
model = get_peft_model(model, lora_config)

# Print trainable parameters
def print_trainable_parameters(model):
    """Print the number of trainable parameters in the model."""
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"Trainable params: {trainable_params:,} || All params: {all_param:,} || Trainable%: {100 * trainable_params / all_param:.2f}")

print_trainable_parameters(model)

# FIX: Enhanced data preprocessing with proper Thai text handling and padding
def format_training_sample(sample):
    """FIX: Format sample for training with proper Thai text encoding"""
    instruction = str(sample.get("instruction", "")).strip()
    input_text = str(sample.get("input", "")).strip()
    output_text = str(sample.get("output", "")).strip()

    # Create Thai medical conversation format
    if input_text:
        prompt = f"### คำสั่ง:\n{instruction}\n\n### คำถาม:\n{input_text}\n\n### คำตอบ:\n{output_text}"
    else:
        prompt = f"### คำสั่ง:\n{instruction}\n\n### คำตอบ:\n{output_text}"

    return prompt

def preprocess_function(examples):
    """FIX: Enhanced preprocessing with proper tokenization and padding"""
    # Format all samples
    formatted_texts = []
    for i in range(len(examples["instruction"])):
        sample = {
            "instruction": examples["instruction"][i],
            "input": examples.get("input", [""] * len(examples["instruction"]))[i],
            "output": examples["output"][i]
        }
        formatted_texts.append(format_training_sample(sample))

    # Tokenize with proper settings and explicit padding
    tokenized = tokenizer(
        formatted_texts,
        truncation=True,
        padding="max_length",
        max_length=config["max_length"],
        return_tensors=None,
        add_special_tokens=True
    )

    # Set labels for causal language modeling
    tokenized["labels"] = tokenized["input_ids"].copy()

    return tokenized

# Convert samples to dataset format
def samples_to_dataset_dict(samples):
    """Convert samples to dataset format with validation"""
    dataset_dict = {
        "instruction": [],
        "input": [],
        "output": [],
        "dataset_type": []
    }

    for sample in samples:
        if isinstance(sample, dict):
            dataset_dict["instruction"].append(sample.get("instruction", ""))
            dataset_dict["input"].append(sample.get("input", ""))
            dataset_dict["output"].append(sample.get("output", ""))
            dataset_dict["dataset_type"].append(sample.get("dataset_type", "general"))

    return dataset_dict

print("Creating datasets...")
train_dataset_dict = samples_to_dataset_dict(train_samples)
val_dataset_dict = samples_to_dataset_dict(val_samples)
test_dataset_dict = samples_to_dataset_dict(test_samples)

# Create Dataset objects
train_dataset = Dataset.from_dict(train_dataset_dict)
val_dataset = Dataset.from_dict(val_dataset_dict)
test_dataset = Dataset.from_dict(test_dataset_dict)

# Apply preprocessing
print("Applying preprocessing...")
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Preprocessing training data"
)

val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Preprocessing validation data"
)

test_dataset = test_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=test_dataset.column_names,
    desc="Preprocessing test data"
)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# FIX: Enhanced data collator with proper padding
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8,
)

# FIX: Advanced evaluation metrics
def compute_metrics(eval_pred):
    """Compute advanced metrics for medical fine-tuning evaluation"""
    predictions, labels = eval_pred

    # Ensure predictions and labels are numpy arrays
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    # Calculate perplexity
    predictions = predictions.reshape(-1, predictions.shape[-1])
    labels = labels.reshape(-1)

    # Filter out ignored tokens (typically -100)
    mask = labels != -100
    if mask.sum() == 0:
        return {"perplexity": float('inf'), "eval_loss": float('inf')}

    predictions = predictions[mask]
    labels = labels[mask]

    # Calculate cross-entropy loss
    log_probs = F.log_softmax(torch.tensor(predictions), dim=-1)
    nll_loss = F.nll_loss(log_probs, torch.tensor(labels), reduction='mean')
    perplexity = torch.exp(nll_loss).item()

    return {
        "perplexity": perplexity,
        "eval_loss": nll_loss.item()
    }

# FIX: Advanced training arguments with A100 optimizations
training_args = TrainingArguments(
    output_dir=config["output_dir"],
    overwrite_output_dir=True,
    num_train_epochs=config["num_epochs"],
    per_device_train_batch_size=config["batch_size"],
    per_device_eval_batch_size=config["batch_size"],
    gradient_accumulation_steps=config["gradient_accumulation_steps"],
    gradient_checkpointing=config["use_gradient_checkpointing"],
    learning_rate=config["learning_rate"],
    weight_decay=config["weight_decay"],
    lr_scheduler_type=config["lr_scheduler_type"],
    warmup_ratio=config["warmup_ratio"],
    logging_steps=config["logging_steps"],
    eval_strategy="steps",
    eval_steps=config["eval_steps"],
    save_strategy="steps",
    save_steps=config["save_steps"],
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=config["fp16"],
    bf16=config["bf16"],
    dataloader_pin_memory=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to=None,
    run_name="enhanced_openthaigpt_medical_v4.0",
    seed=config["seed"],
    data_seed=config["seed"],
    group_by_length=False,
    length_column_name="input_ids",
    ddp_find_unused_parameters=False,
    dataloader_drop_last=False,
    eval_accumulation_steps=1,
    prediction_loss_only=False,
    push_to_hub=False,
)

# FIX: Custom callback for enhanced monitoring
class EnhancedTrainingCallback(TrainerCallback):
    """Enhanced callback for monitoring training progress and GPU memory"""

    def __init__(self):
        self.training_started = False

    def on_train_begin(self, args, state, control, **kwargs):
        """Called at the beginning of training"""
        self.training_started = True
        print("🚀 Enhanced OpenThaiGPT Medical Fine-tuning Started!")
        print(f"📊 Training Configuration:")
        print(f"   • Total epochs: {args.num_train_epochs}")
        print(f"   • Batch size: {args.per_device_train_batch_size}")
        print(f"   • Gradient accumulation: {args.gradient_accumulation_steps}")
        print(f"   • Learning rate: {args.learning_rate}")
        print(f"   • Effective batch size: {args.per_device_train_batch_size * args.gradient_accumulation_steps}")

        # Display GPU memory info
        memory_info = memory_manager.get_memory_info()
        if "error" not in memory_info:
            print(f"   • GPU Memory: {memory_info['allocated_gb']:.1f}GB allocated, {memory_info['free_gb']:.1f}GB free")

    def on_epoch_begin(self, args, state, control, **kwargs):
        """Called at the beginning of each epoch"""
        current_epoch = int(state.epoch) + 1 if state.epoch is not None else 1
        print(f"\n📈 Starting Epoch {current_epoch}/{args.num_train_epochs}")

        # Clear memory at epoch start
        memory_manager.clear_memory()

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Called when logging"""
        if logs and self.training_started:
            step = logs.get('step', state.global_step)

            # Training metrics
            if 'loss' in logs:
                print(f"Step {step}: Loss = {logs['loss']:.4f}")

            # Evaluation metrics
            if 'eval_loss' in logs:
                eval_loss = logs['eval_loss']
                perplexity = logs.get('eval_perplexity', math.exp(eval_loss))
                print(f"📊 Evaluation at Step {step}:")
                print(f"   • Eval Loss: {eval_loss:.4f}")
                print(f"   • Perplexity: {perplexity:.2f}")

                # GPU memory status
                memory_info = memory_manager.get_memory_info()
                if "error" not in memory_info:
                    print(f"   • GPU Memory: {memory_info['allocated_gb']:.1f}GB used")

    def on_epoch_end(self, args, state, control, **kwargs):
        """Called at the end of each epoch"""
        current_epoch = int(state.epoch) if state.epoch is not None else 1
        print(f"✅ Completed Epoch {current_epoch}")

        # Memory cleanup
        memory_manager.clear_memory()

    def on_train_end(self, args, state, control, **kwargs):
        """Called at the end of training"""
        print("🎉 Enhanced OpenThaiGPT Medical Fine-tuning Completed!")
        print(f"📊 Final Training Statistics:")
        print(f"   • Total steps completed: {state.global_step}")
        print(f"   • Best model saved at: {args.output_dir}")

        # Final memory cleanup
        memory_manager.clear_memory()

# FIX: Enhanced trainer with optimized settings
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EnhancedTrainingCallback(),
        EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)
    ],
)

# FIX: Pre-training system checks
print("🔍 Pre-training System Checks:")
print(f"✅ Model loaded: {model.__class__.__name__}")
print(f"✅ LoRA applied: {len([n for n, p in model.named_parameters() if p.requires_grad])} trainable parameters")
print(f"✅ Training samples: {len(train_dataset)}")
print(f"✅ Validation samples: {len(val_dataset)}")
print(f"✅ Test samples: {len(test_dataset)}")

# Display memory usage before training
memory_info = memory_manager.get_memory_info()
if "error" not in memory_info:
    print(f"✅ GPU Memory: {memory_info['allocated_gb']:.1f}GB allocated, {memory_info['free_gb']:.1f}GB available")
else:
    print("⚠️ GPU not available - running on CPU")

# FIX: Enhanced training execution with error handling
try:
    print("\n🚀 Starting Enhanced OpenThaiGPT Medical Fine-tuning...")
    print("=" * 60)

    # Clear memory before training
    memory_manager.clear_memory()

    # Start training
    trainer.train()

    # Save the final model
    print("\n💾 Saving final model...")
    final_model_path = os.path.join(config["output_dir"], "final_model")
    trainer.save_model(final_model_path)
    tokenizer.save_pretrained(final_model_path)

    print(f"✅ Model saved to: {final_model_path}")

    # Save training configuration
    config_path = os.path.join(final_model_path, "training_config.json")
    with open(config_path, 'w', encoding='utf-8') as f:
        json.dump(config, f, indent=2, ensure_ascii=False)

    print(f"✅ Training configuration saved to: {config_path}")

except Exception as e:
    print(f"❌ Training error: {str(e)}")
    print("💡 Troubleshooting suggestions:")
    print("   • Reduce batch_size in config")
    print("   • Enable gradient_checkpointing")
    print("   • Reduce max_length")
    print("   • Check GPU memory availability")
    raise

# FIX: Post-training model testing and perplexity measurement
def calculate_perplexity(model, tokenizer, dataset):
    """Calculate perplexity on test dataset"""
    model.eval()
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for i in range(len(dataset)):
            try:
                # Get input_ids and labels from the dataset
                input_ids = torch.tensor(dataset[i]["input_ids"]).unsqueeze(0).to(model.device)
                labels = torch.tensor(dataset[i]["labels"]).unsqueeze(0).to(model.device)

                # Forward pass
                outputs = model(input_ids, labels=labels)
                loss = outputs.loss

                # Count non-ignored tokens (labels != -100)
                valid_tokens = (labels != -100).sum().item()

                if valid_tokens > 0:
                    total_loss += loss.item() * valid_tokens
                    total_tokens += valid_tokens

            except Exception as e:
                continue

    if total_tokens > 0:
        avg_loss = total_loss / total_tokens
        perplexity = math.exp(avg_loss)
        return perplexity, avg_loss
    else:
        return float('inf'), float('inf')

def test_model_generation(model, tokenizer, test_samples, num_examples=3):
    """Test the fine-tuned model with sample prompts"""
    print(f"\n🧪 Testing Fine-tuned Model with {num_examples} examples:")
    print("=" * 60)

    model.eval()

    generation_config = GenerationConfig(
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        max_new_tokens=200,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    for i in range(min(num_examples, len(test_samples))):
        sample = test_samples[i]
        prompt = format_training_sample(sample)

        print(f"\n📝 Example {i+1}:")
        print(f"Input: {prompt[:100]}...")

        try:
            # Tokenize input
            inputs = tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=config["max_length"] // 2,
                padding=True
            ).to(model.device)

            # Generate response
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    generation_config=generation_config,
                    use_cache=True
                )

            # Decode response
            response = tokenizer.decode(
                outputs[0][inputs['input_ids'].shape[1]:],
                skip_special_tokens=True
            )

            print(f"Generated Output: {response.strip()}")
            print(f"Expected Output: {sample['output'][:100]}...")

        except Exception as e:
            print(f"Generation error: {str(e)}")

        print("-" * 50)

# Calculate perplexity on test dataset
print("\n📊 Calculating final perplexity...")
test_perplexity, test_loss = calculate_perplexity(model, tokenizer, test_dataset)
print(f"✅ Test Perplexity: {test_perplexity:.2f}")
print(f"✅ Test Loss: {test_loss:.4f}")

# Test model generation
test_model_generation(model, tokenizer, test_samples)

# FIX: Final memory cleanup and summary
memory_manager.clear_memory()
final_memory = memory_manager.get_memory_info()

print("\n🎯 Training Summary:")
print("=" * 60)
print(f"✅ Training completed successfully!")
print(f"📊 Final Test Perplexity: {test_perplexity:.2f}")
print(f"💾 Model saved to: {config['output_dir']}")
print(f"🧠 Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"📈 Training samples: {len(train_dataset)}")
print(f"📊 Validation samples: {len(val_dataset)}")
print(f"🧪 Test samples: {len(test_dataset)}")

if "error" not in final_memory:
    print(f"💾 GPU Memory usage: {final_memory['allocated_gb']:.1f}GB / {final_memory['total_gb']:.1f}GB")

print("\n🚀 Enhanced OpenThaiGPT Medical Fine-tuning Complete!")
print("=" * 60)