In [1]:
import os
import torch

print(f"PyTorch Version: {torch.__version__}")
print(f"PyTorch CUDA Version: {torch.version.cuda}")
print(f"CuDNN Version:        {torch.backends.cudnn.version()}")

PyTorch Version: 2.9.1+cu128
PyTorch CUDA Version: 12.8
CuDNN Version:        91002


In [4]:
import time
import torch
from threading import Thread
from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM

# --- 1. Setup Device ---
# Strix Halo (8060S) works best with float16 on ROCm 6.2+
if torch.cuda.is_available():
    device = "cuda"
    dtype = torch.float16  
    print(f"✅ GPU Detected: {torch.cuda.get_device_name(0)}")
else:
    device = "cpu"
    dtype = torch.float32
    print("⚠️  GPU Not Detected. CPU mode.")

model_id = "gpt2"
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_id = "mistralai/Ministral-3-3B-Instruct-2512"
model_id = "Qwen/Qwen2.5-3B-Instruct"

print(f"\nLoading {model_id}...")

# TRUST_REMOTE_CODE=True is the key fix here
tokenizer = AutoTokenizer.from_pretrained(
    model_id, 
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=dtype,    # Standard HF uses torch_dtype, but some custom models prefer dtype
    trust_remote_code=True, # Allow the model to define its own config class
    device_map=device       # Auto-moves to GPU
)

# --- 3. Run Inference ---
messages = [
    {"role": "user", "content": "Tell me a short story."}
]

# Apply Mistral's chat template
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# Tokenize and move to device
inputs = tokenizer(prompt, return_tensors="pt").to(device)

generation_kwargs = dict(
    inputs=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    streamer=streamer,
    max_new_tokens=300,    
    do_sample=True,
    temperature=0.7,
    pad_token_id=tokenizer.eos_token_id
)

print(f"\nPrompt: {messages[0]['content']}")
print("-" * 30)

t0 = time.time()
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

# --- 4. Stream Output ---
generated_text = ""
first_token_received = False
ttft = 0

for new_text in streamer:
    if not first_token_received:
        ttft = time.time() - t0
        first_token_received = True
        print(new_text, end="", flush=True)
    else:
        print(new_text, end="", flush=True)
    generated_text += new_text

t_end = time.time()

# --- 5. Stats ---
total_new_tokens = len(tokenizer.encode(generated_text))
decoding_time = t_end - (t0 + ttft)

print("\n" + "-" * 30)
print(f"Time to First Token: {ttft:.4f} s")
if decoding_time > 0:
    print(f"Generation Speed:    {(total_new_tokens-1)/decoding_time:.2f} tokens/sec")
print(f"Total Tokens:        {total_new_tokens}")

✅ GPU Detected: NVIDIA GeForce RTX 4070 Ti

Loading Qwen/Qwen2.5-3B-Instruct...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Prompt: Tell me a short story.
------------------------------
Once upon a time, in a small village nestled between rolling hills and whispering forests, there lived an old storyteller named Gideon. Gideon was known far and wide for his tales that could make the bravest heart tremble or bring tears to the eyes of the meekest soul.

One day, as Gideon sat by the fireplace in his cozy cottage, he noticed a young girl named Elara wandering lost in the woods near the village. She had a kind face and bright eyes, but she looked frightened and uncertain. Gideon, feeling the call of his calling, approached her gently.

"Elara," he said softly, "why do you stray so far from the village today?"

Elara hesitated before speaking, her voice barely above a whisper. "I'm looking for my grandmother. She disappeared a few days ago, and I've been searching everywhere."

Moved by her plight, Gideon decided to help her. He shared stories of his own adventures and the wonders he had encountered throughout