# Differential Vision: Interactive Attention & Patch Updates

Use a slider to scrub through an image sequence. For each frame, this notebook:
- Runs differential vision encoding (cache / partial / full) using `DifferentialVisionEncoder`.
- Computes patch-change overlays and token delta maps.
- Captures text→vision attention and overlays it on the patch grid.
- Exposes controls for diff threshold, max changed patches, context radius, attention head/layer, and query offset.

Instructions:
- Set `model_path` and `images_dir` in the config block below.
- Run the cell to initialize the UI and use the controls to explore.


In [None]:
# Interactive differential vision + attention over a sequence

import os, time, math
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

import ipywidgets as W
from IPython.display import display, clear_output

from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.utils import disable_torch_init

from differential_vision import DifferentialVisionEncoder
from differential_vision.patch_utils import (
    compute_patch_diff, compute_patch_diff_values,
    visualize_changed_patches, overlay_patch_rects
)

# ---------- Config ----------
device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
torch_device = torch.device(device)

# Edit these before running the cell
model_path = '../checkpoints/llava-fastvithd_1.5b_stage3'
model_base = None  # set base if your checkpoint is projector-only; else keep None
images_dir = '../differential_vision/agentnet_curated/traj_0000'
prompt = 'Describe salient changes and where the text points its focus.'
force_pad_aspect = True  # keeps single-image path for differential updates

# ---------- Load model ----------
disable_torch_init()
model_name = os.path.basename(os.path.expanduser(model_path))
tokenizer, model, image_processor, _ = load_pretrained_model(
    os.path.expanduser(model_path), model_base, model_name, device=device
)
if force_pad_aspect:
    try: setattr(model.config, 'image_aspect_ratio', 'pad')
    except Exception: pass

# ---------- Build prompt/token ids (once) ----------
def build_prompt_and_ids(tokenizer, model, user_prompt):
    qs = user_prompt
    if getattr(model.config, 'mm_use_im_start_end', False):
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
    conv = conv_templates['qwen_2'].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt_text = conv.get_prompt()
    input_ids = tokenizer_image_token(
        prompt_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
    ).unsqueeze(0).to(torch_device)
    return prompt_text, input_ids

prompt_text, input_ids = build_prompt_and_ids(tokenizer, model, prompt)

# ---------- Prepare frames ----------
def _natural_key(path: str):
    import re
    base = os.path.basename(path)
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', base)]

def gather_images(directory: str):
    supported = {'.png', '.jpg', '.jpeg', '.bmp', '.gif'}
    files = [os.path.join(directory, f) for f in os.listdir(directory)
             if os.path.splitext(f)[1].lower() in supported]
    return sorted(files, key=_natural_key)

frame_paths = gather_images(os.path.expanduser(images_dir))
assert frame_paths, f'No images found in {images_dir}'

def load_and_process(path):
    pil = Image.open(path).convert('RGB')
    processed = process_images([pil], image_processor, model.config)
    image_tensor = processed[0] if isinstance(processed, (list, tuple)) else processed
    if image_tensor.ndim == 3:
        image_tensor = image_tensor.unsqueeze(0)
    return pil, image_tensor.to(device=torch_device, dtype=next(model.parameters()).dtype)

frames_pil = []
frames_tensor = []
for p in frame_paths:
    pil, ten = load_and_process(p)
    frames_pil.append(pil)
    frames_tensor.append(ten)

# ---------- Vision grid ----------
vt = model.get_vision_tower()
grid_side = getattr(vt, 'num_patches_per_side', None)
if grid_side is None:
    # fallback to square root of token count
    with torch.inference_mode():
        tmp_feats = model.encode_images(frames_tensor[0])
    grid_side = int(round(math.sqrt(int(tmp_feats.shape[1]))))
grid_side = int(grid_side)

if hasattr(vt, 'config') and hasattr(vt.config, 'patch_size'):
    patch_px_default = int(vt.config.patch_size)
elif isinstance(getattr(vt, 'config', None), dict):
    patch_px_default = int(vt.config['image_cfg']['patch_size'])
else:
    patch_px_default = 24

# ---------- Widgets ----------
frame_slider = W.IntSlider(description='Frame', min=0, max=len(frame_paths)-1, step=1, value=0, continuous_update=False)
diff_thr = W.FloatSlider(description='Diff thr', min=0.0, max=1.0, step=0.01, value=0.05, readout_format='.2f', continuous_update=False)
max_changed = W.IntSlider(description='Max Δ patches', min=1, max=grid_side*grid_side, step=1, value=50, continuous_update=False)
ctx_radius = W.IntSlider(description='Context r', min=0, max=3, step=1, value=1, continuous_update=False)

