# Lesson 3, Exercise 3: Tuning Speculation with GPT-2 - Impact of Draft Length (`K`)

**Goal:**
The purpose of this exercise is to implement a simplified speculative decoding loop using readily available GPT-2 models and to investigate how a key hyperparameter – the draft length `K` (the number of tokens speculatively generated by the draft model) – influences the overall efficiency of the generation process. You will measure performance in terms of both wall-clock time and the number of computationally expensive forward passes through the target model.

## 2. Imports and Configuration

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import pandas as pd

TARGET_MODEL_NAME = "gpt2-medium"
DRAFT_MODEL_NAME = "gpt2" # Standard small GPT-2

PROMPT_TEXT = "Artificial intelligence is rapidly transforming our world by"
MAX_TOTAL_TOKENS_TO_GENERATE = 100 # Total new tokens to generate for each run
K_VALUES_TO_TEST = [1, 2, 3, 4, 5, 8]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3. Load Models and Tokenizer

In [None]:
tokenizer = None
target_model = None
draft_model = None

### TODO: Load the tokenizer (should be same for gpt2 and gpt2-medium)
tokenizer = None
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
pass

### TODO: Load the Target Model (gpt2-medium) and move to device
target_model = None
target_model.eval()
pass

### TODO: Load the Draft Model (gpt2) and move to device
draft_model = None
draft_model.eval()
pass

if not all([tokenizer, target_model, draft_model]):
    raise ValueError("One or more models/tokenizer failed to load. Check TODOs.")
print("Models and tokenizer loaded successfully.")

initial_prompt_ids = tokenizer.encode(PROMPT_TEXT, return_tensors="pt").to(device)

## 4. Baseline Autoregressive Generation (Using Target Model)

In [None]:
print("\n--- Running Baseline Autoregressive Generation ---")
baseline_generated_ids = initial_prompt_ids.clone()
baseline_target_passes = 0
start_time_baseline = time.time()

with torch.no_grad():
    for _ in range(MAX_TOTAL_TOKENS_TO_GENERATE):
        # TODO: write generation loop for generating 1 token at a time
        # 1.	Run the target model on the current sequence of generated tokens
        # 2.	Extract the logits for the next token position. E.g outputs.logits[:, -1, :]
        # 3.	Select the most likely next token using argmax
        # 4.	Append this next token to the current sequence (in 1st iteration, it is just the prompt)
        # 5.	Increment the count of target model passes
        next_token_id = None # Implement above logic
        if next_token_id.item() == tokenizer.eos_token_id:
            break
end_time_baseline = time.time()
baseline_time = end_time_baseline - start_time_baseline
baseline_output_text = tokenizer.decode(baseline_generated_ids[0], skip_special_tokens=True)

print(f"Baseline Time: {baseline_time:.4f} s")
print(f"Baseline Target Model Passes: {baseline_target_passes}")
print(f"Baseline Output: {baseline_output_text}")

## 5. Speculative Decoding Experiment

In [None]:
all_experiment_results = []

