In [11]:
import time
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda"), "CUDA"
    elif torch.backends.mps.is_available():
        return torch.device("mps"), "MPS"
    else:
        return torch.device("cpu"), "CPU"

def print_device_info(device_type):
    if device_type == "CUDA":
        num_gpus = torch.cuda.device_count()
        print(f"Using CUDA. Number of GPUs available: {num_gpus}")
    elif device_type == "MPS":
        print("Using MPS (Apple Silicon GPU).")
    else:
        num_cpus = torch.get_num_threads()
        print(f"Using CPU. Number of CPUs available: {num_cpus}")

In [12]:
device, device_type = get_device()
print_device_info(device_type)

Using MPS (Apple Silicon GPU).


In [13]:
# Measure start time
start_time = time.time()

# Initialize tokenizer and model with the specified device
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [14]:
print(tokenizer.model_max_length)

1024


In [15]:
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [16]:
# Generate text based on a prompt
prompt = "When is the best time to visit Yosemite"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
generated_outputs = model.generate(**inputs, max_length=50, num_return_sequences=1)
# Calculate and print the time taken
end_time = time.time()
print(f"Time taken to generate response: {end_time - start_time:.2f} seconds")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Time taken to generate response: 3.05 seconds


In [17]:
# Decode and print the generated text
generated_text = tokenizer.decode(generated_outputs[0], skip_special_tokens=True)
print(generated_text)

When is the best time to visit Yosemite?

The best time to visit Yosemite is when you're ready to go.

The best time to visit Yosemite is when you're ready to go.

The best time to visit Yosemite is