att_layer_mode = W.Dropdown(description='Layer', options=['last', 'mean_all'], value='last')
att_head_mode = W.Dropdown(description='Head', options=['mean', 'index'], value='mean')
att_head_idx = W.IntSlider(description='Head idx', min=0, max=15, step=1, value=0, continuous_update=False)
q_offset = W.IntSlider(description='Q offset (from end)', min=0, max=8, step=1, value=0, continuous_update=False)

reset_seq_btn = W.Button(description='Reset sequence state', button_style='warning')
output_widget = W.Output()

ui_top = W.HBox([frame_slider, diff_thr, max_changed, ctx_radius])
ui_attn = W.HBox([att_layer_mode, att_head_mode, att_head_idx, q_offset])
ui_ctrl = W.HBox([reset_seq_btn])

display(ui_top, ui_attn, ui_ctrl, output_widget)

# ---------- State ----------
state = dict(
    encoder=None,
    feats_prev=None,
    last_idx=None,
    vt=vt
)

def build_encoder():
    return DifferentialVisionEncoder(
        model,
        diff_threshold=float(diff_thr.value),
        max_changed_patches=int(max_changed.value),
        context_radius=int(ctx_radius.value),
        skip_small_updates=False,
        device=torch_device
    )

def reset_sequence():
    state['encoder'] = build_encoder()
    state['encoder'].reset_cache()
    state['feats_prev'] = None
    state['last_idx'] = None

reset_sequence()

# ---------- Attention helpers ----------
def get_vision_span(input_ids, image_features):
    N = int(image_features.shape[1])
    pos = (input_ids[0] == IMAGE_TOKEN_INDEX).nonzero(as_tuple=False)
    assert pos.numel() > 0, 'Prompt must contain an IMAGE_TOKEN sentinel'
    vision_start = int(pos[0, 0].item())
    return vision_start, vision_start + N, N

def fuse_attention(attentions, q_index, v_start, v_end, layer_mode='last', head_mode='mean', head_idx=0):
    if layer_mode == 'last':
        L_sel = [attentions[-1]]
    else:
        L_sel = attentions
    layer_maps = []
    for A in L_sel:
        A = A[0]                  # [H,Q,K]
        A_q = A[:, q_index, :]    # [H,K]
        if head_mode == 'index':
            A_q = A_q[head_idx:head_idx+1, :].mean(dim=0)
        else:
            A_q = A_q.mean(dim=0)
        layer_maps.append(A_q)
    A_mean = torch.stack(layer_maps, dim=0).mean(dim=0)  # [K]
    return A_mean[v_start:v_end]

# ---------- Rendering ----------
def replay_to_index(idx):
    reset_sequence()
    feats_prev = None
    for t in range(0, idx+1):
        feats_t, _ = state['encoder'].encode(frames_tensor[t], return_stats=True)
        feats_prev = feats_t
    state['feats_prev'] = feats_prev
    state['last_idx'] = idx

