In [2]:
# Install required packages if not already installed
# !pip install transformers torch peft trl dataset

In [3]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, LoraConfig
import logging
from typing import List, Dict, Any

  warn(


In [4]:
# Configuration
BASE_MODEL_NAME = "Qwen/Qwen3-4B"  # Original model name
CHECKPOINT_PATH = "./sft_output"    # Path to your trained model
USE_LORA = True                     # Set to True if you used LoRA training
DEVICE_MAP = "auto"                 # Device mapping strategy

# Inference settings
MAX_NEW_TOKENS = 512
TEMPERATURE = 0.7
TOP_P = 0.8
TOP_K = 20
DO_SAMPLE = True

print(f"Base model: {BASE_MODEL_NAME}")
print(f"Checkpoint path: {CHECKPOINT_PATH}")
print(f"Using LoRA: {USE_LORA}")

Base model: Qwen/Qwen3-4B
Checkpoint path: ./sft_output
Using LoRA: True


In [5]:
def load_base_model_and_tokenizer(model_name: str, device_map: str = "auto"):
    """
    Load the base model and tokenizer.
    """
    print(f"Loading base model: {model_name}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        trust_remote_code=True
    )
    
    print("✓ Base model and tokenizer loaded")
    return model, tokenizer


def load_lora_checkpoint(base_model, checkpoint_path: str):
    """
    Load LoRA adapters onto the base model.
    """
    print(f"Loading LoRA adapters from: {checkpoint_path}")
    
    # Check if adapter files exist
    adapter_config_path = os.path.join(checkpoint_path, "adapter_config.json")
    adapter_model_path = os.path.join(checkpoint_path, "adapter_model.safetensors")
    
    if not os.path.exists(adapter_config_path):
        raise FileNotFoundError(f"LoRA adapter config not found at {adapter_config_path}")
    
    if not os.path.exists(adapter_model_path):
        # Try the .bin version
        adapter_model_path = os.path.join(checkpoint_path, "adapter_model.bin")
        if not os.path.exists(adapter_model_path):
            raise FileNotFoundError(f"LoRA adapter weights not found in {checkpoint_path}")
    
    # Load the PEFT model
    model = PeftModel.from_pretrained(base_model, checkpoint_path)
    
    print("✓ LoRA adapters loaded")
    return model


def load_sft_model(base_model_name: str, checkpoint_path: str, use_lora: bool = True, device_map: str = "auto"):
    """
    Load SFT model - handles both LoRA and full fine-tuning.
    """
    if use_lora:
        # Load base model first, then LoRA adapters
        base_model, tokenizer = load_base_model_and_tokenizer(base_model_name, device_map)
        model = load_lora_checkpoint(base_model, checkpoint_path)
        
        # Try to load tokenizer from checkpoint if available
        try:
            checkpoint_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
            tokenizer = checkpoint_tokenizer
            print("✓ Using tokenizer from checkpoint")
        except:
            print("ℹ Using base model tokenizer")
            
    else:
        # Load full fine-tuned model
        from transformers import AutoModelForCausalLM, AutoTokenizer
        model = AutoModelForCausalLM.from_pretrained(
            checkpoint_path, torch_dtype=torch.bfloat16, device_map=device_map, trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
    
    # Set to evaluation mode
    model.eval()
    
    return model, tokenizer

print("✓ Model loading functions defined")


✓ Model loading functions defined


In [6]:
# Load the trained model
print("Loading trained SFT model...")
print("=" * 50)

try:
    model, tokenizer = load_sft_model(
        base_model_name=BASE_MODEL_NAME,
        checkpoint_path=CHECKPOINT_PATH,
        use_lora=USE_LORA,
        device_map=DEVICE_MAP
    )
    
    print("=" * 50)
    print("🎉 Model loaded successfully!")
    print(f"Model type: {type(model).__name__}")
    print(f"Device: {next(model.parameters()).device}")
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    if USE_LORA and hasattr(model, 'print_trainable_parameters'):
        print("\nLoRA parameter info:")
        model.print_trainable_parameters()
    
except Exception as e:
    print(f"❌ Error loading model: {e}")
    print("\nPlease check:")
    print("1. Checkpoint path exists and contains model files")
    print("2. USE_LORA setting matches your training setup")
    print("3. Base model name is correct")
    raise


Loading trained SFT model...
Loading base model: Qwen/Qwen3-4B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

✓ Base model and tokenizer loaded
Loading LoRA adapters from: ./sft_output
✓ LoRA adapters loaded
ℹ Using base model tokenizer
🎉 Model loaded successfully!
Model type: PeftModelForCausalLM
Device: cuda:0
Total parameters: 4,055,498,240
Trainable parameters: 0

LoRA parameter info:
trainable params: 0 || all params: 4,055,498,240 || trainable%: 0.0000


In [16]:
tokens = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "What is the sum of the prime factors of 1024?"},
    ],
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
    enable_thinking=True
)

print(tokens)

tensor([[151644,    872,    198,   3838,    374,    279,   2629,    315,    279,
          10250,   9363,    315,    220,     16,     15,     17,     19,     30,
         151645,    198, 151644,  77091,    198]])


In [17]:
out = model.generate(tokens.cuda(), max_new_tokens=30)

In [18]:
print(tokenizer.decode(out[0]))

<|im_start|>user
What is the sum of the prime factors of 1024?<|im_end|>
<|im_start|>assistant
<think>

</think>

THE SUM OF THE PRIME FACTORS OF 1024 IS 2.<|im_end|>
