In [None]:
# Enhanced Fine-tuning Typhoon 2.1 Gemma3 4B for Thai Medical Applications
# Model: scb10x/typhoon2.1-gemma3-4b
# Company: V89 Technology Ltd.
# Version: 5.0 - All Critical Fixes Applied for A100 40GB

print("Starting Enhanced Typhoon 2.1 Gemma3 4B Medical Fine-tuning - Fixed Version 5.0...")

import warnings
warnings.filterwarnings('ignore')
import os
import re
import sys
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Install required packages with correct versions for A100
!pip install -q --upgrade pip==24.0
!pip install -q --force-reinstall torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124
!pip install -q compressed-tensors>=0.7.0
!pip install -q transformers==4.50.0
!pip install -q accelerate==1.1.0
!pip install -q bitsandbytes==0.47.0
!pip install -q peft==0.12.0
!pip install -q datasets==2.20.0
!pip install -q evaluate==0.4.2
!pip install -q rouge-score==0.1.2
!pip install -q scikit-learn pandas numpy==2.0.2 scipy>=1.14.1
!pip install -q sacrebleu sentencepiece protobuf nltk
!pip install -q flash-attn --no-build-isolation

# Essential imports
import json
import gc
import torch
import transformers
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from typing import Dict, List, Any, Optional, Union
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
import math
import torch.optim as optim
from sacrebleu import corpus_bleu
from collections import defaultdict

# Updated imports for compatibility
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    GenerationConfig,
    get_linear_schedule_with_warmup,
    TrainerCallback,
    TrainerControl,
    TrainerState
)

from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    PeftModel,
    prepare_model_for_kbit_training
)

print(f"===== Version Check =====")
print(f"Python version: {sys.version}")
print("Transformers version:", transformers.__version__)
print(f"PyTorch version: {torch.__version__}")
print(f"Numpy version: {np.__version__}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name()}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print("=" * 40)

In [None]:
# Mount Google Drive
def setup_drive_connection():
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = "/content/drive/MyDrive/V89Technology/thai_medicalCare_dataset150"
        print("Google Drive connected successfully")
        return drive_path
    except:
        local_path = "V89Technology/thai_medicalCare_dataset150"
        os.makedirs(local_path, exist_ok=True)
        print("Running outside Colab - using local directory")
        return local_path

drive_path = setup_drive_connection()

# GPU Memory Management
class GPUMemoryManager:
    @staticmethod
    def clear_memory():
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            torch.cuda.reset_peak_memory_stats()

    @staticmethod
    def get_memory_info():
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3
            cached = torch.cuda.memory_reserved() / 1024**3
            total = torch.cuda.get_device_properties(0).total_memory / 1024**3
            return {
                "allocated_gb": allocated,
                "cached_gb": cached,
                "total_gb": total,
                "free_gb": total - allocated
            }
        return {"error": "CUDA not available"}

    @staticmethod
    def optimize_for_device():
        if torch.cuda.is_available():
            device_name = torch.cuda.get_device_name().lower()
            if "a100" in device_name:
                return {
                    "batch_size": 1,
                    "gradient_accumulation_steps": 8,
                    "max_length": 512,
                    "lora_r": 2,
                    "use_gradient_checkpointing": True,
                    "fp16": False,
                    "bf16": True
                }
            else:
                return {
                    "batch_size": 1,
                    "gradient_accumulation_steps": 16,
                    "max_length": 256,
                    "lora_r": 1,
                    "use_gradient_checkpointing": True,
                    "fp16": True,
                    "bf16": False
                }
        return {
            "batch_size": 1,
            "gradient_accumulation_steps": 32,
            "max_length": 256,
            "lora_r": 1,
            "use_gradient_checkpointing": True,
            "fp16": False,
            "bf16": False
        }

memory_manager = GPUMemoryManager()
memory_manager.clear_memory()
optimal_config = memory_manager.optimize_for_device()

print("Optimal configuration:")
for key, value in optimal_config.items():
    print(f"  {key}: {value}")

# Configuration for Typhoon 2.1 Gemma3 4B
config = {
    "model_name": "scb10x/typhoon2.1-gemma3-4b",
    "output_dir": "/content/drive/MyDrive/V89Technology/typhoon21-gemma3-4b-medCare-finetuned",
    "max_length": optimal_config["max_length"],
    "batch_size": optimal_config["batch_size"],
    "gradient_accumulation_steps": optimal_config["gradient_accumulation_steps"],
    "learning_rate": 2e-4,
    "num_epochs": 2,
    "warmup_ratio": 0.05,
    "logging_steps": 5,
    "save_steps": 10,
    "eval_steps": 10,
    "lora_r": optimal_config["lora_r"],
    "lora_alpha": optimal_config["lora_r"] * 2,
    "lora_dropout": 0.1,
    "use_gradient_checkpointing": optimal_config["use_gradient_checkpointing"],
    "fp16": optimal_config["fp16"],
    "bf16": optimal_config["bf16"],
    "weight_decay": 0.01,
    "lr_scheduler_type": "cosine",
    "seed": 42
}

print(f"Configuration ready. Output: {config['output_dir']}")

# FIXED: Compatible quantization configuration
def get_quantization_config():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16 if config["bf16"] else torch.float16,
        llm_int8_threshold=6.0,
    )
    return bnb_config

