# XiYan-SQL Training on Google Colab

This notebook provides a complete step-by-step guide to train the XiYan-SQL model on Google Colab.

## Prerequisites
- Upload your model files to Google Drive (e.g., `Qwen2.5-Coder-3B-Instruct` folder)
- Upload your dataset files to Google Drive (raw data, processed data, or both)
- Enable GPU runtime in Colab (Runtime ‚Üí Change runtime type ‚Üí GPU)

## Step 1: Install Dependencies

Install all required packages for XiYan-SQL training.

In [None]:
# Install system dependencies
!apt-get update -qq
!apt-get install -y -qq libaio-dev  # Required for DeepSpeed

print("üì¶ Installing Python packages...")
print("‚ö†Ô∏è  Note: Installing in specific order to avoid numpy/DeepSpeed conflicts.\n")

# Install DeepSpeed AFTER numpy and torch are set
print("\nüîß Installing DeepSpeed (may show some warnings)...")
!pip install -q --disable-pip-version-check --no-cache-dir deepspeed

# Install remaining packages
!pip install -q --disable-pip-version-check llama-index>=0.9.6.post2
!pip install -q --disable-pip-version-check modelscope>=1.33.0
!pip install -q --disable-pip-version-check mysql-connector-python>=9.5.0
!pip install -q --disable-pip-version-check "protobuf>=6.33.3"
!pip install -q --disable-pip-version-check psycopg2-binary>=2.9.11
!pip install -q --disable-pip-version-check swanlab>=0.7.6
!pip install -q --disable-pip-version-check textdistance>=4.6.3
!pip install -q --disable-pip-version-check jedi>=0.16

# Install flash-attn (optional, for faster attention)
print("\nüî® Attempting to install flash-attn (this may take a few minutes)...")
import subprocess
result = subprocess.run(
    ["pip", "install", "-q", "--no-build-isolation", "flash-attn"],
    capture_output=True,
    text=True
)
if result.returncode == 0:
    print("‚úÖ flash-attn installed successfully")
else:
    print("‚ö†Ô∏è  flash-attn installation failed (this is optional, continuing without it)")

print("\n‚úÖ Core dependencies installed!")
print("\nüí° If you see numpy warnings, they are expected and won't affect training.")

# Verify installation
print("\nüîç Verifying installation...")
try:
    import torch
    import transformers
    import accelerate
    import deepspeed
    import peft
    import numpy as np

    print(f"‚úÖ PyTorch: {torch.__version__}")
    print(f"‚úÖ Transformers: {transformers.__version__}")
    print(f"‚úÖ Accelerate: {accelerate.__version__}")
    print(f"‚úÖ DeepSpeed: {deepspeed.__version__}")
    print(f"‚úÖ PEFT: {peft.__version__}")
    print(f"‚úÖ NumPy: {np.__version__}")
    print(f"‚úÖ CUDA Available: {torch.cuda.is_available()}")

    if torch.cuda.is_available():
        print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
        gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"‚úÖ GPU Memory: {gpu_mem:.1f} GB")

        if gpu_mem >= 14:
            print("\nüéØ Your GPU has 15GB+ memory - perfect for optimized training!")
        elif gpu_mem >= 10:
            print("\nüìä Your GPU has 12GB memory - good for moderate training.")
        else:
            print("\n‚ö†Ô∏è  Your GPU has limited memory - training will use conservative settings.")
    else:
        print("\n‚ùå No GPU detected! Make sure to enable GPU in Runtime ‚Üí Change runtime type")

    print("\nüöÄ Ready to proceed!")

except ImportError as e:
    print(f"\n‚ùå Import error: {e}")
    print("\nüîÑ If you see numpy errors, restart runtime and run this cell again.")
    print("   Go to: Runtime ‚Üí Restart runtime")

## Step 1.5: Login to SwanLab (Optional but Recommended)

SwanLab will automatically track your training metrics, including:
- Training loss
- Learning rate
- GPU/CPU/RAM usage
- All hyperparameters
- Model checkpoints

**Get your API key:** https://swanlab.cn (sign up/login, then go to Settings ‚Üí API Key)

In [None]:
!swanlab login

## Step 2: Clone Repository

Clone the XiYan-SQL repository to Colab.

In [None]:
# Change to content directory
import os
import sys
os.chdir('/content')

# Clone the repository
# Replace with your repository URL
REPO_URL = "https://github.com/rezaarrazi/XiYan-SQL.git"  # ‚ö†Ô∏è UPDATE THIS

if not os.path.exists('XiYan-SQL'):
    os.system(f'git clone {REPO_URL}')
    print("‚úÖ Repository cloned successfully")
else:
    print("‚úÖ Repository already exists")

# Navigate to training directory
os.chdir('XiYan-SQL/XiYan-SQLTraining')

# Add to Python path so imports work correctly
TRAINING_DIR = os.getcwd()
if TRAINING_DIR not in sys.path:
    sys.path.insert(0, TRAINING_DIR)
if os.path.dirname(TRAINING_DIR) not in sys.path:
    sys.path.insert(0, os.path.dirname(TRAINING_DIR))

print(f"\nüìÅ Current directory: {os.getcwd()}")
print(f"‚úÖ Python path configured")

## Step 3: Mount Google Drive

Mount your Google Drive to access model and dataset files.

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

print("‚úÖ Google Drive mounted successfully")
print("\nüìÇ Drive path: /content/drive/MyDrive")

## Step 4: Copy Model from Google Drive

Copy your pre-downloaded model from Google Drive to the model directory.

**Configured Path:** `My Drive/Xiyan-SQL/Models/Qwen/`

The script will automatically detect and copy the model folder(s) from this location.

In [None]:
import shutil
import os

# Path to your model in Google Drive
MODEL_DRIVE_PATH = "/content/drive/MyDrive/Xiyan-SQL/Models/Qwen"

# Target directory in the repository
MODEL_TARGET_DIR = "train/model/Qwen"

