# SQL Codegen SLM - Training Notebook

Fine-tune Mistral-7B for PostgreSQL query generation using LoRA and 4-bit quantization.

**Requirements:**
- Google Colab Pro+ (for A100 GPU access)
- GCP Project with Cloud Storage
- ~8-12 hours training time

**Data:** Already uploaded to `gs://sql-codegen-slm-data/data/`

## 1. Check GPU Allocation

In [None]:
!nvidia-smi

import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"\n‚úÖ GPU: {gpu_name} ({gpu_mem:.1f} GB)")
    
    if "A100" in gpu_name:
        print("üéâ Got A100 - optimal for training!")
    elif "V100" in gpu_name:
        print("‚ö†Ô∏è V100 - good, but A100 is faster")
    elif "T4" in gpu_name:
        print("‚ö†Ô∏è T4 - training will be slower. Consider reconnecting for A100.")
else:
    print("‚ùå No GPU! Go to Runtime > Change runtime type > GPU")

## 2. Configure & Authenticate GCP

In [None]:
# GCP Configuration
PROJECT_ID = "your-gcp-project-id"
BUCKET_NAME = "sql-codegen-slm-data"

import os
os.environ["GCP_PROJECT_ID"] = PROJECT_ID
os.environ["GCS_BUCKET"] = BUCKET_NAME

# Authenticate
from google.colab import auth
auth.authenticate_user()

!gcloud config set project {PROJECT_ID}
print(f"\n‚úÖ Authenticated with project: {PROJECT_ID}")
print(f"   Bucket: gs://{BUCKET_NAME}")

## 3. Clone Repository & Install Dependencies

In [None]:
import os
if not os.path.exists('sql-codegen-slm'):
    !git clone https://github.com/rajesh-manikka/sql-codegen-slm.git
%cd sql-codegen-slm

# Install dependencies
!pip install -q -r training/requirements.txt

print("\n‚úÖ Dependencies installed")

## 4. Download Data from GCS

In [None]:
# Create local directories
!mkdir -p /content/data /content/models /content/logs /content/tensorboard

# Download data from GCS
!gsutil -m cp gs://{BUCKET_NAME}/data/*.jsonl /content/data/

# Verify
print("\nüìä Dataset:")
!wc -l /content/data/*.jsonl

## 5. Verify Environment

In [None]:
from training.colab_setup import check_gpu, estimate_training_time

# Check GPU
gpu_info = check_gpu()

# Estimate training time
print("\n")
estimate_training_time()

## 6. Start Training

**Estimated time:** 8-12 hours on A100

Checkpoints save every 500 steps to `/content/models/`

In [None]:
# Start training
!python -m training.train --config training/configs/mistral_lora_config.yaml

In [None]:
# If training was interrupted, resume from checkpoint:
# !python -m training.train --config training/configs/mistral_lora_config.yaml --resume

## 7. Sync Checkpoints to GCS

Run periodically to backup checkpoints.

In [None]:
!gsutil -m rsync -r /content/models gs://{BUCKET_NAME}/models/
print(f"\n‚úÖ Synced to gs://{BUCKET_NAME}/models/")

## 8. Monitor with TensorBoard

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/tensorboard

## 9. Test the Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_path = "/content/models"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto",
)
print("‚úÖ Model loaded")

In [None]:
schema = """
CREATE TABLE customers (id SERIAL PRIMARY KEY, name VARCHAR(100), email VARCHAR(100));
CREATE TABLE orders (id SERIAL PRIMARY KEY, customer_id INTEGER REFERENCES customers(id), total DECIMAL(10,2), created_at TIMESTAMP);
"""

question = "Find customers who have placed more than 5 orders"

prompt = f"[INST] Given the PostgreSQL schema:\n{schema}\nWrite SQL to: {question} [/INST]"

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.1, do_sample=True)

sql = tokenizer.decode(outputs[0], skip_special_tokens=True).split("[/INST]")[-1].strip()
print(f"üìù Question: {question}\n\nüîç SQL:\n{sql}")

## 10. Final Sync to GCS

In [None]:
!gsutil -m rsync -r /content/models gs://{BUCKET_NAME}/models/
!gsutil -m rsync -r /content/tensorboard gs://{BUCKET_NAME}/tensorboard/

print(f"\n‚úÖ All files synced to gs://{BUCKET_NAME}/")
print(f"View: https://console.cloud.google.com/storage/browser/{BUCKET_NAME}")

---
## Troubleshooting

**Session disconnected?**
1. Reconnect, run cells 1-4
2. Download checkpoint: `!gsutil -m cp -r gs://{BUCKET_NAME}/models/* /content/models/`
3. Resume: `!python -m training.train --config training/configs/mistral_lora_config.yaml --resume`

**Out of memory?** Reduce batch size in config to 2, increase gradient accumulation to 8.