In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def create_model(
    lm_model_name = "google/gemma-3-4b-it", 
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
):
    tokenizer = AutoTokenizer.from_pretrained(lm_model_name)
    model = AutoModelForCausalLM.from_pretrained(
        lm_model_name,
        device_map="auto",
    ).to(device).eval()

    return model, tokenizer

In [3]:
def lm_template(
    text: str, 
    system_prompt: str = "You are a helpful assistant.", 
):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_prompt}]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": text}]
        }
    ]

In [4]:
@torch.inference_mode()
def generate(
    prompt, 
    tokenizer, 
    model, 
    max_new_tokens: int = 256, 
    temperature: float = 1,
):
    inputs = tokenizer.apply_chat_template(
        prompt, 
        add_generation_prompt=True, 
        tokenize=True,
        return_dict=True, 
        return_tensors="pt",
    )
    inputs = {
        k: (
            v.to(model.device, dtype=model.dtype)
            if v.dtype.is_floating_point else v.to(model.device)
        )
        for k, v in inputs.items()
    }

    input_len = inputs["input_ids"].shape[-1]

    max_len = int(model.config.text_config.max_position_embeddings)
    if input_len > max_len:
        raise ValueError(
            f"Input length {input_len} exceeds maximum allowed length of {max_len} tokens."
        )

    generation = model.generate(
        **inputs, 
        max_new_tokens=max_new_tokens, 
        do_sample=True, 
        temperature=temperature, 
    )
    generation = generation[0][input_len:]

    response = tokenizer.decode(
        generation, 
        skip_special_tokens=True
    )

    return response

In [5]:
model, tokenizer = create_model(
    lm_model_name="google/gemma-3-4b-it"
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.14s/it]


In [6]:
test_cases = [
    "Hello! How are you today?", 
    "What is 123 multiplied by 456?",
    "What's the weather like today?", 
    "Calculate (987 + 654) * 321 / 2",
]

In [7]:
print("Start Demo!\n")

for i, query in enumerate(test_cases, 1):
    print(f"Test Case ({i}) {'=' * 50}")
    print(f"user input: {query}")

    prompt = lm_template(
        text=query
    )
    response = generate(
        prompt=prompt, 
        tokenizer=tokenizer, 
        model=model, 
    )
    print(f"Model response: {response}\n{'=' * 64}\n")

print("\nDemo Completed!")

Start Demo!

user input: Hello! How are you today?
Model response: Hello there! I’m doing well, thank you for asking! As an AI, I don’t really *feel* in the same way humans do, but my systems are running smoothly and I’m ready to help you with whatever you need. 😊 

How are *you* doing today? Is there anything I can assist you with?

user input: What is 123 multiplied by 456?
Model response: 123 multiplied by 456 is 56,088.

Here's how to calculate that:

123 * 456 = (123 * 400) + (123 * 50) + (123 * 6)
          = 49200 + 6150 + 738
          = 56088


user input: What's the weather like today?
Model response: Please tell me your location! I need to know where you are to give you an accurate weather forecast. 😊 

For example, you could tell me:

*   Your city and state (e.g., "London, England")
*   Your zip code (e.g., "90210")

user input: Calculate (987 + 654) * 321 / 2
Model response: Okay, let's break this down step-by-step:

1. **Calculate the sum in parentheses:** 987 + 654 = 16