# Create target directory if it doesn't exist
os.makedirs(MODEL_TARGET_DIR, exist_ok=True)

# Check if model directory exists in Drive
if os.path.exists(MODEL_DRIVE_PATH):
    print(f"üì• Found model directory at {MODEL_DRIVE_PATH}")
    
    # List contents to see what's inside
    contents = os.listdir(MODEL_DRIVE_PATH)
    print(f"üìÅ Contents: {contents}")
    
    # Check if it's a single model folder or contains multiple model folders
    model_folders = [item for item in contents if os.path.isdir(os.path.join(MODEL_DRIVE_PATH, item))]
    
    if len(model_folders) == 1:
        # Single model folder - copy it directly
        model_name = model_folders[0]
        source_path = os.path.join(MODEL_DRIVE_PATH, model_name)
        target_path = os.path.join(MODEL_TARGET_DIR, model_name)
        
        if os.path.exists(target_path):
            print(f"‚ö†Ô∏è  Model already exists at {target_path}")
            print("Skipping copy (delete manually if you want to re-copy)")
        else:
            print(f"üì• Copying model '{model_name}' from {source_path}...")
            shutil.copytree(source_path, target_path)
            print(f"‚úÖ Model copied to {target_path}")
        
        MODEL_PATH = target_path
    else:
        # Multiple folders or files - copy the entire Qwen directory
        target_path = MODEL_TARGET_DIR
        if os.path.exists(target_path) and os.listdir(target_path):
            print(f"‚ö†Ô∏è  Model directory already exists at {target_path}")
            print("Skipping copy (delete manually if you want to re-copy)")
        else:
            print(f"üì• Copying all models from {MODEL_DRIVE_PATH}...")
            for item in contents:
                source_item = os.path.join(MODEL_DRIVE_PATH, item)
                target_item = os.path.join(target_path, item)
                if os.path.isdir(source_item):
                    if not os.path.exists(target_item):
                        shutil.copytree(source_item, target_item)
                        print(f"  ‚úÖ Copied {item}")
                else:
                    if not os.path.exists(target_item):
                        shutil.copy2(source_item, target_item)
                        print(f"  ‚úÖ Copied {item}")
            print(f"‚úÖ All models copied to {target_path}")
        
        # Set MODEL_PATH to the first model folder found, or let user specify
        if model_folders:
            MODEL_PATH = os.path.join(MODEL_TARGET_DIR, model_folders[0])
            print(f"\nüìå Using model: {MODEL_PATH}")
            print(f"üí° If you want to use a different model, update MODEL_PATH in Step 7")
        else:
            MODEL_PATH = MODEL_TARGET_DIR
            print(f"\nüìå Model directory: {MODEL_PATH}")
            print(f"üí° Please specify the exact model folder name in Step 7")
    
    print(f"\nüìå Model path for training: {MODEL_PATH}")
else:
    print(f"‚ùå Model not found at {MODEL_DRIVE_PATH}")
    print("\nPlease check:")
    print("1. Google Drive is mounted correctly")
    print("2. The path 'My Drive/Xiyan-SQL/Models/Qwen/' exists in your Drive")
    MODEL_PATH = None

## Step 5: Verify Training Dataset

The English training dataset should already be in the repository (via Git LFS).

**Expected file:** `train/datasets/nl2sql_standard_train_en.json` (55MB)

If the file is not present, you can download it from Google Drive as a backup.

In [None]:
import os
import json

# Check if training dataset exists in repository
TRAIN_DATASET_PATH = "train/datasets/nl2sql_standard_train_en.json"

if os.path.exists(TRAIN_DATASET_PATH):
    print(f"‚úÖ Training dataset found in repository!")
    print(f"   Path: {TRAIN_DATASET_PATH}")
    
    size_mb = os.path.getsize(TRAIN_DATASET_PATH) / (1024 * 1024)
    print(f"   Size: {size_mb:.1f} MB")
    
    # Quick verification
    with open(TRAIN_DATASET_PATH, 'r') as f:
        data = json.load(f)
        print(f"   Samples: {len(data)}")
        
        # Check if English
        if data and data[0].get('conversations'):
            prompt = data[0]['conversations'][0]['content']
            if prompt.startswith("You are a SQLite expert"):
                print(f"   Language: ‚úÖ English")
            else:
                print(f"   Language: ‚ö†Ô∏è Not English")
    
    print("\nüéâ Ready to start training! Skip to Step 6.")
    
else:
    print(f"‚ö†Ô∏è  Training dataset not found in repository")
    print(f"   Expected: {TRAIN_DATASET_PATH}")
    print("\nüì• Attempting to download from Google Drive as backup...")
    
    # Backup: Download from Google Drive
    DRIVE_DATASET_PATH = "/content/drive/MyDrive/Xiyan-SQL/Dataset/nl2sql_standard_train_en.json"
    
    if os.path.exists(DRIVE_DATASET_PATH):
        import shutil
        os.makedirs("train/datasets", exist_ok=True)
        shutil.copy2(DRIVE_DATASET_PATH, TRAIN_DATASET_PATH)
        print(f"‚úÖ Dataset copied from Google Drive")
        
        size_mb = os.path.getsize(TRAIN_DATASET_PATH) / (1024 * 1024)
        print(f"   Size: {size_mb:.1f} MB")
    else:
        print(f"‚ùå Dataset not found in Google Drive either")
        print(f"   Expected: {DRIVE_DATASET_PATH}")
        print("\nüí° Options:")
        print("1. Make sure Git LFS pulled the dataset when cloning")
        print("2. Upload nl2sql_standard_train_en.json to Google Drive")
        print("3. Or run data processing from BIRD raw data")

## Step 6: Configure Training Parameters

**Auto-detects your GPU and optimizes settings:**
- A100/H100 40GB+: High-end config (16K context, batch 6-8, LoRA rank 256)
- L4 24GB: High performance (12K context, batch 3, LoRA rank 128)  
- T4 15GB: Balanced config (4K context, batch 1, LoRA rank 64)

