# Running Inference on Habana Gaudi (HPU)

This notebook demonstrates how to run inference using the Habana Gaudi processor (HPU) with our GPT model for classification task from Chapter 6.

## 1. Construct model (same as in Chapter 6)


In [None]:
# import necessary libraries
import habana_frameworks.torch # import Habana PyTorch framework first
import torch
import tiktoken

In [None]:
# initialize tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

In [None]:
# model configuration and parameters
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"

BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.0,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

In [None]:
# use functions from previous chapters to download and load the model
from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt

model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)

# add a new output head to the model (same as in Chapter 6)
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=2)

## 2. Load model weights
We use the weights saved when we trained the model in Chapter 6 on spam classification task.

In [None]:
model_state_dict = torch.load("review_classifier.pth", map_location=torch.device("cpu"), weights_only=True) # load weights to CPU to avoid memory issues
model.load_state_dict(model_state_dict) 

## 3. Move model to HPU
You have to have access to HPU!

In [None]:
device = torch.device("hpu")
model.to(device)

## 4. Classify reviews function
Same as in Chapter 6

In [None]:
def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
    model.eval()

    # Prepare inputs to the model
    input_ids = tokenizer.encode(text)
    supported_context_length = model.pos_emb.weight.shape[0]
    # Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake
    # It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)

    # Truncate sequences if they too long
    input_ids = input_ids[:min(max_length, supported_context_length)]

    # Pad sequences to the longest sequence
    input_ids += [pad_token_id] * (max_length - len(input_ids))
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension

    # Model inference
    with torch.no_grad():
        logits = model(input_tensor)[:, -1, :]  # Logits of the last output token
    predicted_label = torch.argmax(logits, dim=-1).item()

    # Return the classified result
    return "spam" if predicted_label == 1 else "not spam"

## 5. Test the model

In [None]:
text_1 = (
    "You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award."
)

print(classify_review(
    text_1, model, tokenizer, device, max_length=120
))

In [None]:
text_2 = (
    "Hey, just wanted to check if we're still on"
    " for dinner tonight? Let me know!"
)

print(classify_review(
    text_2, model, tokenizer, device, max_length=120
))

It works!
Now let's compare the performance of the model on CPU and HPU.

## 6. Performance comparison
We will use the `time` library to measure the time it takes to classify a review.

In [None]:
import time

