# Run Gemma 3 (4B) Inference

In [None]:
import torch
from unsloth import FastLanguageModel
import os

# Setup model directory and Unsloth parameters
model_dir = '../Resources/gemma-3-4b-it' # Base model
lora_persona_path = "outputs_persona_lora" # Path where Persona LoRA was saved (if saved)
lora_task_path = "outputs_task_lora" # Path where Task LoRA was saved (if saved)

max_seq_length = 2048
dtype = None
load_in_4bit = True

# Load base model with Unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_dir,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    local_files_only=True
)

# --- Load LoRA Adapters (Example of loading one or composing) ---
# As per llm.md, LoRAs are composed. Unsloth might support merging or sequential loading.
# This example shows loading the task LoRA. For composition, refer to Unsloth/PEFT docs.

# Option 1: Load a single LoRA adapter (e.g., the task-specific one)
# Check if the task LoRA adapter exists from the training notebook
if os.path.exists(os.path.join(lora_task_path, "adapter_model.safetensors")): # Unsloth saves PEFT adapters here
    print(f"Loading Task LoRA adapter from: {lora_task_path}")
    model.load_adapter(lora_task_path, adapter_name="task_continue") # PEFT's load_adapter
    # If you want this adapter to be active by default:
    # model.set_adapter("task_continue")
    print("Task LoRA adapter loaded.")
elif os.path.exists(os.path.join(lora_persona_path, "adapter_model.safetensors")):
    print(f"Task LoRA not found, trying to load Persona LoRA adapter from: {lora_persona_path}")
    model.load_adapter(lora_persona_path, adapter_name="persona")
    # model.set_adapter("persona")
    print("Persona LoRA adapter loaded.")
else:
    print("No LoRA adapters found at specified paths. Using base model for inference.")

# For dynamic composition (Base + Persona + Task) as in llm.md,
# you might need to merge LoRAs or use more advanced PEFT features.
# Example: model.add_weighted_adapter(["persona", "task_continue"], [0.5, 0.5], "combined_adapter")
# model.set_adapter("combined_adapter")
# This is highly dependent on the PEFT/Unsloth capabilities for composition.

model.eval()

# Prompt and generate
# Unsloth provides a FastLanguageModel.generate function, or you can use the standard HF generate
# For simple text generation:
alpaca_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{}

### Response:
{}"""

inputs = tokenizer(
[
    alpaca_prompt.format(
        "What is the capital of France?", # instruction
        "", # output - leave this blank for generation!
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
result_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result_texts[0])

# Example for continuation:
prompt_continuation = "The weather today is surprisingly warm, and the sky is clear blue. I think I will"
inputs_continuation = tokenizer(prompt_continuation, return_tensors="pt").to("cuda")
outputs_continuation = model.generate(**inputs_continuation, max_new_tokens=50, use_cache=True)
result_continuation = tokenizer.decode(outputs_continuation[0], skip_special_tokens=True)
print("\n--- Continuation Example ---")
print(result_continuation)

# Apple ML Compute Inference (Placeholder)

In [None]:
import mlx.core as mx
from mlx_lm.utils import load, generate # Assuming these are available
import os

# --- Apple MLX Inference with LoRA ---
print("Setting up Apple MLX inference...")

mlx_model_path = '../Resources/gemma-3-4b-it-mlx' # Path to MLX-converted Gemma model
# mlx_lora_adapter_path = "mlx_lora_adapter_persona/adapters.safetensors" # Path to saved MLX LoRA adapter if saved from training

if not os.path.exists(mlx_model_path):
    print(f"MLX model not found at {mlx_model_path}. Please convert the HF model to MLX format.")
    print("Skipping MLX inference part.")
else:
    # Load the MLX model and tokenizer
    try:
        # The `load` function in mlx-lm might handle adapters automatically if they are in a standard location
        # or might require an explicit parameter.
        # model, tokenizer = load(mlx_model_path, adapter_path=mlx_lora_adapter_path if os.path.exists(mlx_lora_adapter_path) else None)
        model, tokenizer = load(mlx_model_path) # Load base model first
        print("MLX Model loaded.")

        # TODO: Add code to load LoRA adapter if mlx-lm requires explicit loading for inference
        # and if an adapter was saved from the MLX training part.
        # For example:
        # if os.path.exists(mlx_lora_adapter_path):
        #     model.load_weights(mlx_lora_adapter_path, strict=False) # Example, API may vary
        #     print(f"Loaded MLX LoRA adapter from {mlx_lora_adapter_path}")
        # else:
        #     print("MLX LoRA adapter not found, using base MLX model for inference.")

    except Exception as e:
        print(f"Error loading MLX model or adapter: {e}.")
        model = None

    if model:
        # Prompt and generate using mlx-lm's generate function
        prompt_text = "Hello, this is a test with MLX. Please complete this sentence:"
        print(f"Prompt: {prompt_text}")

        # The generate function signature might vary. This is a common pattern.
        response = generate(
            model,
            tokenizer,
            prompt=prompt_text,
            max_tokens=50,
            temp=0.7
        )
        print(f"MLX Generated Response: {response}")

print("MLX inference script part finished.")