Flash Attention is automatically enabled for Ampere+ GPUs (A100, H100, L4) and disabled for T4.

In [None]:
# Check available GPU memory
import subprocess
import re
import os

# Set CUDA memory configuration to reduce fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

try:
    result = subprocess.run(['nvidia-smi', '--query-gpu=memory.total', '--format=csv,noheader,nounits'], 
                          capture_output=True, text=True)
    gpu_memory_mb = int(result.stdout.strip())
    gpu_memory_gb = gpu_memory_mb / 1024
    print(f"üéÆ Detected GPU Memory: {gpu_memory_gb:.1f} GB")
    
    # Detect GPU architecture for flash attention compatibility
    gpu_name_result = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], 
                                     capture_output=True, text=True)
    gpu_name = gpu_name_result.stdout.strip()
    print(f"üéÆ GPU: {gpu_name}")
    
    # Flash attention only works on Ampere or newer (A100, A10, RTX 3090, RTX 4090, etc.)
    # T4 is Turing architecture - doesn't support flash attention
    supports_flash = any(x in gpu_name.upper() for x in ['A100', 'A10', 'RTX 30', 'RTX 40', 'H100', 'L4'])
    if not supports_flash and 'T4' in gpu_name:
        print(f"‚ö†Ô∏è  T4 GPU detected - Flash Attention will be disabled (T4 is Turing, needs Ampere+)")
    elif supports_flash:
        print(f"‚úÖ Flash Attention supported on this GPU!")
except:
    gpu_memory_gb = 15.0  # Default assumption
    supports_flash = False
    print(f"‚ö†Ô∏è  Could not detect GPU, assuming {gpu_memory_gb} GB")
    print(f"‚ö†Ô∏è  Flash Attention will be disabled for compatibility")

# Auto-configure based on GPU memory - VERY CONSERVATIVE to avoid OOM
if gpu_memory_gb >= 70:
    # 80GB GPU (H100 80GB) - Ultra conservative (model alone uses 55GB!)
    MAX_LENGTH = 8192      # Further reduced from 12288
    LORA_R = 64            # Further reduced from 128
    BATCH_SIZE = 2         # Further reduced from 4
    GRAD_ACC = 16          # Effective batch = 32
    print(f"üî• Using CONSERVATIVE config for {gpu_memory_gb:.1f}GB GPU (H100 80GB)")
    print(f"   ‚ö†Ô∏è  Model uses ~55GB, being very conservative with remaining memory")
elif gpu_memory_gb >= 35:
    # 40GB GPU (A100 40GB, H100 40GB) - Very conservative
    MAX_LENGTH = 6144      # Further reduced from 8192
    LORA_R = 64            # Further reduced from 128
    BATCH_SIZE = 1         # Reduced from 2
    GRAD_ACC = 32          # Effective batch = 32
    print(f"üöÄ Using CONSERVATIVE config for {gpu_memory_gb:.1f}GB GPU (A100/H100)")
    print(f"   ‚ö†Ô∏è  Very conservative settings to avoid OOM")
elif gpu_memory_gb >= 22:
    # 24GB GPU (L4, RTX 4090, A10) - Very conservative
    MAX_LENGTH = 6144      # Reduced from 8192
    LORA_R = 32            # Reduced from 64
    BATCH_SIZE = 1         # Reduced from 2
    GRAD_ACC = 32          # Effective batch = 32
    print(f"üéØ Using CONSERVATIVE config for {gpu_memory_gb:.1f}GB GPU")
    print(f"   ‚ö†Ô∏è  Very conservative for stability")
elif gpu_memory_gb >= 14:
    # 15GB GPU (T4, P100) - BALANCED
    MAX_LENGTH = 4096      # Standard context
    LORA_R = 64            # Balanced LoRA
    BATCH_SIZE = 1         # Small batch
    GRAD_ACC = 32          # Effective batch = 32
    print(f"üìä Using BALANCED config for {gpu_memory_gb:.1f}GB GPU")
    print(f"   (Conservative settings to avoid OOM)")
elif gpu_memory_gb >= 10:
    # 12GB GPU - LOW MEMORY
    MAX_LENGTH = 2048
    LORA_R = 32
    BATCH_SIZE = 1
    GRAD_ACC = 64
    print(f"üìä Using LOW MEMORY config for {gpu_memory_gb:.1f}GB GPU")
else:
    # 8GB GPU - ULTRA-LOW
    MAX_LENGTH = 1024
    LORA_R = 16
    BATCH_SIZE = 1
    GRAD_ACC = 128
    print(f"üìä Using ULTRA-LOW MEMORY config for {gpu_memory_gb:.1f}GB GPU")

TRAINING_CONFIG = {
    # Experiment ID
    "expr_id": "nl2sql_3b_colab_en",
    
    # Model path (set in Step 4)
    "model_path": MODEL_PATH if 'MODEL_PATH' in globals() else "train/model/Qwen/Qwen2.5-Coder-3B-Instruct",
    
    # Dataset path - Using English version
    "data_path": "train/datasets/nl2sql_standard_train_en.json",
    
    # Output directory
    "output_dir": "train/output/dense/nl2sql_3b_colab_en/",
    
    # Training hyperparameters
    "epochs": 5 if gpu_memory_gb >= 35 else 3,  # More epochs for A100/H100
    "learning_rate": 2e-5,
    "weight_decay": 0.1,
    "max_length": MAX_LENGTH,
    
    # LoRA configuration
    "use_lora": True,
    "lora_r": LORA_R,
    "lora_alpha": LORA_R * 2,
    
    # Batch configuration
    "batch_size": BATCH_SIZE,
    "gradient_accumulation_steps": GRAD_ACC,
    
    # Other settings
    "save_steps": 200,
    "group_by_length": True,
    "shuffle": True,
    "use_flash_attention": supports_flash,  # Auto-detect based on GPU
    "bf16": True,
}

