In [44]:
!pip install accelerate

python(57672) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.




In [2]:
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor

# ‚úÖ Detect Apple Metal GPU
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"‚úÖ Using device: {device}")

# ‚úÖ Load Model & Processor
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
print("‚è≥ Loading model...")

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    attn_implementation="eager"
).to(device)

print("‚úÖ Model loaded successfully!")

processor = AutoProcessor.from_pretrained(model_id)
print("‚úÖ Processor loaded successfully!")

‚úÖ Using device: mps
‚è≥ Loading model...


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

‚úÖ Model loaded successfully!
‚úÖ Processor loaded successfully!


In [3]:
from PIL import Image

# ‚úÖ Load image from local path
image_path = "./image.jpg"  # CHANGE THIS to your actual image path
image = Image.open(image_path).convert("RGB")

print(f"‚úÖ Loaded image: {image_path}, Size: {image.size}")
image.show()  # Show image

‚úÖ Loaded image: ./image.jpg, Size: (960, 504)


In [4]:
# ‚úÖ Create Question Prompt
messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "Describe the person's face on the left?"}
    ]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
print("‚úÖ Processed input text for the model!")

# ‚úÖ Tokenize Inputs
inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt"
).to(device)

print("‚úÖ Inputs prepared, ready for inference!")

‚úÖ Processed input text for the model!
‚úÖ Inputs prepared, ready for inference!


In [5]:
# ‚úÖ Run Inference with Attention Extraction
print("‚è≥ Running inference...")
model.config.output_attentions = True
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=30,
        output_attentions=True,  # Ensure attention extraction
        return_dict_in_generate=True  # Return a dictionary for easy extraction
    )

# ‚úÖ Extract attention maps correctly
if "attentions" in outputs:
    attentions = outputs.attentions  # Extract all decoder layer attentions
    print(f"‚úÖ Inference completed! Extracted {len(attentions)} attention layers.")
else:
    raise ValueError("‚ùå Attention extraction failed. The model did not return attention weights.")

‚è≥ Running inference...


RuntimeError: MPS backend out of memory (MPS allocated: 25.22 GB, other allocations: 2.80 MB, max allowed: 27.20 GB). Tried to allocate 2.47 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch

# ‚úÖ Iterate through each decoder layer's attention
for layer_idx, layer_attn in enumerate(attentions):
    print(f"Processing attention map for Layer {layer_idx + 1}")

    if isinstance(layer_attn, tuple):  
        # ‚úÖ Extract the first element if it's a tuple
        layer_attn = layer_attn[0]

    # ‚úÖ Ensure the attention map has the correct dimensions
    if layer_attn.ndim == 4:  # Expected (batch_size, num_heads, seq_len, seq_len)
        # ‚úÖ Average across all attention heads
        attn_map = layer_attn.mean(dim=1).squeeze().cpu().detach().numpy()
    else:
        print(f"‚ö†Ô∏è Unexpected shape {layer_attn.shape}. Skipping Layer {layer_idx + 1}.")
        continue

    # ‚úÖ Normalize the attention map
    attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-9)

    # ‚úÖ Ensure the attention map is 2D (reshape if needed)
    seq_length = int(attn_map.shape[-1] ** 0.5)  # Assuming attention is square-like
    if attn_map.shape[-1] != seq_length ** 2:
        print(f"‚ö†Ô∏è Skipping Layer {layer_idx + 1}: Non-square attention shape {attn_map.shape}")
        continue

    attn_map = attn_map.reshape(seq_length, seq_length)  # Convert to 2D grid

    # ‚úÖ Resize attention map to match image dimensions
    resize_transform = transforms.Resize((image.height, image.width))
    attn_map_resized = resize_transform(torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)).squeeze().numpy()

    # ‚úÖ Plot the Original Image & Attention Map
    plt.figure(figsize=(12, 6))

    # üîπ Original Image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.axis("off")
    plt.title("Original Image")

    # üîπ Attention Heatmap
    plt.subplot(1, 2, 2)
    plt.imshow(image)
    plt.imshow(attn_map_resized, cmap="jet", alpha=0.5)
    plt.axis("off")
    plt.title(f"Attention Heatmap - Layer {layer_idx+1}")

    # ‚úÖ Save each heatmap separately
    heatmap_filename = f"attention_layer_{layer_idx+1}.png"
    plt.savefig(heatmap_filename, dpi=300)
    print(f"‚úÖ Saved heatmap: {heatmap_filename}")

plt.show()