# Train ADS/SciX NL Query Translator

This notebook trains a fine-tuned Qwen3-1.7B model to convert natural language to ADS search queries.

**Requirements:**
- Google Colab with GPU runtime (T4 is fine)
- ~30 minutes for training
- HuggingFace account for upload

**Output:**
- Merged model ready for upload to `adsabs/scix-nls-translator`

## 1. Setup Environment

In [None]:
# Check GPU
!nvidia-smi --query-gpu=name,memory.total --format=csv

In [None]:
# Install dependencies
!pip install -q torch transformers datasets peft accelerate trl
!pip install -q "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

## 2. Download Training Data

In [None]:
# Option A: Download from GitHub (if public)
# !wget https://raw.githubusercontent.com/adsabs/nls-finetune-scix/main/data/datasets/processed/train.jsonl

# Option B: Upload manually
from google.colab import files
print("Upload train.jsonl file:")
uploaded = files.upload()

In [None]:
import json

# Load training data
train_data = []
with open('train.jsonl') as f:
    for line in f:
        train_data.append(json.loads(line))

print(f"Loaded {len(train_data)} training examples")
print(f"\nExample:")
print(json.dumps(train_data[0], indent=2))

## 3. Load Model with Unsloth

In [None]:
import torch
from unsloth import FastLanguageModel, UnslothTrainer, UnslothTrainingArguments
from datasets import Dataset

# Configuration
MAX_SEQ_LENGTH = 512
EPOCHS = 3
BATCH_SIZE = 8
LEARNING_RATE = 2e-4

# Load model with fp16 for T4 GPU (bf16 requires Ampere+)
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen3-1.7B",
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=torch.float16,  # Use fp16 for T4
    load_in_4bit=True,
)

print(f"Model loaded: vocab_size={tokenizer.vocab_size}")

In [None]:
# Prepare dataset
def format_chat_template(example):
    text = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=False,
        add_generation_prompt=False,
    )
    return {"text": text}

dataset = Dataset.from_list(train_data)
dataset = dataset.map(format_chat_template, remove_columns=dataset.column_names)
print(f"Dataset ready: {len(dataset)} examples")

In [None]:
# Apply LoRA
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=32,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=42,
    max_seq_length=MAX_SEQ_LENGTH,
)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"LoRA: {trainable:,} trainable / {total:,} total ({100*trainable/total:.2f}%)")

## 4. Train

In [None]:
training_args = UnslothTrainingArguments(
    output_dir="./output",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=2,
    learning_rate=LEARNING_RATE,
    warmup_ratio=0.1,
    logging_steps=10,
    save_steps=100,
    save_total_limit=2,
    fp16=True,  # Use fp16 for T4 GPU (bf16 requires Ampere+)
    optim="adamw_8bit",
    seed=42,
    report_to="none",
)

trainer = UnslothTrainer(
    model=model,
    tokenizer=tokenizer,  # Required for Unsloth
    args=training_args,
    train_dataset=dataset,
)

print("Starting training...")
trainer.train()
print("Training complete!")

In [None]:
# Save LoRA adapter
model.save_pretrained("./output/lora_adapter")
tokenizer.save_pretrained("./output/lora_adapter")
print("LoRA adapter saved")

## 5. Merge LoRA into Base Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
base_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B", trust_remote_code=True)

print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, "./output/lora_adapter")

print("Merging...")
merged_model = model.merge_and_unload()

print("Saving merged model...")
merged_model.save_pretrained("./output/merged", safe_serialization=True)
base_tokenizer.save_pretrained("./output/merged")

print("Merged model saved to ./output/merged")

## 6. Test the Model

In [None]:
# Quick test
from transformers import pipeline

pipe = pipeline("text-generation", model="./output/merged", torch_dtype=torch.float16, device_map="auto")

test_queries = [
    "papers about exoplanets published in 2023",
    "articles by John Smith on machine learning",
    "highly cited papers about dark matter",
]

for query in test_queries:
    messages = [
        {"role": "system", "content": "Convert natural language to ADS search query. Output JSON: {\"query\": \"...\"}"},
        {"role": "user", "content": f"Query: {query}\nDate: 2025-01-23"},
    ]
    prompt = base_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    output = pipe(prompt, max_new_tokens=128, do_sample=False)[0]["generated_text"]
    response = output[len(prompt):].strip()
    print(f"\nInput: {query}")
    print(f"Output: {response}")

## 7. Upload to HuggingFace

In [None]:
# Login to HuggingFace
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# Upload model
REPO_ID = "adsabs/scix-nls-translator"  # Change if needed

merged_model.push_to_hub(REPO_ID, safe_serialization=True)
base_tokenizer.push_to_hub(REPO_ID)

print(f"\nModel uploaded to: https://huggingface.co/{REPO_ID}")

In [None]:
# Download merged model to local machine
!zip -r merged_model.zip ./output/merged
files.download('merged_model.zip')

## Done!

The model is now available at `https://huggingface.co/adsabs/scix-nls-translator`

### Deployment Options

**vLLM (recommended):**
```bash
pip install vllm
vllm serve adsabs/scix-nls-translator --max-model-len 512
```

**Text Generation Inference:**
```bash
docker run --gpus all -p 8080:80 \
  ghcr.io/huggingface/text-generation-inference:latest \
  --model-id adsabs/scix-nls-translator
```

**AWS SageMaker:**
```python
from sagemaker.huggingface import HuggingFaceModel
model = HuggingFaceModel(
    model_data="adsabs/scix-nls-translator",
    role=role,
    transformers_version="4.37",
    pytorch_version="2.1",
    py_version="py310",
)
predictor = model.deploy(instance_type="ml.g5.xlarge", initial_instance_count=1)
```