print("\nüìã Training Configuration:")
print(f"  Experiment ID: {TRAINING_CONFIG['expr_id']}")
print(f"  Dataset: {TRAINING_CONFIG['data_path']}")
print(f"  Max Length: {TRAINING_CONFIG['max_length']} tokens")
print(f"  LoRA Rank: {TRAINING_CONFIG['lora_r']}")
print(f"  Batch Size: {TRAINING_CONFIG['batch_size']}")
print(f"  Gradient Accumulation: {TRAINING_CONFIG['gradient_accumulation_steps']}")
print(f"  Effective Batch Size: {TRAINING_CONFIG['batch_size'] * TRAINING_CONFIG['gradient_accumulation_steps']}")
print(f"  Epochs: {TRAINING_CONFIG['epochs']}")
print(f"  Learning Rate: {TRAINING_CONFIG['learning_rate']}")
print(f"  Flash Attention: {'‚úÖ Enabled' if TRAINING_CONFIG['use_flash_attention'] else '‚ùå Disabled (GPU not compatible)'}")

print("\nüí° Estimated Training Time:")
samples = 9431
steps_per_epoch = samples // (TRAINING_CONFIG['batch_size'] * TRAINING_CONFIG['gradient_accumulation_steps'])
total_steps = steps_per_epoch * TRAINING_CONFIG['epochs']

# Estimate time based on GPU and configuration
if gpu_memory_gb >= 70:
    time_per_step_sec = 2.5  # H100 80GB - slower due to conservative settings
elif gpu_memory_gb >= 35:
    time_per_step_sec = 3.0  # A100/H100 40GB
elif gpu_memory_gb >= 22:
    time_per_step_sec = 3.5  # L4/A10
elif supports_flash:
    time_per_step_sec = 3.5  # Other GPUs with flash attention
else:
    time_per_step_sec = 4    # Without flash attention

total_hours = (total_steps * time_per_step_sec) / 3600
print(f"  Steps per epoch: ~{steps_per_epoch}")
print(f"  Total steps: ~{total_steps}")
print(f"  Estimated time: ~{total_hours:.1f} hours")

if gpu_memory_gb >= 70:
    print(f"\n‚ö†Ô∏è  H100 80GB Note:")
    print(f"   Base model uses ~55GB alone!")
    print(f"   Using very conservative settings: 8K context, batch 2, LoRA 64")
    print(f"   Training will work but slower than expected")
elif gpu_memory_gb >= 35:
    print(f"\n‚ö†Ô∏è  A100/H100 40GB:")
    print(f"   Using very conservative settings to ensure stability")

print("\n‚ö†Ô∏è  Colab Tips:")
if gpu_memory_gb >= 35:
    print("  - Colab Pro+: 24 hour runtime limit with A100/H100")
    print("  - Training should complete well within time limit")
else:
    print("  - Colab Pro: Longer runtime than free tier")
print("  - Keep browser tab active to prevent disconnection")
print("  - Enable background execution in Colab settings")

print("\nüíæ Expected Memory Usage:")
if gpu_memory_gb >= 70:
    print(f"  ‚ö†Ô∏è  WARNING: Base model alone uses ~55GB!")
    print(f"  - Ultra conservative: 8K context, batch 2, LoRA 64")
    print(f"  - Expected total usage: ~65-70GB")
    print(f"  - This leaves minimal headroom - training will be slow")
elif gpu_memory_gb >= 35:
    print(f"  - Very conservative: 6K context, batch 1, LoRA 64")
    print(f"  - Expected usage: ~28-33GB")
elif gpu_memory_gb >= 22:
    print(f"  - Conservative: 6K context, batch 1, LoRA 32")
    print(f"  - Expected usage: ~15-18GB")
else:
    print(f"  - Conservative settings to prevent OOM")
    print(f"  - Expected usage: ~6-12GB")

print("\nüîß Memory optimization: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")

if gpu_memory_gb >= 70:
    print("\nüí° Tip: The 3B model is quite large in memory.")
    print("   Consider using smaller batch sizes and shorter context if OOM persists.")

## Step 7: Start Training

Run the training with your optimized configuration.

In [None]:
import os
import subprocess
import json

# Set training directory
TRAINING_DIR = "/content/XiYan-SQL/XiYan-SQLTraining"
os.chdir(TRAINING_DIR)

# Create DeepSpeed config YAML for single GPU (compatible with accelerate)
# IMPORTANT: For 15GB GPU, we DON'T offload parameters to CPU - keep model on GPU!
# Only offload optimizer states to save GPU memory
ds_config_yaml = """compute_environment: LOCAL_MACHINE
distributed_type: DEEPSPEED
deepspeed_config:
  gradient_accumulation_steps: {grad_acc}
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: none
  zero3_init_flag: false
  zero3_save_16bit_model: false
  zero_stage: 2
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
""".format(grad_acc=TRAINING_CONFIG["gradient_accumulation_steps"])

# Save DeepSpeed config
os.makedirs("train/config", exist_ok=True)
ds_config_path = "train/config/colab_zero2.yaml"
with open(ds_config_path, 'w') as f:
    f.write(ds_config_yaml)

print("üöÄ Starting XiYan-SQL Training")
print("="*60)
print(f"üìÅ Model: {TRAINING_CONFIG['model_path']}")
print(f"üìä Dataset: {TRAINING_CONFIG['data_path']} (English)")
print(f"üíæ Output: {TRAINING_CONFIG['output_dir']}")
print(f"üéØ Effective Batch: {TRAINING_CONFIG['batch_size'] * TRAINING_CONFIG['gradient_accumulation_steps']}")
print(f"üìè Max Length: {TRAINING_CONFIG['max_length']} tokens")
print(f"üîß LoRA Rank: {TRAINING_CONFIG['lora_r']}")
print("="*60)
print("\n‚è≥ Training will take several hours...")
print("üí° Keep this tab active to prevent disconnection\n")

