# Differential Vision: Start-Button Explorer

This variant adds a Start button so you can:
- Defer model loading until configuration is set
- Initialize the interactive slider + attention heatmap UI on demand
- Use the built-in Play control to watch frames over time (adjust interval)

Configure paths below and click Start.

In [None]:
import os, math
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

import ipywidgets as W
from IPython.display import display, clear_output

from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.utils import disable_torch_init

from differential_vision import DifferentialVisionEncoder
from differential_vision.patch_utils import (
    compute_patch_diff, compute_patch_diff_values,
    visualize_changed_patches, overlay_patch_rects
)

# --- Config UI ---
model_path_txt = W.Text(value='/home/sdan/workspace/diffcua/checkpoints/llava-fastvithd_1.5b_stage3', description='Model', layout=W.Layout(width='60%'))
base_path_txt = W.Text(value='', description='Base', layout=W.Layout(width='40%'))
images_dir_txt = W.Text(value='/home/sdan/workspace/diffcua/differential_vision/agentnet_curated/traj_0000', description='Images', layout=W.Layout(width='60%'))
prompt_txt = W.Textarea(value='Describe salient changes and where the text points its focus.', description='Prompt', layout=W.Layout(width='80%', height='60px'))
force_pad_chk = W.Checkbox(value=True, description='Force pad aspect')
play_interval = W.IntSlider(description='Interval(ms)', min=100, max=2000, step=50, value=600, continuous_update=False)
start_btn = W.Button(description='Start', button_style='success')
panel = W.Output()

display(W.VBox([
    W.HBox([model_path_txt, base_path_txt]),
    images_dir_txt,
    prompt_txt,
    W.HBox([force_pad_chk, play_interval, start_btn]),
    panel
]))

def _natural_key(path: str):
    import re
    base = os.path.basename(path)
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', base)]

def _gather_images(directory: str):
    supported = {'.png', '.jpg', '.jpeg', '.bmp', '.gif'}
    files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.splitext(f)[1].lower() in supported]
    return sorted(files, key=_natural_key)

