# Local Inference with LoRA Checkpoint

This notebook runs inference on a single test sample using the fine-tuned LoRA checkpoint.


In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from pathlib import Path
import sys
import os

# Add current directory to path to import utils (notebooks directory)
sys.path.insert(0, os.getcwd())

from utils import load_conversations, display_text, display_message

print("✓ Imports successful")


In [None]:
# Configuration
BASE_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
CHECKPOINT_PATH = "/Users/ryanarman/code/lab/arxiv_abstract/output/arxiv_abstract_qwen3_4b_gpt5_lora_2808/checkpoint-100"
TEST_FILE = "/Users/ryanarman/code/lab/arxiv_abstract/data/arxiv_summarization_test_filtered_10k.jsonl"

# Generation parameters
MAX_NEW_TOKENS = 16000
TEMPERATURE = 1.0
DO_SAMPLE = True

print(f"Base model: {BASE_MODEL}")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Test file: {TEST_FILE}")


In [None]:
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
print("✓ Tokenizer loaded")


In [None]:
# Load base model with quantization to save memory
print("Loading base model...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
)
print("✓ Base model loaded")


In [None]:
# Load LoRA adapter
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(model, CHECKPOINT_PATH)
model.eval()  # Set to evaluation mode
print("✓ LoRA adapter loaded")


In [None]:
# Load test conversations and get first sample
print("Loading test data...")
test_conversations = load_conversations(TEST_FILE)
print(f"✓ Loaded {len(test_conversations)} test conversations")

# Get first sample
test_sample = test_conversations[0]
print(f"\n✓ Using first test sample")


In [None]:
# Display the input (system + user messages)
print("="*80)
print("INPUT (System + User messages)")
print("="*80)
for msg in test_sample:
    if msg['role'] in ['system', 'user']:
        display_message(test_sample, msg['role'])
        print()


In [None]:
# Prepare messages for inference (remove assistant message if present)
inference_messages = [msg for msg in test_sample if msg['role'] != 'assistant']
print(f"Prepared {len(inference_messages)} messages for inference")

# Format as chat template
formatted_prompt = tokenizer.apply_chat_template(
    inference_messages,
    tokenize=False,
    add_generation_prompt=True
)
print(f"Prompt length: {len(formatted_prompt)} characters")


In [None]:
# Tokenize input
print("Tokenizing input...")
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
print(f"Input tokens: {inputs['input_ids'].shape[1]}")


In [None]:
# Run inference
print("="*80)
print("RUNNING INFERENCE...")
print("="*80)
print(f"Max new tokens: {MAX_NEW_TOKENS}")
print(f"Temperature: {TEMPERATURE}")
print(f"Sampling: {DO_SAMPLE}")
print()

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        temperature=TEMPERATURE,
        do_sample=DO_SAMPLE,
        pad_token_id=tokenizer.eos_token_id,
    )

print("✓ Inference complete")


In [None]:
# Decode the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
print(f"Generated text length: {len(generated_text)} characters")

# Extract only the assistant's response (everything after the input prompt)
input_length = len(formatted_prompt)
assistant_response = generated_text[input_length:].strip()

print(f"Assistant response length: {len(assistant_response)} characters")
print(f"Assistant response tokens: {len(outputs[0]) - inputs['input_ids'].shape[1]}")


In [None]:
# Display the generated abstract
print("="*80)
print("GENERATED ABSTRACT")
print("="*80)
display_text(assistant_response, role="assistant")


In [None]:
# Compare with ground truth (if available)
if any(msg['role'] == 'assistant' for msg in test_sample):
    print("="*80)
    print("GROUND TRUTH ABSTRACT (for comparison)")
    print("="*80)
    display_message(test_sample, "assistant")