# Build training command - use absolute paths
cmd = [
    "accelerate", "launch",
    "--config_file", ds_config_path,
    "--num_processes", "1",
    "train/sft4xiyan.py",
    "--save_only_model", "True",
    "--resume", "False",
    "--model_name_or_path", TRAINING_CONFIG["model_path"],
    "--data_path", TRAINING_CONFIG["data_path"],
    "--output_dir", TRAINING_CONFIG["output_dir"],
    "--num_train_epochs", str(TRAINING_CONFIG["epochs"]),
    "--per_device_train_batch_size", str(TRAINING_CONFIG["batch_size"]),
    "--gradient_accumulation_steps", str(TRAINING_CONFIG["gradient_accumulation_steps"]),
    "--save_strategy", "steps",
    "--save_steps", str(TRAINING_CONFIG["save_steps"]),
    "--save_total_limit", "3",
    "--learning_rate", str(TRAINING_CONFIG["learning_rate"]),
    "--weight_decay", str(TRAINING_CONFIG["weight_decay"]),
    "--adam_beta2", "0.95",
    "--warmup_ratio", "0.1",
    "--lr_scheduler_type", "cosine",
    "--logging_steps", "10",
    "--report_to", "none",
    "--model_max_length", str(TRAINING_CONFIG["max_length"]),
    "--lazy_preprocess", "False",
    "--gradient_checkpointing", "True",
    "--predict_with_generate", "True",
    "--include_inputs_for_metrics", "True",
    "--use_lora", str(TRAINING_CONFIG["use_lora"]),
    "--lora_r", str(TRAINING_CONFIG["lora_r"]),
    "--lora_alpha", str(TRAINING_CONFIG["lora_alpha"]),
    "--do_shuffle", str(TRAINING_CONFIG["shuffle"]),
    "--torch_compile", "False",
    "--group_by_length", str(TRAINING_CONFIG["group_by_length"]),
    "--model_type", "auto",
    "--use_flash_attention", str(TRAINING_CONFIG["use_flash_attention"]),
    "--bf16",
    "--expr_id", TRAINING_CONFIG["expr_id"]
]

# Show the full command for debugging
print("üìù Training command:")
print(" ".join(cmd))
print("\n" + "="*60 + "\n")

# Run training
try:
    result = subprocess.run(cmd, cwd=TRAINING_DIR, check=False)
    
    if result.returncode == 0:
        print("\n" + "="*60)
        print("‚úÖ Training completed successfully!")
        print(f"üìÅ Model saved to: {TRAINING_CONFIG['output_dir']}")
        print("="*60)
    else:
        print("\n" + "="*60)
        print(f"‚ùå Training failed with return code {result.returncode}")
        print("="*60)
        print("\nüí° Common issues:")
        print("  - Dataset not found: Check that nl2sql_standard_train_en.json exists")
        print("  - Model not found: Check MODEL_PATH is correct")
        print("  - Out of memory: Try reducing max_length or batch_size")
        print("\nüîç Check the error messages above for more details")
except Exception as e:
    print(f"\n‚ùå Error during training: {e}")

## Step 8: Save Trained Model to Google Drive (Optional)

After training completes, save your model to Google Drive for future use.

In [None]:
import shutil
import os

# Path to trained model
TRAINED_MODEL_PATH = TRAINING_CONFIG["output_dir"]

# Destination in Google Drive
DRIVE_SAVE_PATH = f"/content/drive/MyDrive/XiYan-SQL/Trained-Models/{TRAINING_CONFIG['expr_id']}"

if os.path.exists(TRAINED_MODEL_PATH):
    print(f"üì• Copying trained model to Google Drive...")
    print(f"   From: {TRAINED_MODEL_PATH}")
    print(f"   To: {DRIVE_SAVE_PATH}")
    
    # Create parent directory
    os.makedirs(os.path.dirname(DRIVE_SAVE_PATH), exist_ok=True)
    
    # Copy model
    if os.path.exists(DRIVE_SAVE_PATH):
        shutil.rmtree(DRIVE_SAVE_PATH)
    
    shutil.copytree(TRAINED_MODEL_PATH, DRIVE_SAVE_PATH)
    print(f"\n‚úÖ Model saved to Google Drive!")
    print(f"üìÅ Location: {DRIVE_SAVE_PATH}")
else:
    print(f"‚ö†Ô∏è  Trained model not found at {TRAINED_MODEL_PATH}")
    print("Make sure training completed successfully in Step 7.")

## Step 9: Merge LoRA Adapter with Base Model (Optional)

Merge the trained LoRA adapter with the base model to create a single, deployable model. This step is **required** if you want to use the model without loading the adapter separately.

In [None]:
import os
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM

# Paths
BASE_MODEL_PATH = TRAINING_CONFIG["model_path"]
ADAPTER_PATH = TRAINING_CONFIG["output_dir"]

# Find the latest checkpoint
checkpoint_dirs = []
if os.path.exists(ADAPTER_PATH):
    for item in os.listdir(ADAPTER_PATH):
        item_path = os.path.join(ADAPTER_PATH, item)
        if os.path.isdir(item_path) and ("checkpoint" in item.lower() or "adapter" in item.lower()):
            checkpoint_dirs.append(item_path)

if checkpoint_dirs:
    checkpoint_dirs.sort(key=lambda x: os.path.getmtime(x), reverse=True)
    ADAPTER_CHECKPOINT = checkpoint_dirs[0]
else:
    ADAPTER_CHECKPOINT = ADAPTER_PATH

# Output path for merged model
MERGED_MODEL_PATH = f"/content/drive/MyDrive/XiYan-SQL/Trained-Models/{TRAINING_CONFIG['expr_id']}-merged"

print("üîÑ Merging LoRA adapter with base model...")
print(f"   Base model: {BASE_MODEL_PATH}")
print(f"   Adapter: {ADAPTER_CHECKPOINT}")
print(f"   Output: {MERGED_MODEL_PATH}")
print("\n‚è≥ This may take a few minutes...\n")