# FIXED: LoRA configuration for Gemma3
def get_lora_config():
    return LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=config["lora_r"],
        lora_alpha=config["lora_alpha"],
        lora_dropout=config["lora_dropout"],
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ],
        bias="none",
        inference_mode=False,
        init_lora_weights=True,
    )

quantization_config = get_quantization_config()
lora_config = get_lora_config()

# FIXED: Dataset loading with proper Thai encoding
def load_external_csv_datasets():
    datasets = {}
    dataset_files = {
        "mental_health": {
            "path": f"{drive_path}/mental_health_thai_150.csv",
            "columns": ["th_Context", "th_Response"]
        },
        "healthcare": {
            "path": f"{drive_path}/healthcare_thai_150.csv",
            "columns": ["th_instruction", "th_input", "th_output"]
        },
        "pubmed": {
            "path": f"{drive_path}/pubmed_thai_150.csv",
            "columns": ["th_input", "th_output", "th_instruction"]
        },
        "medical_qa": {
            "path": f"{drive_path}/medical_qa_thai_150.csv",
            "columns": ["th_instruction", "th_input", "th_output"]
        }
    }

    for dataset_name, dataset_config in dataset_files.items():
        try:
            print(f"Loading {dataset_name}...")
            df = pd.read_csv(dataset_config["path"], encoding='utf-8')
            print(f"  Loaded {len(df)} samples")

            samples = []
            for _, row in df.iterrows():
                try:
                    if dataset_name == "mental_health":
                        sample = {
                            "instruction": "ให้คำปรึกษาด้านสุขภาพจิต",
                            "input": str(row.get('th_Context', '')).strip(),
                            "output": str(row.get('th_Response', '')).strip(),
                            "dataset_type": dataset_name
                        }
                    else:
                        sample = {
                            "instruction": str(row.get('th_instruction', 'ให้คำแนะนำด้านการแพทย์')).strip(),
                            "input": str(row.get('th_input', '')).strip(),
                            "output": str(row.get('th_output', '')).strip(),
                            "dataset_type": dataset_name
                        }

                    # Validate sample
                    if (sample.get("instruction") and sample.get("output") and
                        len(sample["instruction"]) >= 5 and len(sample["output"]) >= 10):
                        samples.append(sample)

                except Exception as e:
                    continue

            datasets[dataset_name] = samples
            print(f"  {dataset_name}: {len(samples)} valid samples")

        except Exception as e:
            print(f"  Error loading {dataset_name}: {str(e)}")
            datasets[dataset_name] = []

    return datasets