def render(idx):
    if (state['encoder'] is None or
        state['encoder'].diff_threshold != float(diff_thr.value) or
        state['encoder'].max_changed_patches != int(max_changed.value) or
        state['encoder'].context_radius != int(ctx_radius.value)):
        replay_to_index(idx)
    else:
        if state['last_idx'] is None or idx < state['last_idx']:
            replay_to_index(idx)
        elif idx > state['last_idx']:
            for t in range(state['last_idx']+1, idx+1):
                feats_t, _ = state['encoder'].encode(frames_tensor[t], return_stats=True)
                state['feats_prev'] = feats_t
            state['last_idx'] = idx

    feats_cur, info = state['encoder'].encode(frames_tensor[idx], return_stats=True)

    vt = state['vt']
    grid_side_local = getattr(vt, 'num_patches_per_side', grid_side)
    grid_side_local = int(grid_side_local)
    if hasattr(vt, 'config') and hasattr(vt.config, 'patch_size'):
        patch_px = int(vt.config.patch_size)
    elif isinstance(getattr(vt, 'config', None), dict):
        patch_px = int(vt.config['image_cfg']['patch_size'])
    else:
        patch_px = patch_px_default

    img_cur = np.array(frames_pil[idx]).astype(np.float32)/255.0
    if idx > 0:
        img_prev = np.array(frames_pil[idx-1]).astype(np.float32)/255.0
        mask_bool = compute_patch_diff(img_prev, img_cur, patch_size=patch_px, threshold=float(diff_thr.value))
        diff_vals = compute_patch_diff_values(img_prev, img_cur, patch_size=patch_px)
    else:
        mask_bool = np.zeros((grid_side_local, grid_side_local), dtype=bool)
        diff_vals = np.zeros((grid_side_local, grid_side_local), dtype=np.float32)

    vis_changed = visualize_changed_patches(img_cur, mask_bool, patch_size=patch_px, alpha=0.25)
    idx_i, idx_j = np.where(mask_bool)
    overlay_rects = overlay_patch_rects(img_cur, list(zip(idx_i.tolist(), idx_j.tolist())), patch_size=patch_px, color=(255,0,0), thickness=2)

    # Token delta grid: compare to previous frame full-encode for stable baseline
    if idx > 0:
        with torch.inference_mode():
            feats_prev_full = state['encoder']._full_encode(frames_tensor[idx-1])
        flat_prev = feats_prev_full.view(-1, feats_prev_full.shape[-1]).detach().float().cpu()
    else:
        flat_prev = feats_cur.view(-1, feats_cur.shape[-1]).detach().float().cpu()
    flat_cur = feats_cur.view(-1, feats_cur.shape[-1]).detach().float().cpu()
    delta = torch.linalg.norm(flat_cur - flat_prev, dim=-1)
    delta_grid = delta.view(grid_side_local, grid_side_local).numpy()

    # Attention
    with torch.inference_mode():
        img_feats = model.encode_images(frames_tensor[idx])
        v_start, v_end, N = get_vision_span(input_ids, img_feats)
        q_last = int((input_ids != tokenizer.pad_token_id).sum(dim=1).item()) - 1
        q_idx = max(0, q_last - int(q_offset.value))
        model_out = model(
            input_ids=input_ids,
            images=frames_tensor[idx],
            image_sizes=[frames_pil[idx].size],
            output_attentions=True,
            return_dict=True
        )
        atts = model_out.attentions
        B,H,Q,K = atts[-1].shape
        att_head_idx.max = max(0, H-1)
        att_vec = fuse_attention(
            atts, q_idx, v_start, v_end,
            layer_mode=att_layer_mode.value,
            head_mode=('index' if att_head_mode.value == 'index' else 'mean'),
            head_idx=int(att_head_idx.value)
        ).detach().cpu().numpy()
        att_grid = att_vec.reshape(grid_side_local, grid_side_local)

    # Render
    with output_widget:
        clear_output(wait=True)
        fig, ax = plt.subplots(2, 3, figsize=(16, 9))
        ax[0,0].imshow(frames_pil[idx]); ax[0,0].set_title(f'Frame {idx}: {os.path.basename(frame_paths[idx])}'); ax[0,0].axis('off')
        ax[0,1].imshow(vis_changed); ax[0,1].set_title(f'Changed patches (thr={float(diff_thr.value):.2f})'); ax[0,1].axis('off')
        ax[0,2].imshow(overlay_rects); ax[0,2].set_title('Changed patch rectangles'); ax[0,2].axis('off')
        im = ax[1,0].imshow(delta_grid, cmap='magma'); ax[1,0].set_title('Token Δ (L2 per patch)'); fig.colorbar(im, ax=ax[1,0], fraction=0.046, pad=0.04)
        im2 = ax[1,1].imshow(att_grid, cmap='viridis'); ax[1,1].set_title(f'Attention (q={q_idx}, layer={att_layer_mode.value}, head={att_head_mode.value}{'' if att_head_mode.value=='mean' else f'[{int(att_head_idx.value)}]'} )'); fig.colorbar(im2, ax=ax[1,1], fraction=0.046, pad=0.04)
        ax[1,2].axis('off')
        text = (
            f"Decision: {info['encoding_type']} | Changed: {info['changed_patches']}/{info['total_patches']}\n"
            f"Context radius: {int(ctx_radius.value)} | Patch(px): {patch_px} | Grid: {grid_side_local}x{grid_side_local}\n"
            f"Vision span: [{v_start},{v_end}) N={N} | Tower: {type(vt).__name__}"
        )
        ax[1,2].text(0.02, 0.5, text, fontsize=11, va='center')
        plt.tight_layout(); plt.show()

# ---------- Callbacks ----------
def on_frame_change(change):
    render(int(change['new']))
def on_param_change(change):
    render(int(frame_slider.value))
def on_reset_clicked(btn):
    reset_sequence()
    render(int(frame_slider.value))

frame_slider.observe(on_frame_change, names='value')
for w in (diff_thr, max_changed, ctx_radius, att_layer_mode, att_head_mode, att_head_idx, q_offset):
    w.observe(on_param_change, names='value')
reset_seq_btn.on_click(on_reset_clicked)

# Initial render
render(int(frame_slider.value))