try:
    # Load base model and tokenizer
    print("üì• Loading base model...")
    tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL_PATH, 
        use_fast=False, 
        trust_remote_code=True
    )
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_PATH,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto"
    )
    
    # Load LoRA adapter
    print("üì• Loading LoRA adapter...")
    lora_model = PeftModel.from_pretrained(
        base_model,
        ADAPTER_CHECKPOINT,
        torch_dtype=torch.float16,
        device_map='auto'
    )
    
    # Merge adapter into base model
    print("üîó Merging adapter into base model...")
    merged_model = lora_model.merge_and_unload()
    
    # Save merged model
    print(f"üíæ Saving merged model to {MERGED_MODEL_PATH}...")
    os.makedirs(MERGED_MODEL_PATH, exist_ok=True)
    merged_model.save_pretrained(MERGED_MODEL_PATH)
    tokenizer.save_pretrained(MERGED_MODEL_PATH)
    
    print(f"\n‚úÖ Merged model saved successfully!")
    print(f"üìÅ Location: {MERGED_MODEL_PATH}")
    print(f"\nüí° You can now use this merged model directly without loading the adapter separately.")
    
    # Save paths for quick testing
    globals()['MERGED_MODEL_FOR_TESTING'] = MERGED_MODEL_PATH
    
except Exception as e:
    print(f"\n‚ùå Error during merging: {e}")
    print("\nüí° Troubleshooting:")
    print("  - Make sure the base model path is correct")
    print("  - Verify the adapter checkpoint exists")
    print("  - Check that you have enough disk space in Google Drive")

## Step 10: Quick Inference Test

Test your trained model with sample questions to verify it's generating reasonable SQL queries.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# Use the same template format as training
NL2SQLITE_TEMPLATE_EN = """You are a SQLite expert. You need to read and understand the following„ÄêDatabase Schema„Äëdescription and the possible provided„ÄêEvidence„Äë, and use valid SQLite knowledge to generate SQL for answering the„ÄêQuestion„Äë.
„ÄêQuestion„Äë
{question}

„ÄêDatabase Schema„Äë
{db_schema}

„ÄêEvidence„Äë
{evidence}

„ÄêQuestion„Äë
{question}

```sql"""


def extract_sql_only(text):
    """Extract only SQL from model output, removing explanations."""
    if not text:
        return text
    
    text = text.strip()
    
    # Pattern 1: SQL in markdown code blocks
    if '```sql' in text:
        parts = text.split('```sql')
        if len(parts) > 1:
            sql = parts[1].split('```')[0].strip()
            return sql
    
    # Pattern 2: SQL in plain code blocks
    if '```' in text:
        parts = text.split('```')
        if len(parts) > 1:
            sql = parts[1].strip()
            # Check if it looks like SQL
            if any(sql.upper().startswith(kw) for kw in ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'WITH']):
                return sql
    
    # Pattern 3: SQL starts with SELECT/INSERT/etc
    sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'WITH']
    for keyword in sql_keywords:
        if text.upper().startswith(keyword):
            # Take until we hit explanation or end
            lines = text.split('\n')
            sql_lines = []
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                # Stop at explanation patterns
                explanation_indicators = [
                    'this query', 'here\'s', 'to find', 'you can use',
                    'the query', 'this will', 'selects', 'groups',
                    'This query', 'Here\'s', 'To find', 'You can use',
                    'The query', 'This will'
                ]
                # Check if line contains explanation (but not if it's actual SQL)
                is_explanation = any(indicator in line.lower() for indicator in explanation_indicators)
                # Also check for markdown code block endings that suggest explanation follows
                if is_explanation or ('```' in line and 'SELECT' not in line.upper()):
                    # Only break if we already have SQL
                    if len(sql_lines) > 0:
                        break
                sql_lines.append(line)
            return ' '.join(sql_lines)
    
    # Pattern 4: Look for SQL after explanation text
    # Find first occurrence of SQL keywords
    for keyword in sql_keywords:
        idx = text.upper().find(keyword)
        if idx != -1:
            # Extract from that position
            sql_text = text[idx:].strip()
            lines = sql_text.split('\n')
            sql_lines = []
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                # Stop at new explanation
                explanation_indicators = [
                    'this query', 'here\'s', 'to find', 'you can use',
                    'the query', 'this will'
                ]
                if any(indicator in line.lower() for indicator in explanation_indicators):
                    if len(sql_lines) > 0:
                        break
                sql_lines.append(line)
            return ' '.join(sql_lines)
    
    return text


# Determine which model to test
# Priority: 1) Merged model, 2) Latest checkpoint with adapter, 3) Base model with adapter
model_to_test = None
use_adapter = False
adapter_path = None

if 'MERGED_MODEL_FOR_TESTING' in globals() and os.path.exists(globals()['MERGED_MODEL_FOR_TESTING']):
    model_to_test = globals()['MERGED_MODEL_FOR_TESTING']
    print(f"üéØ Testing merged model: {model_to_test}")
elif 'ADAPTER_CHECKPOINT' in globals() and os.path.exists(globals()['ADAPTER_CHECKPOINT']):
    model_to_test = TRAINING_CONFIG["model_path"]
    adapter_path = globals()['ADAPTER_CHECKPOINT']
    use_adapter = True
    print(f"üéØ Testing base model + adapter:")
    print(f"   Base: {model_to_test}")
    print(f"   Adapter: {adapter_path}")
else:
    print("‚ùå No trained model found!")
    print("   Please complete Step 7 (training) and optionally Step 9 (merging) first.")
    model_to_test = None

