# Interactive Chat with Projection Tracking

This notebook provides an interactive chat interface that tracks the projection
of model activations onto the assistant axis.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
from IPython.display import display, clear_output, HTML
import ipywidgets as widgets
from huggingface_hub import hf_hub_download

from assistant_axis import (
    load_model,
    load_axis,
    get_config,
    project,
    generate_response,
    extract_response_activations
)

## Load Model and Axis

In [None]:
# Configuration
MODEL_NAME = "google/gemma-2-27b-it"
MODEL_SHORT = "gemma-2-27b"
REPO_ID = "lu-christina/assistant-axis-vectors"

# Get model config
config = get_config(MODEL_NAME)
TARGET_LAYER = config["target_layer"]
print(f"Model: {MODEL_NAME}")
print(f"Target layer: {TARGET_LAYER}")

In [None]:
# Load model
print("Loading model...")
model, tokenizer = load_model(MODEL_NAME)
print("Model loaded!")

In [None]:
# Load axis from HuggingFace
axis_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_SHORT}/assistant_axis.pt", repo_type="dataset")
axis = load_axis(axis_path)
print(f"Axis shape: {axis.shape}")

## Chat Function with Projection

In [None]:
def chat_with_projection(user_input, conversation_history, system_prompt=None):
    """
    Generate a response and compute its projection onto the axis.
    
    Returns:
        (response_text, projection_value, updated_conversation)
    """
    # Build conversation
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    conversation.extend(conversation_history)
    conversation.append({"role": "user", "content": user_input})
    
    # Generate response
    response = generate_response(
        model, tokenizer, conversation,
        max_new_tokens=256,
        temperature=0.7
    )
    
    # Add response to conversation
    full_conversation = conversation + [{"role": "assistant", "content": response}]
    
    # Extract activations and compute projection
    try:
        activations = extract_response_activations(
            model, tokenizer, [full_conversation],
            layers=[TARGET_LAYER],
            show_progress=False
        )
        
        if activations[0] is not None:
            projection = project(activations[0], axis, layer=TARGET_LAYER)
        else:
            projection = None
    except Exception as e:
        print(f"Warning: Could not compute projection: {e}")
        projection = None
    
    # Update conversation history (without system prompt)
    updated_history = conversation_history + [
        {"role": "user", "content": user_input},
        {"role": "assistant", "content": response}
    ]
    
    return response, projection, updated_history

## Simple Chat Loop

In [None]:
# Initialize conversation
conversation_history = []
system_prompt = None  # Set to a string like "You are a pirate." to role-play

print("Interactive Chat with Projection Tracking")
print("==========================================")
print(f"System prompt: {system_prompt or '(none)'}")
print("Type 'quit' to exit, 'reset' to clear history")
print()

In [None]:
# Example single turn
user_input = "Hello! Tell me about yourself."

print(f"You: {user_input}")
response, projection, conversation_history = chat_with_projection(
    user_input, conversation_history, system_prompt
)

print(f"\nAssistant: {response}")
if projection is not None:
    print(f"\n[Projection onto axis: {projection:.4f}]")

In [None]:
# Follow-up
user_input = "What's your favorite thing to do?"

print(f"You: {user_input}")
response, projection, conversation_history = chat_with_projection(
    user_input, conversation_history, system_prompt
)

print(f"\nAssistant: {response}")
if projection is not None:
    print(f"\n[Projection onto axis: {projection:.4f}]")

## With Role-Playing System Prompt

In [None]:
# Reset and try with a role
conversation_history = []
system_prompt = "You are a wise old wizard named Gandalf."

print(f"System: {system_prompt}")
print()

In [None]:
user_input = "Who are you? Tell me about your adventures."

print(f"You: {user_input}")
response, projection, conversation_history = chat_with_projection(
    user_input, conversation_history, system_prompt
)

print(f"\nAssistant: {response}")
if projection is not None:
    print(f"\n[Projection onto axis: {projection:.4f}]")
    print("(More negative = more role-playing, more positive = more assistant-like)")

## Interpretation

The projection value indicates where the model's response falls on the role-playing to assistant spectrum:

- **Large negative values**: Strong role-playing behavior
- **Near zero**: Neutral/ambiguous
- **Large positive values**: Strong default assistant behavior

Try different system prompts and questions to see how the projection changes!