In [None]:
## FINETUNE, MERGE, AND UPLOAD

print("Installing required libraries")
!pip install -q -U "numpy==1.26.4" "torch==2.3.1" "transformers==4.42.3" "peft==0.11.1" "accelerate==0.31.0" "trl==0.9.4" "datasets==2.19.2" "bitsandbytes==0.43.1"
import json
import os
import torch
from datasets import Dataset

from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer
from huggingface_hub import notebook_login
from getpass import getpass

BASE_MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat"
ADAPTER_SAVE_NAME = "qwen1.5-1.8b-chat-numpy-refactor-adapter-v2"
#manuell setzen, dann zum runterladen später
HF_REPO_ID = "julianpins/qwen1.5-1.8b-chat-numpy-refactor-merged-v2"

notebook_login()

#prompt für code gen und context
SYSTEM_PROMPT = (
    "You are a Python code refactoring tool for NumPy. Your task is to replace only the deprecated functions in the given code snippet with their modern equivalents.\n"
    "Your response must be structured with two markdown sections:\n"
    "1. A '### Refactored Code' section containing ONLY the updated Python code block.\n"
    "2. A '### Deprecation Context' section containing a brief explanation of the deprecation.\n"
    "Do not change the code's logic. If no functions are deprecated, return the original code and state that no changes were needed in the context section."
)

print("Preparing dataset")
PATH_TO_TRAINING = 'training_data.json'
with open(PATH_TO_TRAINING, 'r', encoding='utf-8') as f:
    training_data = json.load(f)


def create_prompt(sample):
    #Qwen format
    assistant_response = (
        "### Refactored Code\n"
        f"```python\n{sample['output']}\n```\n"
        "### Deprecation Context\n"
        f"{sample['context']}"
    )

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"### INPUT CODE:\n```python\n{sample['input']}\n```"},
        {"role": "assistant", "content": assistant_response}
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

dataset = Dataset.from_list([{'text': create_prompt(s)} for s in training_data])

# model & tokenizer
print("Loading model and tokenizer:")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
# Qwen
tokenizer.pad_token = tokenizer.eos_token


# 4-bit quantization for memory-efficient training
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

# LorA
print("⚙️ Configuring LoRA and training arguments...")
peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

training_args = TrainingArguments(
    output_dir="./models",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    logging_steps=10,
    fp16=True,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    save_strategy="epoch",
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=1024,
    tokenizer=tokenizer,
    args=training_args,
    packing=False,
)

print("Fine-tuning:")
trainer.train()

print(f"Saving final adapter model to '{ADAPTER_SAVE_NAME}':")
trainer.model.save_pretrained(ADAPTER_SAVE_NAME)
print("Adapter model saved.")

print("Merging LoRA adapter with base model")
del model
del trainer
torch.cuda.empty_cache()

base_model_fp16 = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

model_merged = PeftModel.from_pretrained(base_model_fp16, ADAPTER_SAVE_NAME)
model_merged = model_merged.merge_and_unload()
print("Adapter merged!")

print(f"Pushing to '{HF_REPO_ID}':")
model_merged.push_to_hub(HF_REPO_ID)
tokenizer.push_to_hub(HF_REPO_ID)
print("DONE!!")