In [None]:
## FINETUNE, MERGE, AND UPLOAD

print("Installing required libraries")
!pip install -q -U "numpy==1.26.4" "torch==2.3.1" "transformers==4.42.3" "peft==0.11.1" "accelerate==0.31.0" "trl==0.9.4" "datasets==2.19.2" "bitsandbytes==0.43.1"

In [None]:
import json
import os
import torch
from datasets import Dataset

from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer
from huggingface_hub import notebook_login
from getpass import getpass
BASE_MODEL_ID = "google/gemma-2-2b-it"
ADAPTER_SAVE_NAME = "gemma-2-2b-it-numpy-refactor-adapter-v1"
#manually set, then for downloading later
HF_REPO_ID = "priyam-turakhia/gemma-2-2b-it-numpy-refactor-merged-v1"

notebook_login()

In [None]:
SYSTEM_PROMPT = (
    "You are a Python code refactoring tool for NumPy. Your task is to replace only the deprecated functions in the given code snippet with their modern equivalents.\n"
    "Your response must be structured with two markdown sections:\n"
    "1. A '### Refactored Code' section containing ONLY the updated Python code block.\n"
    "2. A '### Deprecation Context' section containing a brief explanation of the deprecation.\n"
    "Do not change the code's logic. If no functions are deprecated, return the original code and state that no changes were needed in the context section."
)

print("Loading model and tokenizer:")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
tokenizer.padding_side = 'right'

print("Preparing dataset")
PATH_TO_TRAINING = 'training_data.json'
with open(PATH_TO_TRAINING, 'r', encoding='utf-8') as f:
    training_data = json.load(f)


def create_prompt(sample):
    user_content = f"{SYSTEM_PROMPT}\n\n### INPUT CODE:\n```python\n{sample['input']}\n```"

    assistant_response = (
        "### Refactored Code\n"
        f"```python\n{sample['output']}\n```\n"
        "### Deprecation Context\n"
        f"{sample['context']}"
    )

    messages = [
        {"role": "user", "content": user_content},
        {"role": "model", "content": assistant_response}
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

dataset = Dataset.from_list([{'text': create_prompt(s)} for s in training_data])

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, 
)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation="eager" 
)

# LoRA configuration - same target modules work for Gemma
print("Configuring LoRA and training args")
peft_config = LoraConfig(
    r=16, 
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

training_args = TrainingArguments(
    output_dir="./models",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=10,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    max_grad_norm=0.3, 
    save_strategy="epoch",
    group_by_length=True,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    tokenizer=tokenizer,
    args=training_args,
    dataset_text_field="text",
    max_seq_length=1024,
    packing=False, 
)

print("Fine-tuning:")
trainer.train()

print(f"Saving final adapter model to '{ADAPTER_SAVE_NAME}':")
trainer.model.save_pretrained(ADAPTER_SAVE_NAME)
print("Adapter model saved.")

print("Merging LoRA adapter with base model")
del model
del trainer
torch.cuda.empty_cache()

base_model_bf16 = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager"
)

model_merged = PeftModel.from_pretrained(base_model_bf16, ADAPTER_SAVE_NAME)
model_merged = model_merged.merge_and_unload()
print("Adapter merged!")

print(f"Pushing to '{HF_REPO_ID}':")
model_merged.push_to_hub(HF_REPO_ID)
tokenizer.push_to_hub(HF_REPO_ID)
print("DONE!!")