In [None]:
!pip install torch transformers peft accelerate ipywidgets

In [None]:
import os
from pathlib import Path

ds_name = "surfing"
finetuned_format = "chat"

FINETUNED_ADAPTER_DIR = f"outputs/granite-4.0-micro-raft-peft-{ds_name}/checkpoint-127"
MERGED_MODEL_DIR = f"outputs/granite-4.0-micro-raft-full"

print("Using fine-tuned adapter directory:", FINETUNED_ADAPTER_DIR)
print("Fine-tuned model format:", finetuned_format)

if not Path(FINETUNED_ADAPTER_DIR).is_dir():
    raise FileNotFoundError(f"Directory not found: {FINETUNED_ADAPTER_DIR}")


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel

BASE_MODEL_ID = "ibm-granite/granite-4.0-micro"

print("Loading base model…")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
)

print("Attaching LoRA adapter from:", FINETUNED_ADAPTER_DIR)
peft_model = PeftModel.from_pretrained(base_model, FINETUNED_ADAPTER_DIR)

print("Merging LoRA adapter into base model…")
merged_model = peft_model.merge_and_unload()
merged_model.eval()

print("Saving merged model to:", MERGED_MODEL_DIR)
Path(MERGED_MODEL_DIR).mkdir(parents=True, exist_ok=True)
merged_model.save_pretrained(MERGED_MODEL_DIR)
tokenizer.save_pretrained(MERGED_MODEL_DIR)

print("Done. Merged model + tokenizer are ready for vLLM / deployment.")

In [None]:
gen_pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

prompt = (
    "You are a helpful assistant answering questions about surfing.\n\n"
    "Question: What is the Snurfer and how is it related to surfing?\n"
    "Answer:"
)

out = gen_pipe(
    prompt,
    max_new_tokens=128,
    do_sample=False,
)
print(out[0]["generated_text"])