def load_enhanced_medical_datasets():
    print("Loading datasets...")
    external_datasets = load_external_csv_datasets()

    all_samples = []
    for dataset_name, samples in external_datasets.items():
        all_samples.extend(samples)

    print(f"Total samples: {len(all_samples)}")

    # Add fallback if insufficient data
    if len(all_samples) < 50:
        print("Adding fallback samples...")
        fallback_samples = [
            {
                "instruction": "อธิบายอาการและการรักษาโรคทั่วไปในประเทศไทย",
                "input": "โรคไข้เลือดออกมีอาการอย่างไร",
                "output": "โรคไข้เลือดออกเกิดจากไวรัสเดงกี่ มีอาการไข้สูง ปวดหัว ปวดกล้ามเนื้อ คลื่นไส้อาเจียน ผื่นแดง อาจมีเลือดออกตามรูพรุ้น หากรุนแรงอาจเป็นไข้เลือดออกแบบช็อค ควรดื่มน้ำมากๆ พักผ่อน หลีกเลี่ยงยาแอสไพริน และรีบพบแพทย์ทันที",
                "dataset_type": "fallback"
            },
            {
                "instruction": "วิเคราะห์อาการและให้คำแนะนำเบื้องต้น",
                "input": "มีอาการปวดท้อง ท้องเสีย เป็นมา 2 วัน",
                "output": "อาการปวดท้องและท้องเสียอาจเกิดจากการติดเชื้อในทางเดินอาหาร การกินอาหารเป็นพิษ หรือความเครียด ควรดื่มน้ำสะอาดมากๆ กิน ORS หลีกเลี่ยงอาหารมัน เผ็ด หรือยา กินข้าวต้มหรืออาหารอ่อนๆ หากไม่ดีขึ้น ภายใน 1-2 วัน หรือมีไข้ ควรพบแพทย์",
                "dataset_type": "fallback"
            },
            {
                "instruction": "ให้คำปรึกษาด้านสุขภาพจิต",
                "input": "รู้สึกเครียด นอนไม่หลับ กังวลเกี่ยวกับงาน",
                "output": "ความเครียดจากงานเป็นเรื่องปกติ แต่ต้องจัดการอย่างเหมาะสม ควรหาเวลาพักผ่อน ออกกำลังกาย ทำสมาธิ หรือทำกิจกรรมที่ชอบ หลีกเลี่ยงคาเฟอีนก่อนนอน สร้างสภาพแวดล้อมที่เหมาะสำหรับการนอน หากอาการไม่ดีขึ้น ควรปรึกษาผู้เชี่ยวชาญด้านสุขภาพจิต",
                "dataset_type": "fallback"
            }
        ]
        all_samples.extend(fallback_samples)

    return {"combined_dataset": all_samples}

# Load datasets
print("Loading enhanced datasets...")
medical_datasets = load_enhanced_medical_datasets()

all_samples = []
for dataset_name, samples in medical_datasets.items():
    valid_samples = [s for s in samples if isinstance(s, dict) and s.get("output") and len(str(s["output"]).strip()) >= 10]
    all_samples.extend(valid_samples)

print(f"Total valid samples: {len(all_samples)}")

# Train/validation/test split
def create_split(samples, train_ratio=0.8, val_ratio=0.1):
    import random
    random.seed(config["seed"])
    random.shuffle(samples)

    total = len(samples)
    train_end = int(total * train_ratio)
    val_end = train_end + int(total * val_ratio)

    return samples[:train_end], samples[train_end:val_end], samples[val_end:]

train_samples, val_samples, test_samples = create_split(all_samples)

print(f"Training: {len(train_samples)}")
print(f"Validation: {len(val_samples)}")
print(f"Test: {len(test_samples)}")

if len(train_samples) < 10:
    raise ValueError("Insufficient training samples")

# FIXED: Load tokenizer with proper configuration
print("Loading model and tokenizer...")

tokenizer = AutoTokenizer.from_pretrained(
    config["model_name"],
    trust_remote_code=True,
    padding_side='right',
    use_fast=True,
)

# FIXED: Configure pad token properly
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    config["model_name"],
    quantization_config=quantization_config,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.bfloat16 if config["bf16"] else torch.float16,
    use_cache=False,
    attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
)

# Enable gradient checkpointing
if config["use_gradient_checkpointing"]:
    model.gradient_checkpointing_enable()

# FIXED: Prepare for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"Trainable: {trainable_params:,} || All: {all_param:,} || Ratio: {100 * trainable_params / all_param:.2f}%")

print_trainable_parameters(model)

# FIXED: Data preprocessing with proper tokenization
def format_training_sample(sample):
    instruction = str(sample.get("instruction", "")).strip()
    input_text = str(sample.get("input", "")).strip()
    output_text = str(sample.get("output", "")).strip()

    if input_text:
        prompt = f"### คำสั่ง:\n{instruction}\n\n### คำถาม:\n{input_text}\n\n### คำตอบ:\n{output_text}"
    else:
        prompt = f"### คำสั่ง:\n{instruction}\n\n### คำตอบ:\n{output_text}"

    return prompt

