In [1]:
import math
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
# ─── SETTINGS ────────────────────────────────────────────────────────────────
MODEL_NAME = "meta-llama/Llama-3.2-1B"

In [8]:
def load_model_and_tokenizer(model_name: str):
    """
    TODO:
      - Load tokenizer & model from `model_name`
      - Move model to GPU if available, choose float16 for CUDA else float32
      - Set model to eval mode
      - Return (tokenizer, model, device)
    """
    raise NotImplementedError


In [4]:
def select_text():
    """
    TODO:
      - Pick or paste a ~100-token passage (e.g. a Wiki snippet)
      - Return (text_str)
    """
    raise NotImplementedError

In [5]:
def tokenize_with_labels(tokenizer, text: str):
    """
    TODO:
      - Tokenize `text` with return_tensors="pt"
      - Prepare inputs and set labels = input_ids
      - Return (inputs_dict, input_len)
    """
    raise NotImplementedError

In [6]:
def compute_peak_memory_and_loss(model, inputs, device):
    """
    TODO:
      - Reset peak memory stats if CUDA
      - Run model(**inputs) under torch.no_grad()
      - Sync CUDA if needed
      - Retrieve peak memory via torch.cuda.max_memory_allocated (in MiB)
      - Return (peak_mem_mib, loss_value)
    """
    raise NotImplementedError

In [7]:
def compute_perplexity(loss: float):
    """
    TODO:
      - Compute and return math.exp(loss)
    """
    raise NotImplementedError

In [None]:
def start():
    # 1. Load
    tokenizer, model, device = load_model_and_tokenizer(MODEL_NAME)
    print(f"Using device: {device}")

    # 2. Select & tokenize
    text = select_text()
    inputs, input_len = tokenize_with_labels(tokenizer, text)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    print(f"Tokenized length: {input_len}")

    # 3. Measure peak memory & loss
    peak_mem, loss = compute_peak_memory_and_loss(model, inputs, device)
    print(f"Peak GPU memory: {peak_mem:.1f} MiB")

    # 4. Compute perplexity
    ppl = compute_perplexity(loss)
    print(f"Next-token perplexity: {ppl:.3f}")


In [None]:
start()