if model_to_test:
    print("\nüì• Loading model and tokenizer...")
    print("   (This may take 1-2 minutes)\n")
    
    try:
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_to_test,
            use_fast=False,
            trust_remote_code=True
        )
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model
        if use_adapter:
            # Load base model first
            base_model = AutoModelForCausalLM.from_pretrained(
                model_to_test,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True
            )
            # Load adapter on top
            model = PeftModel.from_pretrained(
                base_model,
                adapter_path,
                torch_dtype=torch.bfloat16
            )
            model.eval()
        else:
            # Load merged model directly
            model = AutoModelForCausalLM.from_pretrained(
                model_to_test,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True
            )
            model.eval()
        
        print("‚úÖ Model loaded successfully!\n")
        
        # Test cases (using m-schema format like training data)
        # All from movie_3 database
        movie_3_schema = """„ÄêDB_ID„Äë movie_3
„ÄêSchema„Äë
# Table: film
[
(film_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(title:TEXT, Examples: [ACADEMY DINOSAUR, ACE GOLDFINGER, ADAPTATION HOLES]),
(description:TEXT),
(release_year:TEXT, Examples: [2006]),
(language_id:INTEGER, Examples: [1]),
(original_language_id:INTEGER),
(rental_duration:INTEGER, Examples: [6, 3, 7]),
(rental_rate:REAL, Examples: [0.99, 4.99, 2.99]),
(length:INTEGER, Examples: [86, 48, 50]),
(replacement_cost:REAL, Examples: [20.99, 12.99, 18.99]),
(rating:TEXT, Examples: [PG, G, NC-17]),
(special_features:TEXT, Examples: [Trailers,Deleted Scenes, Commentaries,Behind the Scenes]),
(last_update:DATETIME, Examples: [2006-02-15 05:03:42.0])
]
# Table: rental
[
(rental_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(rental_date:DATETIME, Examples: [2005-05-24 22:53:30.0]),
(inventory_id:INTEGER, Examples: [367, 1525, 1711]),
(customer_id:INTEGER, Examples: [130, 459, 408]),
(return_date:DATETIME, Examples: [2005-05-26 22:04:30.0]),
(staff_id:INTEGER, Examples: [1, 2]),
(last_update:DATETIME, Examples: [2006-02-15 21:30:53.0])
]
# Table: store
[
(store_id:INTEGER, Primary Key, Examples: [1, 2]),
(manager_staff_id:INTEGER, Examples: [1, 2]),
(address_id:INTEGER, Examples: [1, 2]),
(last_update:DATETIME, Examples: [2006-02-15 04:57:12.0])
]
# Table: inventory
[
(inventory_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(film_id:INTEGER, Examples: [1, 2, 3]),
(store_id:INTEGER, Examples: [1, 2]),
(last_update:DATETIME, Examples: [2006-02-15 05:09:17.0])
]
# Table: address
[
(address_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(address:TEXT, Examples: [47 MySakila Drive, 28 MySQL Boulevard, 23 Workhaven Lane]),
(address2:TEXT),
(district:TEXT, Examples: [Alberta, QLD, Nagasaki]),
(city_id:INTEGER, Examples: [300, 576, 463]),
(postal_code:TEXT, Examples: [35200, 17886, 83579]),
(phone:TEXT, Examples: [14033335568, 6172235589, 28303384290]),
(last_update:DATETIME, Examples: [2006-02-15 04:45:30.0])
]
# Table: country
[
(country_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(country:TEXT, Examples: [Afghanistan, Algeria, American Samoa]),
(last_update:DATETIME, Examples: [2006-02-15 04:44:00.0])
]
# Table: city
[
(city_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(city:TEXT, Examples: [A Corua (La Corua), Abha, Abu Dhabi]),
(country_id:INTEGER, Examples: [87, 82, 101]),
(last_update:DATETIME, Examples: [2006-02-15 04:45:25.0])
]
# Table: film_actor
[
(actor_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(film_id:INTEGER, Primary Key, Examples: [1, 23, 25]),
(last_update:DATETIME, Examples: [2006-02-15 05:05:03.0])
]
# Table: payment
[
(payment_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(customer_id:INTEGER, Examples: [1, 2, 3]),
(staff_id:INTEGER, Examples: [1, 2]),
(rental_id:INTEGER, Examples: [76, 573, 1185]),
(amount:REAL, Examples: [2.99, 0.99, 5.99]),
(payment_date:DATETIME, Examples: [2005-05-25 11:30:37.0]),
(last_update:DATETIME, Examples: [2006-02-15 22:12:30.0])
]
# Table: film_text
[
(film_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(title:TEXT, Examples: [ACADEMY DINOSAUR, ACE GOLDFINGER, ADAPTATION HOLES]),
(description:TEXT)
]
# Table: customer
[
(customer_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(store_id:INTEGER, Examples: [1, 2]),
(first_name:TEXT, Examples: [MARY, PATRICIA, LINDA]),
(last_name:TEXT, Examples: [SMITH, JOHNSON, WILLIAMS]),
(email:TEXT),
(address_id:INTEGER, Examples: [5, 6, 7]),
(active:INTEGER, Examples: [1, 0]),
(create_date:DATETIME, Examples: [2006-02-14 22:04:36.0]),
(last_update:DATETIME, Examples: [2006-02-15 04:57:20.0])
]
# Table: staff
[
(staff_id:INTEGER, Primary Key, Examples: [1, 2]),
(first_name:TEXT, Examples: [Mike, Jon]),
(last_name:TEXT, Examples: [Hillyer, Stephens]),
(address_id:INTEGER, Examples: [3, 4]),
(picture:BLOB),
(email:TEXT),
(store_id:INTEGER, Examples: [1, 2]),
(active:INTEGER, Examples: [1]),
(username:TEXT, Examples: [Mike, Jon]),
(password:TEXT),
(last_update:DATETIME, Examples: [2006-02-15 04:57:16.0])
]
# Table: language
[
(language_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(name:TEXT, Examples: [English, Italian, Japanese]),
(last_update:DATETIME, Examples: [2006-02-15 05:02:19.0])
]
# Table: film_category
[
(film_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(category_id:INTEGER, Primary Key, Examples: [6, 11, 8]),
(last_update:DATETIME, Examples: [2006-02-15 05:07:09.0])
]
# Table: category
[
(category_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(name:TEXT, Examples: [Action, Animation, Children]),
(last_update:DATETIME, Examples: [2006-02-15 04:46:27.0])
]
# Table: actor
[
(actor_id:INTEGER, Primary Key, Examples: [1, 2, 3]),
(first_name:TEXT, Examples: [PENELOPE, NICK, ED]),
(last_name:TEXT, Examples: [GUINESS, WAHLBERG, CHASE]),
(last_update:DATETIME, Examples: [2006-02-15 04:34:33.0])
]
„ÄêForeign keys„Äë
film.original_language_id=language.language_id
film.language_id=language.language_id
rental.staff_id=staff.staff_id
rental.customer_id=customer.customer_id
rental.inventory_id=inventory.inventory_id
store.address_id=address.address_id
store.manager_staff_id=staff.staff_id
inventory.store_id=store.store_id
inventory.film_id=film.film_id
address.city_id=city.city_id
city.country_id=country.country_id
film_actor.film_id=film.film_id
film_actor.actor_id=actor.actor_id
payment.rental_id=rental.rental_id
payment.staff_id=staff.staff_id
payment.customer_id=customer.customer_id
customer.address_id=address.address_id
customer.store_id=store.store_id
staff.store_id=store.store_id
staff.address_id=address.address_id
film_category.category_id=category.category_id
film_category.film_id=film.film_id"""

        test_cases = [
            {
                "schema": movie_3_schema,
                "question": "Among the times Mary Smith had rented a movie, how many of them happened in June, 2005?",
                "evidence": "in June 2005 refers to year(payment_date) = 2005 and month(payment_date) = 6"
            },
            {
                "schema": movie_3_schema,
                "question": "Please give the full name of the customer who had made the biggest amount of payment in one single film rental.",
                "evidence": "full name refers to first_name, last_name; the biggest amount refers to max(amount)"
            },
            {
                "schema": movie_3_schema,
                "question": "How much in total had the customers in Italy spent on film rentals?",
                "evidence": "total = sum(amount); Italy refers to country = 'Italy'"
            },
            {
                "schema": movie_3_schema,
                "question": "Among the payments made by Mary Smith, how many of them are over 4.99?",
                "evidence": "over 4.99 refers to amount > 4.99"
            },
            {
                "schema": movie_3_schema,
                "question": "What is the average amount of money spent by a customer in Italy on a single film rental?",
                "evidence": "Italy refers to country = 'Italy'; average amount = divide(sum(amount), count(customer_id)) where country = 'Italy'"
            }
        ]
        
        print("="*80)
        print("üß™ QUICK INFERENCE TEST")
        print("="*80)
        
        for i, test_case in enumerate(test_cases, 1):
            print(f"\n{'='*80}")
            print(f"Test Case {i}:")
            print(f"{'='*80}")
            print(f"\nüìù Question: {test_case['question']}")
            print(f"\nüìä Schema:")
            for line in test_case['schema'].split('\n')[:5]:  # Show first 5 lines
                print(f"   {line}")
            print("   ...")
            
            # Build prompt using the same template as training
            prompt_text = NL2SQLITE_TEMPLATE_EN.format(
                question=test_case['question'],
                db_schema=test_case['schema'],
                evidence=test_case['evidence']
            )
            
            # Create conversation format (same as training data)
            conversations = [
                {
                    "role": "user",
                    "content": prompt_text
                }
            ]
            
            # Apply chat template (same as sql_infer.py)
            text = tokenizer.apply_chat_template(
                conversations,
                tokenize=False,
                add_generation_prompt=True
            )
            
            # Tokenize
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            # Generate (using similar params as sql_infer.py)
            print("\n‚è≥ Generating SQL...")
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=256,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
            
            # Extract only the generated part (after input)
            generated_ids = outputs[0][len(inputs['input_ids'][0]):]
            raw_output = tokenizer.decode(generated_ids, skip_special_tokens=True)
            
            # Extract only SQL, removing any explanations
            sql_output = extract_sql_only(raw_output)
            
            print(f"\n‚úÖ Generated SQL:")
            print("‚îÄ"*80)
            print(sql_output)
            print("‚îÄ"*80)
            
            # Show raw output if it differs significantly (for debugging)
            if raw_output != sql_output and len(raw_output) > len(sql_output) + 20:
                print(f"\n‚ö†Ô∏è  Note: Model also generated explanatory text (removed)")
                print(f"   Raw output length: {len(raw_output)} chars, SQL length: {len(sql_output)} chars")
        
        print(f"\n{'='*80}")
        print("‚úÖ Quick inference test completed!")
        print("="*80)
        
        # Clean up to free memory
        print("\nüßπ Cleaning up memory...")
        if 'model' in locals():
            del model
        if 'tokenizer' in locals():
            del tokenizer
        if 'base_model' in locals():
            del base_model
        torch.cuda.empty_cache()
        print("‚úÖ Memory freed")
        
    except Exception as e:
        print(f"\n‚ùå Error during inference: {e}")
        import traceback
        traceback.print_exc()



## Troubleshooting

### Out of Memory (OOM) Errors
- Reduce `batch_size` to 1
- Reduce `max_length` to 8192 or 4096
- Increase `gradient_accumulation_steps` to maintain effective batch size
- The DeepSpeed config already uses CPU offloading, which helps

### Model Not Found
- Check that `MODEL_DRIVE_PATH` in Step 4 is correct
- Verify the model folder exists in Google Drive
- Ensure the model folder contains all required files (config.json, tokenizer files, etc.)

### Dataset Not Found
- Check that dataset paths in Step 5 are correct
- Verify files exist in Google Drive
- If processing raw data, ensure `db_conn.json` exists

### Training Too Slow
- Colab free tier has limited GPU time
- Consider using Colab Pro for longer training sessions
- Reduce dataset size for testing (set `sample_num` in dataset config)

### Connection Issues
- Colab sessions may disconnect after inactivity
- Use `nohup` or save checkpoints frequently
- Consider running training in multiple sessions if needed