def preprocess_function(examples):
    # FIXED: Handle batch processing correctly
    formatted_texts = []
    for i in range(len(examples["instruction"])):
        sample = {
            "instruction": examples["instruction"][i],
            "input": examples.get("input", [""] * len(examples["instruction"]))[i],
            "output": examples["output"][i]
        }
        formatted_texts.append(format_training_sample(sample))

    # FIXED: Proper tokenization with attention masks
    tokenized = tokenizer(
        formatted_texts,
        truncation=True,
        padding="max_length",
        max_length=config["max_length"],
        return_tensors=None,
        add_special_tokens=True
    )

    # FIXED: Create proper labels and attention masks
    tokenized["labels"] = []
    tokenized["attention_mask"] = []

    for i, input_ids in enumerate(tokenized["input_ids"]):
        # Create attention mask
        attention_mask = [1 if token_id != tokenizer.pad_token_id else 0 for token_id in input_ids]
        tokenized["attention_mask"].append(attention_mask)

        # Create labels (same as input_ids for causal LM)
        labels = input_ids.copy()
        tokenized["labels"].append(labels)

    return tokenized

# FIXED: Convert to dataset format
def samples_to_dataset_dict(samples):
    return {
        "instruction": [s.get("instruction", "") for s in samples],
        "input": [s.get("input", "") for s in samples],
        "output": [s.get("output", "") for s in samples],
        "dataset_type": [s.get("dataset_type", "general") for s in samples]
    }

print("Creating datasets...")
train_dataset = Dataset.from_dict(samples_to_dataset_dict(train_samples))
val_dataset = Dataset.from_dict(samples_to_dataset_dict(val_samples))
test_dataset = Dataset.from_dict(samples_to_dataset_dict(test_samples))

print("Preprocessing...")
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Processing training data"
)

val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Processing validation data"
)

test_dataset = test_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=test_dataset.column_names,
    desc="Processing test data"
)

print(f"Processed - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# FIXED: Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8,
)

# FIXED: Compute metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    if isinstance(predictions, tuple):
        predictions = predictions[0]

    # Convert to numpy if needed
    if hasattr(predictions, 'cpu'):
        predictions = predictions.cpu().numpy()
    if hasattr(labels, 'cpu'):
        labels = labels.cpu().numpy()

    # Flatten labels and get mask for non-ignored tokens
    labels_flat = labels.flatten()
    mask = labels_flat != -100

    if mask.sum() == 0:
        return {"perplexity": float('inf'), "eval_loss": float('inf')}

    # Calculate perplexity from logits
    valid_predictions = predictions.reshape(-1, predictions.shape[-1])
    valid_labels = labels_flat[mask]

    # Calculate cross-entropy loss
    log_probs = F.log_softmax(torch.tensor(valid_predictions[mask]), dim=-1)
    nll_loss = F.nll_loss(log_probs, torch.tensor(valid_labels), reduction='mean')
    perplexity = torch.exp(nll_loss).item()

    return {
        "perplexity": min(perplexity, 10000.0),  # Cap perplexity
        "eval_loss": nll_loss.item()
    }

# FIXED: Training arguments
training_args = TrainingArguments(
    output_dir=config["output_dir"],
    overwrite_output_dir=True,
    num_train_epochs=config["num_epochs"],
    per_device_train_batch_size=config["batch_size"],
    per_device_eval_batch_size=config["batch_size"],
    gradient_accumulation_steps=config["gradient_accumulation_steps"],
    gradient_checkpointing=config["use_gradient_checkpointing"],
    learning_rate=config["learning_rate"],
    weight_decay=config["weight_decay"],
    lr_scheduler_type=config["lr_scheduler_type"],
    warmup_ratio=config["warmup_ratio"],
    logging_steps=config["logging_steps"],
    eval_strategy="steps",
    eval_steps=config["eval_steps"],
    save_strategy="steps",
    save_steps=config["save_steps"],
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=config["fp16"],
    bf16=config["bf16"],
    dataloader_pin_memory=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to=None,
    run_name="typhoon21_gemma3_4b_medical_fixed_v5",
    seed=config["seed"],
    data_seed=config["seed"],
    group_by_length=False,
    dataloader_drop_last=False,
    eval_accumulation_steps=1,
    prediction_loss_only=False,
    push_to_hub=False,
)

