In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"  # specify which GPU(s) to be used

import re
import json
import numpy as np
import pandas as pd
import torch

import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from llm_steer import LLM_Steer_Manager
from measure_activation import Measure_Manager

from utils import Config, Data_Manager, model_selection

In [None]:
model_ids = [
    "google/gemma-3-270m-it", 
    "Qwen/Qwen3-0.6B", 
    "meta-llama/Llama-3.2-1B-Instruct", 
    "Qwen/Qwen3-1.7B", 
    "meta-llama/Llama-3.2-3B-Instruct", 
    "google/gemma-3-4b-it",
    "meta-llama/Llama-3.1-8B-Instruct",
    "google/gemma-3-12b-it", 
    "Qwen/Qwen3-14B", 
    "google/gemma-3-27b-it",
    "Qwen/Qwen3-32B",
    "meta-llama/Llama-3.3-70B-Instruct",
]
llm = model_ids[10]
print(f'Selected LLM: {llm}')

Selected LLM: Qwen/Qwen3-32B


In [3]:
class InteractiveChat:
    def __init__(self, cfg, model, tokenizer, dm, symp_label_dict, sae_dict, actmax_dict, device_dict, std_dict, generation_kwargs):
        # ... (Your existing initialization code remains exactly the same) ...
        self.cfg = cfg
        self.model = model
        self.tokenizer = tokenizer
        self.dm = dm
        self.sae_dict = sae_dict
        self.actmax_dict = actmax_dict
        self.device_dict = device_dict
        self.std_dict = std_dict
        
        self.symp_label_dict = symp_label_dict
        self.symp_keys = list(symp_label_dict.keys())
        
        self.generation_kwargs = generation_kwargs
        self.steer_manager = LLM_Steer_Manager(cfg, model, tokenizer, device_dict, self.generation_kwargs)
        self.measure_manager = Measure_Manager(cfg, model, tokenizer, device_dict, std_dict, sae_dict, actmax_dict, self.symp_label_dict)
        
        self.chat_history = []
        self.system_prompt = "You are a helpful AI assistant."
        if not self.chat_history:
            self.chat_history.append({"role": "system", "content": self.system_prompt})

        self.themes = {
            'Default': {'user_color': '#0000ff', 'assistant_color': '#008000', 'bg_color': '#ffffff', 'text_color': '#000000'},
            'Dark': {'user_color': '#4fc3f7', 'assistant_color': '#81c784', 'bg_color': '#1e1e1e', 'text_color': '#ffffff'},
            'Blue': {'user_color': '#1565c0', 'assistant_color': '#00695c', 'bg_color': '#e3f2fd', 'text_color': '#0d47a1'}
        }
        self.current_theme = 'Default'

        self.setup_ui()

    def setup_ui(self):
        # Theme Selector
        self.theme_selector = widgets.Dropdown(
            options=list(self.themes.keys()),
            value='Default',
            description='Theme:',
            layout={'width': '200px'}
        )
        self.theme_selector.observe(self.change_theme, names='value')

        # --- FIX: Chat Display Area ---
        # Instead of Output, we use an HTML widget for content
        self.chat_content = widgets.HTML(value="", layout={'width': '100%'})
        
        # We wrap it in a VBox that handles the scrolling and borders
        self.chat_container = widgets.VBox(
            [self.chat_content], 
            layout={
                'border': '1px solid #ccc', 
                'height': '400px', 
                'overflow_y': 'auto',  # Ensures scrollbar appears
                'padding': '10px',
                'margin': '0 0 20px 0',
                'display': 'block'     # Helps VS Code render block correctly
            }
        )
        
        # Input area
        self.user_input = widgets.Textarea(
            placeholder='Type your message here...',
            layout={'width': '60%', 'height': '60px'}
        )
        
        # Controls
        self.send_button = widgets.Button(description="Send", button_style='primary', icon='paper-plane', layout={'width': '100px'})
        self.send_button.on_click(self.handle_submit)
        
        self.quit_button = widgets.Button(description="Quit", button_style='danger', icon='times', layout={'width': '100px'})
        self.quit_button.on_click(self.handle_quit)
        
        # Intervention Controls
        self.use_intervention = widgets.Checkbox(value=False, description='Apply Intervention')
        self.latent_selector = widgets.Dropdown(
            options=self.symp_keys,
            description='Latent:',
            disabled=True,
            layout={'width': '300px'}
        )
        
        self.sliders = {}
        slider_widgets = []
        for layer in self.cfg.hook_layers:
            slider = widgets.FloatSlider(
                value=0.0, min=0.0, max=1.0, step=0.01,
                description=f'Layer {layer}:', disabled=True,
                continuous_update=False, orientation='horizontal',
                readout=True, readout_format='.2f', layout={'width': '300px'}
            )
            self.sliders[layer] = slider
            slider_widgets.append(slider)
            
        self.sliders_box = widgets.VBox(slider_widgets)
        self.use_intervention.observe(self.toggle_intervention, names='value')
        
        # Layout
        header = widgets.HBox([widgets.HTML("<h3>Interactive Chat with S3AE Intervention</h3>"), self.theme_selector], layout={'justify_content': 'space-between', 'align_items': 'center', 'margin': '0 0 10px 0'})
        
        self.input_box = widgets.HBox([self.user_input, self.send_button, self.quit_button], layout={'align_items': 'center', 'margin': '10px 0'})
        
        self.intervention_controls = widgets.VBox([
            widgets.HBox([self.use_intervention, self.latent_selector]),
            widgets.Label("Intervention Strength per Layer:"),
            self.sliders_box
        ], layout={'border': '1px solid #eee', 'padding': '10px', 'margin': '10px 0', 'background_color': '#f9f9f9'})
        
        self.main_layout = widgets.VBox([
            header,
            self.chat_container, # Use the new container
            self.input_box,
            self.intervention_controls
        ])
        
    def get_max_pooled_measurements(self, text):
            """
            Splits text into sentences, measures activations for each, 
            and returns the max activation per latent across all sentences.
            """
            # 1. Split text into sentences using regex
            # This regex splits by . ! ? but keeps the punctuation attached to the sentence
            sentences = re.split(r'(?<=[.!?])\s+', text)
            
            # Filter out empty strings/whitespace
            sentences = [s.strip() for s in sentences if s.strip()]
            
            # Fallback: if no sentences detected (e.g. just a word), measure the whole text
            if not sentences:
                sentences = [text]

            # 2. Get measurements for ALL sentences at once
            # Assuming measure_manager returns a list of lists (one list of activations per input string)
            sent_activations = self.measure_manager.sae_measure_no_json(sentences)

            # 3. Max Pool across sentences
            # Convert to numpy array for easy column-wise max operation
            # Shape: (num_sentences, num_latents)
            act_array = np.array(sent_activations)
            
            # Max along axis 0 (collapsing the sentences dimension)
            # Shape: (num_latents,)
            max_pooled = np.max(act_array, axis=0)
            
            return max_pooled.tolist()
        
    def toggle_intervention(self, change):
        enabled = change['new']
        self.latent_selector.disabled = not enabled
        for slider in self.sliders.values():
            slider.disabled = not enabled

    def change_theme(self, change):
        self.current_theme = change['new']
        self.refresh_chat_display()

    def handle_quit(self, b):
        self.user_input.disabled = True
        self.send_button.disabled = True
        self.quit_button.disabled = True
        self.use_intervention.disabled = True
        self.theme_selector.disabled = True
        # Append quit message to HTML
        self.chat_content.value += "<div style='margin-top:20px; text-align:center; color: #888;'><strong>--- Chat Ended ---</strong></div>"

    def handle_submit(self, b):
        user_text = self.user_input.value.strip()
        if not user_text:
            return
            
        self.user_input.value = '' 
        
        self.chat_history.append({"role": "user", "content": user_text})
        self.refresh_chat_display(typing=True)
            
        itv_settings = None
        if self.use_intervention.value:
            strengths = {layer: slider.value for layer, slider in self.sliders.items()}
            itv_settings = {
                'latent': self.latent_selector.value,
                'strengths': strengths
            }
            
        try:
            if itv_settings:
                response, measurements = self.generate_with_intervention(itv_settings)
            else:
                response, measurements = self.generate_normal()
                
            self.chat_history.append({"role": "assistant", "content": response, "measurements": measurements})
            self.refresh_chat_display()
            
        except Exception as e:
            # Display error in the chat log in red
            err_msg = f"<div style='color: red; margin: 10px 0;'><strong>Error:</strong> {str(e)}</div>"
            current_html = self.chat_content.value
            # Remove typing indicator if present
            if "Assistant is typing..." in current_html:
                 self.refresh_chat_display() # Redraw cleanly
            self.chat_content.value += err_msg

    def refresh_chat_display(self, typing=False):
        theme = self.themes[self.current_theme]
        bg_color = theme['bg_color']
        text_color = theme['text_color']
        user_color = theme['user_color']
        asst_color = theme['assistant_color']
        
        # --- FIX: Removed clear_output logic ---
        
        # Build the full HTML string
        html_content = f'<div style="background-color: {bg_color}; color: {text_color}; padding: 10px; min-height: 100%; font-family: sans-serif;">'
        
        for msg in self.chat_history:
            role = msg['role']
            content = msg['content']
            if role == 'system': continue
            
            # Formatting line breaks for HTML
            content = content.replace('\n', '<br>')
            
            if role == 'user':
                html_content += f'<div style="margin: 10px 0;"><strong><span style="color: {user_color}">User:</span></strong> {content}</div>'
            elif role == 'assistant':
                html_content += f'<div style="margin: 10px 0;"><strong><span style="color: {asst_color}">Assistant:</span></strong> {content}</div>'
                
                if 'measurements' in msg and msg['measurements']:
                    measurements = msg['measurements']
                    measure_pairs = list(zip(self.symp_keys, measurements))
                    measure_pairs.sort(key=lambda x: x[1], reverse=True)
                    top_str = ", ".join([f"{k}: {v:.2f}" for k, v in measure_pairs[:5]])
                    html_content += f'<div style="font-size: 0.9em; color: #888; margin-left: 20px;"><em>Top Activations: {top_str}</em></div>'
                    html_content += f'<hr style="border-color: #eee; margin: 5px 0;">'
        
        if typing:
            html_content += f'<div style="margin: 10px 0; color: #888;"><em>Assistant is typing...</em></div>'
        
        html_content += '</div>'
        
        # Simply update the widget value
        self.chat_content.value = html_content

    # ... (Keep get_clean_history, generate_with_intervention, generate_normal, and start as they were) ...
    def get_clean_history(self):
        return [{'role': m['role'], 'content': m['content']} for m in self.chat_history]

    def generate_with_intervention(self, settings):
            latent = settings['latent']
            strengths = settings['strengths']
            
            itv_W_dict = {}
            itv_str_dict = {}
            
            latent_idx = self.symp_label_dict[latent]
            
            for layer in self.cfg.hook_layers:
                # Get steering vector from SAE decoder
                steering_vec = self.sae_dict[layer].decoder.weight.T[latent_idx]
                steering_vec = steering_vec.to(self.device_dict[layer])
                
                # Reshape for LLM_Steer_Manager: (batch_size, 1, hidden_size)
                itv_W_dict[layer] = steering_vec.unsqueeze(0).unsqueeze(0).to(torch.bfloat16)
                
                # Strength: (batch_size, 1)
                s = strengths.get(layer, 0.0)
                itv_str_dict[layer] = torch.tensor([[s]], device=self.device_dict[layer], dtype=torch.bfloat16)
                
            # Generate
            output_tokens = self.steer_manager.generate_text_w_itv([self.get_clean_history()], itv_W_dict, itv_str_dict)
            output_text = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
            
            # --- CHANGED: Use max pooling helper ---
            measurements = self.get_max_pooled_measurements(output_text)
            
            return output_text, measurements

    def generate_normal(self):
        output_tokens = self.steer_manager.generate_text([self.get_clean_history()])
        output_text = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
        
        # --- CHANGED: Use max pooling helper ---
        measurements = self.get_max_pooled_measurements(output_text)
        
        return output_text, measurements
        
    def start(self):
        display(self.main_layout)

In [4]:
cfg = Config(llm)
dm = Data_Manager(cfg)

model, tokenizer = model_selection(cfg)
sae_dict = dm.load_dict(dict_type='sae')
actmax_dict = dm.load_dict(dict_type='actmax')
device_dict = dm.load_dict(dict_type='device', model=model)
std_dict = dm.load_dict(dict_type='act-std')
for layer in cfg.hook_layers:
    sae_dict[layer] = sae_dict[layer].to(device_dict[layer])
    
symp_label_dict, _, symp_keys, _ = dm.load_dict('label')

generation_kwargs = {'max_new_tokens': 300, 'tmp': 0.5}

Loading model...


Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [None]:
chat_system = InteractiveChat(cfg, model, tokenizer, dm, symp_label_dict, sae_dict, actmax_dict, device_dict, std_dict, generation_kwargs)
chat_system.start()