# Let's create a couple of spam/not spam messages
messages = [
    "You are a winner you have been specially selected to receive $1000 cash or a $2000 award.", # Spam
    "Please send me your bank details so I can transfer you the money.", # Spam
    "I'm going to the gym now, want to join?", # Not spam
    "This is not spam, it's a test message.", # Not spam
    "Congratulations! You've won a free iPhone! Click here to claim.", # Spam
    "Hey, are we still on for lunch tomorrow?",  # Not spam
    "URGENT! Your account has been compromised. Reset your password now!",  # Spam
    "Meeting rescheduled to 3 PM. Let me know if that works.",  # Not spam
    "Limited time offer! Buy now and get 50% off!",  # Spam
    "Can you review this document and send me your feedback?",  # Not spam
    "FREE MONEY! Click this link to receive your reward.",  # Spam
    "Your order has been shipped. Track it here.",  # Not spam
    "Earn $$$ from home! No experience needed. Sign up today.",  # Spam
    "Reminder: Your dentist appointment is on Monday at 10 AM.",  # Not spam
    "Exclusive deal just for you! Unlock your discount now!",  # Spam
    "Happy Birthday! Hope you have an amazing day!",  # Not spam
    "Claim your prize before it's too late! Act now!",  # Spam
    "Thanks for your help with the project. Appreciate it!",  # Not spam
    "Your PayP@l account needs verification! Click here immediately!",  # Spam
    "Hey, can you send me that report by EOD?",  # Not spam
    "Final warning! Your subscription will be canceled unless you act now!",  # Spam
    "Dinner plans tonight? Let me know.",  # Not spam
    "You have been selected for a special offer! Open now!",  # Spam
    "Let's catch up over coffee next week.",  # Not spam
    "Instant weight loss! See the miracle solution here.",  # Spam
    "Just checking in—how are you doing?",  # Not spam
    "Hurry! Stocks are running out. Order yours today!",  # Spam
    "Can you confirm the schedule for tomorrow?",  # Not spam
    "Dear user, your acc0unt has suspicious activity. Verify now!",  # Spam
    "Great job on the presentation today!",  # Not spam
    "Double your profits in just 7 days! Guaranteed!",  # Spam
    "I'll be late for the meeting, stuck in traffic.",  # Not spam
    "Secret investment opportunity—make millions fast!",  # Spam
    "Let's finalize the contract details this afternoon.",  # Not spam
    "Congratulations, you are the chosen winner of our lottery!",  # Spam
    "Thanks for your help with the budget analysis.",  # Not spam
    "Y0ur p@ckage is d3layed. Cl!ck here to f!x.",  # Spam
    "I’ll send over the revised slides shortly.",  # Not spam
    "Work from home and make $$$ instantly!",  # Spam
    "Can you join the call at 2 PM instead of 3?",  # Not spam
    "Limited seats available! Enroll in our exclusive program today.",  # Spam
    "See you at the event later!",  # Not spam
    "Hurry! This deal won’t last long. Act fast!",  # Spam
    "Your invoice for last month is attached.",  # Not spam
    "Boost your credit score instantly! Click here.",  # Spam
    "Looking forward to our meeting tomorrow.",  # Not spam
    "Your social media account has been hacked! Reset password now!",  # Spam
    "Let’s schedule a team lunch next week.",  # Not spam
    "Easy way to make extra cash online—start today!",  # Spam
    "Can you review the proposal before we submit?",  # Not spam
    "F!nal rem!nder: Update y0ur b@nk details NOW!",  # Spam
    "Thanks for the update on the project.",  # Not spam
    "Your subscription has been successfully renewed.",  # Not spam
    "Meet singles in your area now!",  # Spam
    "I left my laptop at the office, can you bring it?",  # Not spam
    "You are pre-approved for a low-interest loan!",  # Spam
    "Looking forward to your presentation next week!",  # Not spam
    "WIN a brand-new car! Just sign up!",  # Spam
    "Hope you’re feeling better today!",  # Not spam
    "Act fast! This offer expires soon!",  # Spam
    "Thanks for the great conversation earlier.",  # Not spam
    "Your p@ssw0rd will expire soon! Cl!ck here to reset.",  # Spam
    "Don’t forget our dinner plans tonight!",  # Not spam
    "Limited-time deal! Get yours now before it’s gone!",  # Spam
    "Have a safe flight!",  # Not spam
    "FREE investment tips! Join our webinar today!",  # Spam
    "Let me know if you need any help with the project.",  # Not spam
    "This is not a scam! You have won $1,000,000!",  # Spam
    "Excited to see you at the conference!",  # Not spam
    "You won’t believe this shocking weight loss secret!",  # Spam
    "Are you available for a quick call?",  # Not spam
    "Act n0w! Your acc0unt has been compromised!",  # Spam
    "Let’s meet at the usual coffee shop.",  # Not spam
    "Earn p@ssive inc0me with this one simple trick!",  # Spam
    "Thanks for helping me with the move.",  # Not spam
    "Your cl@im has been approved! Cl!ck here to get it.",  # Spam
    "Don’t miss out on this exclusive deal!",  # Spam
    "Let’s touch base later today.",  # Not spam
    "This is your last chance to claim your reward!",  # Spam
    "Great catching up with you yesterday!",  # Not spam
    "Cl@im y0ur refund now! L!mited time offer!",  # Spam
    "Important update regarding your bank account.",  # Spam
    "See you at the meeting in 10 minutes.",  # Not spam
    "Get rich quick with this foolproof method!",  # Spam
    "I’ll send you the details by email.",  # Not spam
    "Unbelievable investment opportunity—act now!",  # Spam
    "Don’t forget about the deadline tomorrow.",  # Not spam
    "Y0ur Netflix account is locked! Verify now!",  # Spam
    "Grab your free sample today!",  # Spam
    "I’ll share the report with you later.",  # Not spam
    "This stock is about to skyrocket! Invest today!",  # Spam
    "Reminder: Submit your expense report by Friday.",  # Not spam
    "Claim your Bitcoin bonus now!",  # Spam
    "XXX WEBSITE XXX", # Spam
    ]

num_messages = len(messages)

# Test on CPU
start_time = time.time()
cpu_results = [classify_review(msg, model, tokenizer, device, max_length=120) for msg in messages]
end_time = time.time()
cpu_time = end_time - start_time
print(f"CPU time: {cpu_time:.2f / num_messages} seconds per message")

# Test on HPU
start_time = time.time()
hpu_results = [classify_review(msg, model, tokenizer, device, max_length=120) for msg in messages]
end_time = time.time()
hpu_time = end_time - start_time
print(f"HPU time: {hpu_time:.2f / num_messages} seconds per message")

# Compare results
print(f"HPU faster by {cpu_time / hpu_time:.2f}x")