# Training callback
class EnhancedCallback(TrainerCallback):
    def __init__(self):
        self.training_started = False

    def on_train_begin(self, args, state, control, **kwargs):
        self.training_started = True
        print("Training Started!")
        print(f"Configuration: {args.num_train_epochs} epochs, batch size {args.per_device_train_batch_size}")

    def on_epoch_begin(self, args, state, control, **kwargs):
        current_epoch = int(state.epoch) + 1 if state.epoch is not None else 1
        print(f"\nEpoch {current_epoch}/{args.num_train_epochs}")
        memory_manager.clear_memory()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and self.training_started:
            step = logs.get('step', state.global_step)
            if 'loss' in logs:
                print(f"Step {step}: Loss = {logs['loss']:.4f}")
            if 'eval_loss' in logs:
                print(f"Step {step}: Eval Loss = {logs['eval_loss']:.4f}")

    def on_train_end(self, args, state, control, **kwargs):
        print("Training Completed!")

# Initialize trainer
print("Initializing trainer...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EnhancedCallback(),
        EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.001)
    ],
)

# FIXED: Generation function
def generate_response(model, tokenizer, instruction, input_text="", max_length=256, temperature=0.7, top_p=0.9):
    """Fixed generation with proper input handling"""
    if input_text:
        prompt = f"### คำสั่ง:\n{instruction}\n\n### คำถาม:\n{input_text}\n\n### คำตอบ:\n"
    else:
        prompt = f"### คำสั่ง:\n{instruction}\n\n### คำตอบ:\n"

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_length,
        padding=False
    )

    # FIXED: Ensure inputs are on correct device and dtype
    if torch.cuda.is_available():
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        # FIXED: Convert input_ids to long tensor to avoid dtype error
        inputs['input_ids'] = inputs['input_ids'].long()
        if 'attention_mask' in inputs:
            inputs['attention_mask'] = inputs['attention_mask'].long()

    model.eval()
    with torch.no_grad():
        generation_config = GenerationConfig(
            max_new_tokens=128,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,
            no_repeat_ngram_size=3,
        )

        outputs = model.generate(
            **inputs,
            generation_config=generation_config
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "### คำตอบ:\n" in response:
        response = response.split("### คำตอบ:\n")[1].strip()

    return response

def evaluate_model_performance(model, tokenizer, test_samples, num_samples=3):
    """Evaluate model performance"""
    print(f"Evaluating on {min(num_samples, len(test_samples))} samples...")

    results = []
    for i, sample in enumerate(test_samples[:num_samples]):
        instruction = sample.get("instruction", "")
        input_text = sample.get("input", "")
        expected = sample.get("output", "")

        print(f"\n--- Sample {i+1} ---")
        print(f"Instruction: {instruction}")
        if input_text:
            print(f"Input: {input_text}")
        print(f"Expected: {expected[:100]}...")

        try:
            generated = generate_response(
                model, tokenizer, instruction, input_text,
                max_length=config["max_length"], temperature=0.7
            )
            print(f"Generated: {generated[:100]}...")

            results.append({
                "instruction": instruction,
                "input": input_text,
                "expected": expected,
                "generated": generated
            })
        except Exception as e:
            print(f"Generation error: {e}")
            results.append({
                "instruction": instruction,
                "input": input_text,
                "expected": expected,
                "generated": f"Error: {str(e)}"
            })

    return results

# Pre-training evaluation
print("\n=== Pre-training Evaluation ===")
pre_results = evaluate_model_performance(model, tokenizer, test_samples, 2)

# Clear memory and start training
memory_manager.clear_memory()

print("\n=== Starting Training ===")
print("=" * 50)

try:
    # Train the model
    trainer.train()
    print("Training completed successfully!")

    # Save the final model
    print("\n=== Saving Model ===")
    trainer.save_model()
    tokenizer.save_pretrained(config["output_dir"])

    # Save training state
    trainer.save_state()

    print(f"Model saved to: {config['output_dir']}")

except Exception as e:
    print(f"Training failed: {str(e)}")
    import traceback
    traceback.print_exc()

    # Save partial progress if possible
    try:
        print("Attempting to save partial progress...")
        trainer.save_model(output_dir=f"{config['output_dir']}_partial")
        print("Partial model saved successfully")
    except Exception as save_error:
        print(f"Failed to save partial model: {str(save_error)}")

# Clear memory after training
memory_manager.clear_memory()

# Post-training evaluation
print("\n=== Post-training Evaluation ===")
try:
    # Reload the trained model for evaluation
    print("Loading trained model for evaluation...")

    # Load the fine-tuned model
    trained_model = PeftModel.from_pretrained(
        model,
        config["output_dir"],
        torch_dtype=torch.bfloat16 if config["bf16"] else torch.float16,
    )
    trained_model.eval()

    print("Evaluating trained model performance...")
    post_results = evaluate_model_performance(trained_model, tokenizer, test_samples, 5)

    # Compare results
    print("\n=== Training Results Comparison ===")
    print("Pre-training vs Post-training samples:")

    for i, (pre, post) in enumerate(zip(pre_results[:2], post_results[:2])):
        print(f"\n--- Comparison {i+1} ---")
        print(f"Instruction: {pre['instruction']}")
        if pre['input']:
            print(f"Input: {pre['input']}")
        print(f"Expected: {pre['expected'][:100]}...")
        print(f"Pre-training:  {pre['generated'][:100]}...")
        print(f"Post-training: {post['generated'][:100]}...")
        print("-" * 50)

except Exception as e:
    print(f"Post-training evaluation failed: {str(e)}")

# Model testing function
def test_medical_model(model, tokenizer):
    """Test the model with medical scenarios"""
    print("\n=== Medical Model Testing ===")

    test_cases = [
        {
            "instruction": "ให้คำแนะนำด้านการแพทย์",
            "input": "มีอาการไข้ ปวดหัว เจ็บคอ มา 2 วัน",
            "description": "Common cold symptoms"
        },
        {
            "instruction": "ให้คำปรึกษาด้านสุขภาพจิต",
            "input": "รู้สึกเครียดและกังวลมาก ทำงานไม่ได้",
            "description": "Mental health consultation"
        },
        {
            "instruction": "อธิบายเกี่ยวกับโรคและการรักษา",
            "input": "โรคเบาหวานคืออะไร",
            "description": "Diabetes explanation"
        },
        {
            "instruction": "ให้คำแนะนำการดูแลสุขภาพ",
            "input": "วิธีดูแลสุขภาพหัวใจให้แข็งแรง",
            "description": "Heart health advice"
        }
    ]

    for i, test_case in enumerate(test_cases):
        print(f"\n--- Test Case {i+1}: {test_case['description']} ---")
        print(f"Instruction: {test_case['instruction']}")
        print(f"Input: {test_case['input']}")

        try:
            response = generate_response(
                model, tokenizer,
                test_case['instruction'],
                test_case['input'],
                max_length=512,
                temperature=0.7
            )
            print(f"Response: {response}")
        except Exception as e:
            print(f"Error generating response: {str(e)}")

        print("-" * 80)

# Test the trained model
try:
    if 'trained_model' in locals():
        test_medical_model(trained_model, tokenizer)
    else:
        print("Trained model not available for testing")
except Exception as e:
    print(f"Model testing failed: {str(e)}")

# Save training metrics and results
def save_training_results():
    """Save training results and metrics"""
    print("\n=== Saving Training Results ===")

    try:
        results_data = {
            "config": config,
            "training_completed": True,
            "timestamp": datetime.now().isoformat(),
            "model_name": config["model_name"],
            "output_dir": config["output_dir"],
            "dataset_info": {
                "train_samples": len(train_samples),
                "val_samples": len(val_samples),
                "test_samples": len(test_samples),
                "total_samples": len(all_samples)
            },
            "memory_info": memory_manager.get_memory_info(),
            "optimal_config": optimal_config
        }

        # Add training history if available
        if hasattr(trainer, 'state') and trainer.state.log_history:
            results_data["training_history"] = trainer.state.log_history

        # Add evaluation results
        if 'pre_results' in locals():
            results_data["pre_training_samples"] = pre_results
        if 'post_results' in locals():
            results_data["post_training_samples"] = post_results

        # Save to JSON file
        results_file = f"{config['output_dir']}/training_results.json"
        os.makedirs(os.path.dirname(results_file), exist_ok=True)

        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(results_data, f, indent=2, ensure_ascii=False)

        print(f"Training results saved to: {results_file}")

        # Create summary report
        summary_file = f"{config['output_dir']}/training_summary.txt"
        with open(summary_file, 'w', encoding='utf-8') as f:
            f.write("=== Typhoon 2.1 Gemma3 4B Medical Fine-tuning Summary ===\n")
            f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Model: {config['model_name']}\n")
            f.write(f"Output Directory: {config['output_dir']}\n")
            f.write(f"Training Samples: {len(train_samples)}\n")
            f.write(f"Validation Samples: {len(val_samples)}\n")
            f.write(f"Test Samples: {len(test_samples)}\n")
            f.write(f"Epochs: {config['num_epochs']}\n")
            f.write(f"Learning Rate: {config['learning_rate']}\n")
            f.write(f"LoRA Rank: {config['lora_r']}\n")
            f.write(f"Max Length: {config['max_length']}\n")
            f.write(f"Batch Size: {config['batch_size']}\n")
            f.write(f"Gradient Accumulation Steps: {config['gradient_accumulation_steps']}\n")

            memory_info = memory_manager.get_memory_info()
            if "error" not in memory_info:
                f.write(f"GPU Memory Used: {memory_info['allocated_gb']:.2f} GB\n")
                f.write(f"GPU Memory Total: {memory_info['total_gb']:.2f} GB\n")

            f.write("\n=== Configuration ===\n")
            for key, value in config.items():
                f.write(f"{key}: {value}\n")

        print(f"Training summary saved to: {summary_file}")

    except Exception as e:
        print(f"Failed to save training results: {str(e)}")

save_training_results()

# Create inference function for deployment
def create_inference_function():
    """Create a standalone inference function"""
    print("\n=== Creating Inference Function ===")

    inference_code = '''
def load_trained_model(model_path):
    """Load the trained model for inference"""
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import PeftModel

    # Load base model and tokenizer
    base_model_name = "scb10x/typhoon2.1-gemma3-4b"
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )

    # Load fine-tuned weights
    model = PeftModel.from_pretrained(base_model, model_path)
    model.eval()

    return model, tokenizer

def medical_chat_inference(model, tokenizer, instruction, input_text="", max_length=256):
    """Generate medical advice response"""
    if input_text:
        prompt = f"### คำสั่ง:\\n{instruction}\\n\\n### คำถาม:\\n{input_text}\\n\\n### คำตอบ:\\n"
    else:
        prompt = f"### คำสั่ง:\\n{instruction}\\n\\n### คำตอบ:\\n"

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)

    if torch.cuda.is_available():
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        inputs['input_ids'] = inputs['input_ids'].long()
        if 'attention_mask' in inputs:
            inputs['attention_mask'] = inputs['attention_mask'].long()

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=128,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,
            no_repeat_ngram_size=3
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "### คำตอบ:\\n" in response:
        response = response.split("### คำตอบ:\\n")[1].strip()

    return response

# Example usage:
# model, tokenizer = load_trained_model("path/to/your/model")
# response = medical_chat_inference(model, tokenizer, "ให้คำแนะนำด้านการแพทย์", "มีอาการไข้")
    '''

    try:
        inference_file = f"{config['output_dir']}/inference.py"
        with open(inference_file, 'w', encoding='utf-8') as f:
            f.write(inference_code)
        print(f"Inference function saved to: {inference_file}")
    except Exception as e:
        print(f"Failed to save inference function: {str(e)}")

create_inference_function()

# Final cleanup and summary
print("\n" + "=" * 80)
print("=== FINE-TUNING COMPLETE ===")
print("=" * 80)

print(f"✓ Model: {config['model_name']}")
print(f"✓ Training samples: {len(train_samples)}")
print(f"✓ Validation samples: {len(val_samples)}")
print(f"✓ Test samples: {len(test_samples)}")
print(f"✓ Epochs: {config['num_epochs']}")
print(f"✓ Output directory: {config['output_dir']}")

memory_info = memory_manager.get_memory_info()
if "error" not in memory_info:
    print(f"✓ Final GPU memory usage: {memory_info['allocated_gb']:.2f} GB / {memory_info['total_gb']:.2f} GB")

print(f"✓ Model files saved successfully")
print(f"✓ Training results and metrics saved")
print(f"✓ Inference function created")

print("\n=== Files Created ===")
print(f"• Model files: {config['output_dir']}/")
print(f"• Training results: {config['output_dir']}/training_results.json")
print(f"• Training summary: {config['output_dir']}/training_summary.txt")
print(f"• Inference function: {config['output_dir']}/inference.py")

print("\n=== Next Steps ===")
print("1. Test the model with medical queries")
print("2. Deploy for inference using the provided inference.py")
print("3. Monitor performance and gather feedback")
print("4. Consider additional fine-tuning if needed")

# Final memory cleanup
memory_manager.clear_memory()

print("\n🎉 Enhanced Typhoon 2.1 Gemma3 4B Medical Fine-tuning Completed Successfully! 🎉")
print("=" * 80)