# Differential Vision + Attention Explorer

Play through an image sequence with a slider. For each frame, we:
- Run differential vision encoding (cache/partial/full)
- Prepare multimodal inputs and capture cross-attention
- Visualize changed patches and attention over the vision grid

Set model path and image directory, then run `run_explorer(...)`.

In [2]:
import os, re, math, inspect
import numpy as np
import torch
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from differential_vision import DifferentialVisionEncoder
from differential_vision.patch_utils import compute_patch_diff, visualize_changed_patches
from llava.mm_utils import process_images, tokenizer_image_token, expand2square
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.conversation import conv_templates
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX

def _natural_sort_key(path: str):
    basename = os.path.basename(path)
    return [int(tok) if tok.isdigit() else tok.lower() for tok in re.split(r'(\d+)', basename)]

def _gather_images_from_dir(directory: str):
    supported = {'.png', '.jpg', '.jpeg', '.bmp', '.gif'}
    paths = []
    for entry in os.listdir(directory):
        full = os.path.join(directory, entry)
        if os.path.isfile(full) and os.path.splitext(entry)[1].lower() in supported:
            paths.append(full)
    return sorted(paths, key=_natural_sort_key)

def _prepare_prompt(tokenizer, model, prompt: str, conv_mode: str):
    if getattr(model.config, 'mm_use_im_start_end', False):
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + prompt
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

def _compute_vision_span(input_ids: torch.Tensor, num_image_tokens: int):
    # input_ids: [1, seq] with IMAGE_TOKEN_INDEX placeholders
    ids = input_ids[0]
    idx = (ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=False)
    if idx.numel() == 0:
        return None, None
    start = int(idx[0].item())
    return start, start + int(num_image_tokens)

def _heat_to_overlay(image_np: np.ndarray, heat: np.ndarray, alpha: float = 0.35, cmap: str = 'jet') -> np.ndarray:
    H, W = image_np.shape[:2]
    gh, gw = heat.shape
    cell_h = math.ceil(H / gh)
    cell_w = math.ceil(W / gw)
    # normalize heat to [0,1]
    hmin = float(heat.min())
    hmax = float(heat.max())
    denom = (hmax - hmin) if (hmax - hmin) > 1e-6 else 1.0
    heat_norm = (heat - hmin) / denom
    heat_rgb_small = (cm.get_cmap(cmap)(heat_norm)[..., :3] * 255).astype(np.uint8)
    up = np.kron(heat_rgb_small, np.ones((cell_h, cell_w, 1), dtype=np.uint8))
    up = up[:H, :W]
    blended = (image_np.astype(np.float32) * (1 - alpha) + up.astype(np.float32) * alpha).astype(np.uint8)
    return blended