if target_model and draft_model and tokenizer and initial_prompt_ids is not None:
    print("\n--- Running Speculative Decoding Experiments for different K values ---")

    for K_val in K_VALUES_TO_TEST:
        print(f"\nStarting Speculative Decoding with K = {K_val}")
        current_context_ids_spec = initial_prompt_ids.clone()
        count_total_new_tokens_generated_spec = 0
        count_target_model_passes_spec = 0
        count_total_draft_tokens_verified_correctly = 0
        count_verification_steps_spec = 0 # How many times the target model was called to verify

        # Timing Start
        if device.type == 'cuda': torch.cuda.synchronize()
        time_start_speculative = time.perf_counter()

        with torch.no_grad():
            while count_total_new_tokens_generated_spec < MAX_TOTAL_TOKENS_TO_GENERATE:
                if current_context_ids_spec.shape[1] >= target_model.config.max_position_embeddings - K_val -1: # Ensure space for K draft + 1 target
                    print(f"K={K_val}: Approaching model's maximum length. Stopping early to prevent overflow.")
                    break

                # --- Draft Phase ---
                draft_tokens_generated_ids = torch.empty((1,0), dtype=torch.long, device=device)
                temp_context_for_draft = current_context_ids_spec.clone()
                for _ in range(K_val):
                    if temp_context_for_draft.shape[1] >= draft_model.config.max_position_embeddings:
                        break # Draft model also has a max length
                    ### TODO: Generate one token using the 'draft_model' from 'temp_context_for_draft'.
                    # Append it to 'draft_tokens_generated_ids' and update 'temp_context_for_draft'.
                    # Break if EOS is generated by the draft model.
                    pass # Replace with draft generation step
                
                num_tokens_actually_drafted = draft_tokens_generated_ids.shape[1]
                if num_tokens_actually_drafted == 0:
                    # If no tokens could be drafted (e.g., context too long or K=0 was somehow missed),
                    # fall back to a single autoregressive step with the target model to make progress.
                    # ... (Implement target model single step here, update counts, and 'continue' the outer while loop) ...
                    print(f"K={K_val}: No tokens drafted, performing target step. (Implement fallback)")
                    # For starter, just break to prevent infinite loop if not implemented
                    break 

                # --- Verification Phase ---
                verification_input_ids = torch.cat([current_context_ids_spec, draft_tokens_generated_ids], dim=1)
                
                ### TODO: Pass 'verification_input_ids' to the 'target_model' to get its verification outputs.
                # Increment 'count_target_model_passes_spec'.
                # Increment 'count_verification_steps_spec'.
                pass # Replace with target model call
                target_verification_outputs_logits = torch.rand(1, verification_input_ids.shape[1], tokenizer.vocab_size, device=device) # Dummy logits
                count_target_model_passes_spec +=1; count_verification_steps_spec +=1 # Dummy increment

                ### TODO: Extract the target model's preferred tokens for each of the 'num_tokens_actually_drafted' positions.
                # These are the tokens the target model *would have chosen* if it were generating autoregressively at those steps.
                # You'll need to look at the logits from 'target_verification_outputs.logits' at the correct indices.
                # Store these preferred token IDs in a list or tensor called 'target_preferred_tokens_at_draft_positions'.
                # Careful with indexing: logits at index `t` predict token `t+1`.
                target_preferred_tokens_at_draft_positions = draft_tokens_generated_ids.squeeze().tolist() # Dummy: assumes target agrees perfectly

                # --- Acceptance Logic ---
                num_matched_tokens = 0
                ### TODO: Compare 'draft_tokens_generated_ids' with 'target_preferred_tokens_at_draft_positions' token by token.
                # Count how many tokens match consecutively from the beginning. Store this in 'num_matched_tokens'.
                # for i in range(num_tokens_actually_drafted):
                #    if ___ == ___:
                #        num_matched_tokens += 1
                #    else: break
                num_matched_tokens = num_tokens_actually_drafted # Dummy: assume all match
                count_total_draft_tokens_verified_correctly += num_matched_tokens

                accepted_tokens_for_this_step = draft_tokens_generated_ids[0, :num_matched_tokens]

                # Determine the next token: either the one from target model at mismatch, or one beyond matched sequence.
                if num_matched_tokens < num_tokens_actually_drafted:
                    # Mismatch occurred: use target's preferred token at the point of mismatch.
                    # next_token_id_after_match = torch.tensor([[target_preferred_tokens_at_draft_positions[num_matched_tokens]]], device=device)
                    pass # Replace, placeholder below
                    next_token_id_after_match = torch.randint(0, tokenizer.vocab_size, (1,1), device=device) # Dummy next token
                else: # All K draft tokens matched
                    # Try to get one more token from the target model (the (K+1)th token).
                    # This requires looking at the logits from 'target_verification_outputs' for the position *after* the K draft tokens.
                    # verification_idx_for_bonus_token = current_context_ids_spec.shape[1] + num_tokens_actually_drafted -1
                    # if verification_idx_for_bonus_token < target_verification_outputs.logits.shape[1]:
                    #    next_token_id_after_match = torch.argmax(target_verification_outputs.logits[0, verification_idx_for_bonus_token, :]).unsqueeze(0).unsqueeze(0)
                    # else: # Cannot get bonus token, just proceed with matched ones (or handle as error/fallback)
                    #    next_token_id_after_match = None 
                    pass # Replace, placeholder below
                    next_token_id_after_match = torch.randint(0, tokenizer.vocab_size, (1,1), device=device) # Dummy next token
                
                if next_token_id_after_match is not None:
                     accepted_tokens_for_this_step = torch.cat([accepted_tokens_for_this_step, next_token_id_after_match.squeeze()])

                if accepted_tokens_for_this_step.numel() == 0:
                    print(f"K={K_val}: No tokens were accepted in this step. Performing a single target model step to advance.")
                    # Fallback: Generate one token with target model to ensure progress.
                    # ... (Implement target model single step here, update counts, and 'continue' the outer while loop) ...
                    break # Simplified for starter, real fallback needed
                
                current_context_ids_spec = torch.cat([current_context_ids_spec, accepted_tokens_for_this_step.unsqueeze(0) if accepted_tokens_for_this_step.dim() == 1 else accepted_tokens_for_this_step], dim=1)
                num_total_new_tokens_generated += accepted_tokens_for_this_step.shape[0] if accepted_tokens_for_this_step.dim() == 1 else accepted_tokens_for_this_step.shape[1]

                if tokenizer.eos_token_id in accepted_tokens_for_this_step:
                    print(f"K={K_val}: EOS token generated.")
                    break
        
        # Timing End
        if device.type == 'cuda': torch.cuda.synchronize()
        time_end_speculative = time.perf_counter()

        duration_speculative = time_end_speculative - time_start_speculative
        text_output_speculative = tokenizer.decode(current_context_ids_spec[0], skip_special_tokens=True)
        
        # TODO: Calculate average accepted tokens per verification step for this K_val run
        # This is the sum of tokens accepted from draft (num_matched) + the one corrective/bonus from target, divided by verification calls
        # A simpler proxy: (num_total_new_tokens_generated / count_target_model_passes_spec)
        avg_accepted_this_K = None # Replace with actual calculation

        all_experiment_results.append({
            "K": K_val,
            "Time (s)": duration_speculative,
            "Target Passes": count_target_model_passes_spec,
            "Avg Accepted Tokens per Verification": avg_accepted_this_K,
            "Output Text Sample": text_output_speculative[:150] + "..."
        })
        print(f"K={K_val}: Time={duration_speculative:.4f}s, Target Passes={count_target_model_passes_spec}, Avg Accepted={avg_accepted_this_K:.2f}")
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
else:
    print("Skipping speculative decoding experiment as models/tokenizer were not properly loaded.")