def on_start_clicked(_):
    with panel:
        clear_output(wait=True)
        # Resolve config
        device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
        torch_device = torch.device(device)
        model_path = os.path.expanduser(model_path_txt.value.strip())
        model_base = base_path_txt.value.strip() or None
        images_dir = os.path.expanduser(images_dir_txt.value.strip())
        prompt = prompt_txt.value.strip()
        # Load model lazily
        disable_torch_init()
        try:
            model_name = os.path.basename(model_path)
            tokenizer, model, image_processor, _ = load_pretrained_model(model_path, model_base, model_name, device=device)
            # Set attention implementation to eager to support output_attentions
            model.set_attn_implementation('eager')
        except Exception as e:
            print('Error loading model:', e)
            return
        if force_pad_chk.value:
            try: setattr(model.config, 'image_aspect_ratio', 'pad')
            except Exception: pass
        # Build prompt + input ids
        def build_prompt_and_ids(user_prompt):
            qs = user_prompt
            if getattr(model.config, 'mm_use_im_start_end', False):
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
            conv = conv_templates['qwen_2'].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            text = conv.get_prompt()
            ids = tokenizer_image_token(text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(torch_device)
            return text, ids
        prompt_text, input_ids = build_prompt_and_ids(prompt)
        # Gather images
        try:
            frame_paths = _gather_images(images_dir)
        except Exception as e:
            print('Error reading images:', e)
            return
        if not frame_paths:
            print('No images found in', images_dir)
            return
        # Pre-load tensors
        def load_and_process(path):
            pil = Image.open(path).convert('RGB')
            processed = process_images([pil], image_processor, model.config)
            tensor = processed[0] if isinstance(processed, (list, tuple)) else processed
            if tensor.ndim == 3: tensor = tensor.unsqueeze(0)
            return pil, tensor.to(device=torch_device, dtype=next(model.parameters()).dtype)
        frames_pil, frames_tensor = [], []
        for p in frame_paths:
            try:
                pil, ten = load_and_process(p)
            except Exception as e:
                print('Error loading image', p, e); return
            frames_pil.append(pil); frames_tensor.append(ten)
        # Grid info
        vt = model.get_vision_tower()
        grid_side = getattr(vt, 'num_patches_per_side', None)
        if grid_side is None:
            with torch.inference_mode():
                tmp_feats = model.encode_images(frames_tensor[0])
            grid_side = int(round(math.sqrt(int(tmp_feats.shape[1]))))
        grid_side = int(grid_side)
        if hasattr(vt, 'config') and hasattr(vt.config, 'patch_size'):
            patch_px_default = int(vt.config.patch_size)
        elif isinstance(getattr(vt, 'config', None), dict):
            patch_px_default = int(vt.config['image_cfg']['patch_size'])
        else:
            patch_px_default = 24
        # Build encoder
        def build_encoder(thr, maxchg, radius):
            return DifferentialVisionEncoder(
                model, diff_threshold=float(thr), max_changed_patches=int(maxchg), context_radius=int(radius), skip_small_updates=False, device=torch_device
            )
        encoder = build_encoder(0.05, 50, 1)
        encoder.reset_cache()
        # Widgets
        frame_slider = W.IntSlider(description='Frame', min=0, max=len(frame_paths)-1, step=1, value=0, continuous_update=False)
        play = W.Play(interval=int(play_interval.value), value=0, min=0, max=len(frame_paths)-1, step=1, description='Play')
        W.jslink((play, 'value'), (frame_slider, 'value'))
        diff_thr = W.FloatSlider(description='Diff thr', min=0.0, max=1.0, step=0.01, value=0.05, readout_format='.2f', continuous_update=False)
        max_changed = W.IntSlider(description='Max Δ patches', min=1, max=grid_side*grid_side, step=1, value=50, continuous_update=False)
        ctx_radius = W.IntSlider(description='Context r', min=0, max=3, step=1, value=1, continuous_update=False)
        att_layer_mode = W.Dropdown(description='Layer', options=['last', 'mean_all'], value='last')
        att_head_mode = W.Dropdown(description='Head', options=['mean', 'index'], value='mean')
        att_head_idx = W.IntSlider(description='Head idx', min=0, max=15, step=1, value=0, continuous_update=False)
        q_offset = W.IntSlider(description='Q offset', min=0, max=8, step=1, value=0, continuous_update=False)
        out_vis = W.Output()
        display(W.VBox([W.HBox([play, frame_slider]), W.HBox([diff_thr, max_changed, ctx_radius]), W.HBox([att_layer_mode, att_head_mode, att_head_idx, q_offset]), out_vis]))
        # Helpers
        def get_vision_span(ids, image_features):
            N = int(image_features.shape[1])
            pos = (ids[0] == IMAGE_TOKEN_INDEX).nonzero(as_tuple=False)
            assert pos.numel() > 0
            start = int(pos[0,0].item())
            return start, start+N, N
        def fuse_attention(atts, q_index, v_start, v_end, layer_mode='last', head_mode='mean', head_idx=0):
            layers = [atts[-1]] if layer_mode=='last' else atts
            vals = []
            for A in layers:
                A = A[0]  # [H,Q,K]
                Aq = A[:, q_index, :]
                if head_mode=='index':
                    Aq = Aq[head_idx:head_idx+1,:].mean(dim=0)
                else:
                    Aq = Aq.mean(dim=0)
                vals.append(Aq)
            V = torch.stack(vals, dim=0).mean(dim=0)
            return V[v_start:v_end]
        # Render
        def render(i):
            nonlocal encoder
            # Rebuild encoder if params changed
            if (encoder.diff_threshold != float(diff_thr.value) or encoder.max_changed_patches != int(max_changed.value) or encoder.context_radius != int(ctx_radius.value)):
                encoder = build_encoder(diff_thr.value, max_changed.value, ctx_radius.value)
                encoder.reset_cache()
                # warm through to i
                for t in range(0, i+1):
                    encoder.encode(frames_tensor[t], return_stats=False)
            feats, info = encoder.encode(frames_tensor[i], return_stats=True)
            # Mask
            if hasattr(vt, 'config') and hasattr(vt.config, 'patch_size'): patch_px = int(vt.config.patch_size)
            elif isinstance(getattr(vt, 'config', None), dict): patch_px = int(vt.config['image_cfg']['patch_size'])
            else: patch_px = patch_px_default
            img_cur = np.array(frames_pil[i]).astype(np.float32)/255.0
            if i>0:
                img_prev = np.array(frames_pil[i-1]).astype(np.float32)/255.0
                mask_bool = compute_patch_diff(img_prev, img_cur, patch_size=patch_px, threshold=float(diff_thr.value))
                diff_vals = compute_patch_diff_values(img_prev, img_cur, patch_size=patch_px)
            else:
                mask_bool = np.zeros((grid_side, grid_side), dtype=bool)
                diff_vals = np.zeros((grid_side, grid_side), dtype=np.float32)
            vis_changed = visualize_changed_patches(img_cur, mask_bool, patch_size=patch_px, alpha=0.25)
            idx_i, idx_j = np.where(mask_bool)
            overlay_rects = overlay_patch_rects(img_cur, list(zip(idx_i.tolist(), idx_j.tolist())), patch_size=patch_px, color=(255,0,0), thickness=2)
            # Delta grid
            if i>0:
                with torch.inference_mode():
                    prev_full = encoder._full_encode(frames_tensor[i-1])
                flat_prev = prev_full.view(-1, prev_full.shape[-1]).detach().float().cpu()
            else:
                flat_prev = feats.view(-1, feats.shape[-1]).detach().float().cpu()
            flat_cur = feats.view(-1, feats.shape[-1]).detach().float().cpu()
            delta = torch.linalg.norm(flat_cur - flat_prev, dim=-1)
            delta_grid = delta.view(grid_side, grid_side).numpy()
            # Attention
            with torch.inference_mode():
                img_feats = model.encode_images(frames_tensor[i])
                v_start, v_end, N = get_vision_span(input_ids, img_feats)
                q_last = int((input_ids != tokenizer.pad_token_id).sum(dim=1).item()) - 1
                q_idx = max(0, q_last - int(q_offset.value))
                out = model(input_ids=input_ids, images=frames_tensor[i], image_sizes=[frames_pil[i].size], output_attentions=True, return_dict=True)
                atts = out.attentions
                B,H,Q,K = atts[-1].shape
                att_head_idx.max = max(0, H-1)
                att_vec = fuse_attention(atts, q_idx, v_start, v_end, layer_mode=att_layer_mode.value, head_mode=('index' if att_head_mode.value=='index' else 'mean'), head_idx=int(att_head_idx.value)).detach().cpu().numpy()
                att_grid = att_vec.reshape(grid_side, grid_side)
            # Draw
            with out_vis:
                clear_output(wait=True)
                fig, ax = plt.subplots(2, 3, figsize=(16, 9))
                ax[0,0].imshow(frames_pil[i]); ax[0,0].set_title(f'Frame {i}: {os.path.basename(frame_paths[i])}'); ax[0,0].axis('off')
                ax[0,1].imshow(vis_changed); ax[0,1].set_title(f'Changed patches (thr={float(diff_thr.value):.2f})'); ax[0,1].axis('off')
                ax[0,2].imshow(overlay_rects); ax[0,2].set_title('Changed patch rectangles'); ax[0,2].axis('off')
                im = ax[1,0].imshow(delta_grid, cmap='magma'); ax[1,0].set_title('Token Δ (L2 per patch)'); fig.colorbar(im, ax=ax[1,0], fraction=0.046, pad=0.04)
                im2 = ax[1,1].imshow(att_grid, cmap='viridis'); ax[1,1].set_title(f'Attention (layer={att_layer_mode.value}, head={att_head_mode.value}{'' if att_head_mode.value=='mean' else f'[{int(att_head_idx.value)}]'}, q={q_idx})'); fig.colorbar(im2, ax=ax[1,1], fraction=0.046, pad=0.04)
                ax[1,2].axis('off')
                text = (
                    f"Decision: {info['encoding_type']} | Changed: {info['changed_patches']}/{info['total_patches']}\n"
                    f"Context radius: {int(ctx_radius.value)} | Patch(px): {patch_px} | Grid: {grid_side}x{grid_side}\n"
                    f"Vision span: [{v_start},{v_end}) N={N} | Tower: {type(vt).__name__}"
                )
                ax[1,2].text(0.02, 0.5, text, fontsize=11, va='center')
                plt.tight_layout(); plt.show()
        # Bind
        def on_frame_change(change): render(int(change['new']))
        def on_param_change(change): render(int(frame_slider.value))
        frame_slider.observe(on_frame_change, names='value')
        for w in (diff_thr, max_changed, ctx_radius, att_layer_mode, att_head_mode, att_head_idx, q_offset): w.observe(on_param_change, names='value')
        # Initial
        render(0)

start_btn.on_click(on_start_clicked)