def run_explorer(
    model_path: str,
    image_dir: str,
    prompt_text: str = 'Describe the image.',
    conv_mode: str = 'qwen_2',
    device: str = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'),
    diff_threshold: float = 0.05,
    diff_max_changed_patches: int = 50,
    diff_skip_small: bool = False,
):
    # Load model
    model_path = os.path.expanduser(model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, model_name, device=device)
    # Force pad aspect to keep single-image inputs for differential updates
    try:
        setattr(model.config, 'image_aspect_ratio', 'pad')
    except Exception:
        pass
    torch_device = torch.device(device)
    # Differential encoder wrapping
    diff = DifferentialVisionEncoder(
        model,
        diff_threshold=diff_threshold,
        max_changed_patches=diff_max_changed_patches,
        skip_small_updates=diff_skip_small,
        device=torch_device,
    )
    original_encode = getattr(model, 'encode_images', None)
    def encode_images_with_cache(images, image_sizes=None, return_stats=False):
        try:
            return diff.encode(images, image_sizes=image_sizes, return_stats=return_stats)
        except Exception:
            if original_encode is not None:
                try:
                    return original_encode(images, image_sizes=image_sizes, return_stats=return_stats)
                except TypeError:
                    return original_encode(images)
            raise
    model.encode_images = encode_images_with_cache

    # Gather frames
    directory = os.path.expanduser(image_dir)
    frames = _gather_images_from_dir(directory)
    assert len(frames) > 0, f'No images found in {directory}'

    # Grid/token info
    vt = model.get_vision_tower()
    grid_h = grid_w = int(vt.num_patches_per_side)
    num_image_tokens = int(vt.num_patches)

    # State caches
    cache = {}  # idx -> dict with visuals and stats
    last_computed = -1

    # UI controls
    play = widgets.Play(interval=600, value=0, min=0, max=len(frames)-1, step=1, description='Play', disabled=False)
    slider = widgets.IntSlider(min=0, max=len(frames)-1, step=1, description='Frame', readout=True)
    widgets.jslink((play, 'value'), (slider, 'value'))
    reset_btn = widgets.Button(description='Reset', button_style='warning')
    out = widgets.Output()

    # Dtype for images
    try:
        model_dtype = next(model.parameters()).dtype
    except StopIteration:
        model_dtype = torch.float16 if torch_device.type != 'cpu' else torch.float32

    def _process_frame(i: int):
        # Load and preprocess
        img = Image.open(frames[i]).convert('RGB')
        processed = process_images([img], image_processor, model.config)
        img_tensor = processed[0] if isinstance(processed, (list, tuple)) else processed
        if img_tensor.ndim == 3:
            img_tensor = img_tensor.unsqueeze(0)
        img_tensor = img_tensor.to(device=torch_device, dtype=model_dtype)

        # Run differential encode to update cache and get stats
        features, info = diff.encode(img_tensor, image_sizes=[img.size], return_stats=True)

        # Build prompt and input_ids
        prompt = _prepare_prompt(tokenizer, model, prompt_text, conv_mode)
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(torch_device)

        # Prepare multimodal inputs (this calls model.encode_images again; cheap due to cache hit)
        (ii, position_ids, attention_mask, pkv, inputs_embeds, labels) = model.prepare_inputs_labels_for_multimodal(
            input_ids, None, None, None, None, img_tensor, image_sizes=[img.size]
        )

        # Forward to capture attentions on the prompt embeddings
        with torch.no_grad():
            outputs = model.forward(
                input_ids=None,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=pkv,
                inputs_embeds=inputs_embeds,
                labels=None,
                output_attentions=True,
                return_dict=True,
            )

        # Aggregate attention: last layer, mean heads, query=last prompt token
        if outputs.attentions is None or len(outputs.attentions) == 0:
            attn_heat = None
        else:
            last = outputs.attentions[-1][0]  # (num_heads, seq, seq) for batch 0
            head_mean = last.mean(dim=0)  # (seq, seq)
            seq_len = head_mean.shape[0]
            query_idx = seq_len - 1
            vstart, vend = _compute_vision_span(ii, num_image_tokens)
            if vstart is None:
                attn_heat = None
            else:
                vweights = head_mean[query_idx, vstart:vend].detach().float().cpu().numpy()
                attn_heat = vweights.reshape(grid_h, grid_w)

        # Build display image identical to pad preprocessing for overlay
        disp = expand2square(img, tuple(int(x*255) for x in image_processor.image_mean))
        disp = disp.resize((image_processor.crop_size['width'], image_processor.crop_size['height']))
        disp_np = np.array(disp)
        # Changed mask for visualization (recomputed on display-res grid)
        patch_size_px = disp_np.shape[0] // grid_h
        if i == 0:
            changed_mask = np.zeros((grid_h, grid_w), dtype=bool)
        else:
            prev_img = Image.open(frames[i-1]).convert('RGB')
            prev_disp = expand2square(prev_img, tuple(int(x*255) for x in image_processor.image_mean))
            prev_disp = prev_disp.resize((image_processor.crop_size['width'], image_processor.crop_size['height']))
            prev_np = np.array(prev_disp)
            changed_mask = compute_patch_diff(prev_np, disp_np, patch_size=patch_size_px, threshold=diff_threshold)

        vis_changed = visualize_changed_patches(disp_np, changed_mask, patch_size=patch_size_px, alpha=0.35)
        attn_overlay = _heat_to_overlay(disp_np, attn_heat, alpha=0.35) if attn_heat is not None else None

        cache[i] = dict(
            info=info,
            disp=disp_np,
            vis_changed=vis_changed,
            attn_overlay=attn_overlay,
        )

    def _compute_to(target_idx: int):
        nonlocal last_computed
        for k in range(last_computed + 1, target_idx + 1):
            _process_frame(k)
            last_computed = k

    def _on_reset(_):
        nonlocal last_computed
        diff.reset_cache()
        cache.clear()
        last_computed = -1
        with out:
            clear_output(wait=True)
            print('State reset. Move the slider to recompute.')

    def _on_change(change):
        with out:
            clear_output(wait=True)
            idx = int(change['new'])
            if idx > last_computed:
                _compute_to(idx)
            item = cache[idx]
            info = item['info']
            fig, axes = plt.subplots(1, 2, figsize=(11, 5))
            axes[0].imshow(item['vis_changed'])
            axes[0].set_title(f"Changed patches â€” {info['encoding_type']} ({info['changed_patches']}/{info['total_patches']})")
            axes[0].axis('off')
            if item['attn_overlay'] is not None:
                axes[1].imshow(item['attn_overlay'])
                axes[1].set_title('Cross-attention over vision tokens')
            else:
                axes[1].imshow(item['disp'])
                axes[1].set_title('Attention unavailable')
            axes[1].axis('off')
            plt.tight_layout()
            plt.show()
            print('Encoder stats:', diff.get_stats())

    reset_btn.on_click(_on_reset)
    slider.observe(_on_change, names='value')

    display(widgets.HBox([play, slider, reset_btn]))
    display(out)
    # Bootstrap first frame
    _compute_to(0)
    slider.value = 0

# Example usage (uncomment and set paths):
# run_explorer(model_path='./checkpoints/your-llava', image_dir='./data/frames', prompt_text='What happened?', conv_mode='qwen_2')


In [None]:
run_explorer(model_path='/home/sdan/storage/diffcua/checkpoints/llava-fastvithd_7b_stage3', image_dir='/home/sdan/storage/diffcua/differential_vision/agentnet_curated/traj_0000', prompt_text='What happened?', conv_mode='qwen_2')

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]