## 6. Display Final Results Summary

In [None]:
df_spec_results = pd.DataFrame(speculative_decoding_results_list)
print("\n\n--- Speculative Decoding Experiment Results Summary ---")
print(df_spec_results.to_string())

## 7. Analysis and Discussion

Based on the 'Speculative Decoding Experiment Results Summary' table:

1.  **Impact of `K` on Target Model Passes:**
    *   TODO: Analyze how varying `K` affected the total number of forward passes made by the `gpt2-medium` (Target Model) compared to the baseline.

2.  **Impact of `K` on Wall-Clock Time:**
    *   TODO: Discuss the effect of `K` on the overall wall-clock generation time. Was there an apparent optimal value of `K` for this `gpt2-medium`/`gpt2` pairing in your setup? Explain why time might increase or decrease.

3.  **Impact of `K` on Average Accepted Tokens:**
    *   TODO: Explain how the average number of tokens accepted per target model verification step changed as `K` varied. What does this metric tell you about the efficiency of the speculation?

4.  **Trade-offs of `K`:**
    *   TODO: Conclude by summarizing the trade-offs involved in selecting a small versus a large value for `K`. Consider factors like draft model accuracy, overhead of draft generation, and the probability of successful verification.

5.  **Comparison to Baseline:**
    *   TODO: How did the best speculative decoding configuration compare to the baseline autoregressive generation in terms of target model passes and wall-clock time?