# ShareGPT4V-7B Model Analysis and Exploration

This notebook explores the architecture and capabilities of the ShareGPT4V-7B model, a multimodal model that can process both text and images.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F

from share4v.model.builder import load_pretrained_model
from share4v.mm_utils import (
    process_images,
    tokenizer_image_token,
)

from share4v.constants import (
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_TOKEN_INDEX,
)

## 1. Model Loading and Configuration

Let's load the pretrained ShareGPT4V-7B model with its components:
- Tokenizer: Converts text to tokens
- Model: The main neural network architecture
- Image processor: Handles image preprocessing
- Context length: Maximum sequence length the model can process

In [None]:
# Model configuration parameters
model_path = "Lin-Chen/ShareGPT4V-7B"  # HuggingFace repository path
model_name = "share4v-7b"               # Model identifier

# Load the pretrained model and its components
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path, None, model_name, False, False
)

print(f"Model loaded successfully. Context length: {context_len}")

## 2. Model Architecture Overview

Let's inspect the high-level architecture of the model to understand its components.

In [None]:
model

## 3. Text Processing Pipeline

Now, we'll explore how the model processes text input. This involves:
1. Converting text to token IDs
2. Examining token representations
3. Visualizing embeddings

This helps us understand what the model "sees" when processing text.

In [None]:
# Define a simple text prompt
prompt = "Tell me a joke about programming."
# Alternative prompt with image: "Tell me something interesting about this image: <image>"
stop_str = "<image>"  # Special token for stopping generation

# Tokenize the input prompt
# tokenizer_image_token handles special image tokens if present
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

# Print the tokenized shape
print(f"Tokenized shape: {input_ids.shape}")

# Add a batch dimension and move tensor to the model's device
input_ids = input_ids.unsqueeze(0).to(model.device)
print(f"Input shape after batching: {input_ids.shape}")

In [None]:
# Get the vocabulary size for reference
vocabulary = tokenizer.get_vocab()
print(f"Vocabulary size: {len(vocabulary)}")

# Display the actual tensor of token IDs
print("Token IDs tensor:")
input_ids

In [None]:
# Decode each token ID to see how the text was tokenized
print(f"Tokenized representation of: '{prompt}'")
print("-" * 50)

for token_id in input_ids[0]:
    if token_id < 0:
        print(f"Token {token_id:5}: RESERVED_TOKEN")  # Special tokens have negative IDs
        continue
    print(f"Token {token_id:5}: '{tokenizer.decode(token_id)}'")

## 4. Image Processing Pipeline

Next, we'll explore how the model processes image input:
1. Loading and preprocessing an image
2. Converting it to tensor format
3. Moving it to the appropriate device

In [None]:
# Load an example image
image_path = "../examples/photo.png"  # Change this path to your desired image
image = Image.open(image_path).convert("RGB")

# Display the original image
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('off')
plt.title("Original Image")
plt.show()

In [None]:
# Load an example image
image = Image.open("../examples/photo.png")


In [None]:
image.size

In [None]:
image_processor

In [None]:
img = image.resize((100,762))
img

In [None]:
pimage = process_images([image, img, image.resize((200,50))], image_processor, model.config)

In [None]:
image_processor

In [None]:
pimage.shape

In [None]:
import torch
from PIL import Image
import torchvision.transforms as T
import numpy as np

def tensor_to_pil_image(tensor):
    """
    Convert a CLIP-processed tensor back to a PIL image.
    
    Args:
        tensor (torch.Tensor): The processed image tensor [C, H, W] (normalized, rescaled)
        
    Returns:
        PIL.Image: The reconstructed PIL image
    """
    # Make sure tensor is on CPU and create a copy to avoid modifying the original
    tensor = tensor.cpu().detach().clone()
    
    
    # Image mean and std from the CLIP processor
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(-1, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(-1, 1, 1)
    
    # Undo normalization
    tensor = tensor * std + mean
    
    # Undo rescaling (multiply by 1/rescale_factor)
    tensor = tensor * (1 / 0.00392156862745098)
    
    # Clamp values to valid image range [0, 255]
    tensor = torch.clamp(tensor, 0, 255)
    
    # Convert to numpy array and correct data type
    image_np = tensor.numpy().transpose(1, 2, 0).astype(np.uint8)
    
    # Create PIL image
    pil_image = Image.fromarray(image_np)
    
    return pil_image

tensor_to_pil_image(pimage[1])

In [None]:

tensor_to_pil_image(pimage[0])

In [None]:

tensor_to_pil_image(pimage[2])

In [None]:
# Load an example image
images = [
    Image.open("../examples/photo.png"),
]
print(f"Loaded image with size: {images[0].size}")

# Process images using ShareGPT4V's image processor
# This includes resizing, normalization, and other transformations
images = process_images(images, image_processor, model.config)

# Ensure images are in the right format and on the correct device
if isinstance(images, list):
    images = [img.to(model.device, dtype=torch.float32) for img in images]
else:
    images = images.to(model.device, dtype=torch.float32)

print(f"Processed image tensor shape: {images.shape}")

In [None]:
image_processor

## 5. Vision Model Analysis

Now we'll examine how the model processes images through its vision tower.

### 5.1 Vision Tower Architecture

First, let's look at the structure of the vision components.

In [None]:
# Examine the vision tower's embedding components
print("Vision embedding components:")
print(model.model.vision_tower.vision_tower.vision_model.embeddings)

In [None]:
# Explore the patch embedding process
# The image is divided into patches that are individually embedded
patch_embeddings = model.model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding(images)
print(f"Patch embedding shape: {patch_embeddings.shape}")
# This shows how the image is divided into spatial patches and projected to the embedding space

# Get the full vision embeddings (patches + position embeddings)
vision_embeddings = model.model.vision_tower.vision_tower.vision_model.embeddings(images)
print(f"Vision embedding shape (with positional info): {vision_embeddings.shape}")

In [None]:
336 / 14

### 5.2 Image Feature Extraction

Let's extract and analyze image features from the vision tower.

In [None]:
# Process the image for the model
images = [image]
images_tensor = process_images(images, image_processor, model.config)
if isinstance(images_tensor, list):
    images_tensor = [img.to(model.device, dtype=torch.float32) for img in images_tensor]
else:
    images_tensor = images_tensor.to(model.device, dtype=torch.float32)

# Extract image features using the vision tower
with torch.no_grad():
    image_features = model.model.vision_tower(images_tensor)

In [None]:

getattr(model.config, "image_aspect_ratio", None)

In [None]:
# Process the image through the full vision tower
emb = model.model.vision_tower(images_tensor)
print(f"Vision tower output shape: {emb.shape}")
# This represents the image features extracted by the vision transformer

In [None]:
# Analyze the distribution of values in the image features
plt.figure(figsize=(10, 6))
plt.hist(
    emb.detach().cpu().numpy().flatten(),
    log=True,
    bins=200,
)
plt.title("Distribution of Vision Features")
plt.xlabel("Feature Value")
plt.ylabel("Log Count")
plt.grid(alpha=0.3)
plt.show()

In [None]:
# Visualize the image embeddings
plt.figure(figsize=(16, 8))
plt.imshow(
    emb.detach().cpu().numpy()[0, :, :],
    vmin=-5,
    vmax=5,
    cmap="bwr",
    aspect="auto"
)
plt.colorbar(label="Feature Value")
plt.title("Vision Tower Output Features")
plt.xlabel("Feature Dimension")
plt.ylabel("Image Patch")
plt.show()

## 6. Multimodal Integration Analysis

Now, let's examine how visual and textual information are aligned and integrated.

### 6.1 Vision-Language Projection

The model uses a projection layer to map visual features into the same space as text embeddings.

In [None]:
# Map image features to text token space with the projector
with torch.no_grad():
    projected_features = model.model.mm_projector(image_features)

# Create a function to find the closest token for each embedding
def find_closest_tokens(embeddings, token_embeddings, top_k=1):
    # Compute cosine similarity between embeddings and token embeddings
    normalized_embeddings = F.normalize(embeddings, p=2, dim=-1)
    normalized_token_embeddings = F.normalize(token_embeddings, p=2, dim=-1)
    similarities = torch.matmul(normalized_embeddings, normalized_token_embeddings.T)
    
    # Get the top-k token indices with highest similarity
    if top_k == 1:
        closest_token_indices = similarities.argmax(dim=-1)
        return closest_token_indices
    else:
        top_k_values, top_k_indices = torch.topk(similarities, k=top_k, dim=-1)
        return top_k_indices, top_k_values
    
# Get the token embeddings from the model
token_embeddings = model.model.embed_tokens.weight

#### test it with text embeddings

In [None]:
# input_ids 
prompt = "Tell me a joke about programming."
stop_str = "<image>"
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
embedded_tokens = model.model.embed_tokens(input_ids)
embedded_tokens.shape

In [None]:
# Get the closest token for each patch embedding
closest_token_indices = find_closest_tokens(embedded_tokens, token_embeddings)

# Decode the tokens
closest_tokens = [tokenizer.decode(idx.item()).strip() for idx in closest_token_indices]
closest_tokens

#### test it with image

In [None]:
# Get the closest token for each patch embedding
closest_token_indices = find_closest_tokens(projected_features[0, :], token_embeddings)

# Decode the tokens
closest_tokens = [tokenizer.decode(idx.item()).strip() for idx in closest_token_indices]
closest_tokens[:10]

In [None]:
# Project vision features to language embedding space
mm_emb = model.model.mm_projector(emb)
print(f"After projection: {mm_emb.shape}")
print(f"Compare with text embedding shape: {embedded_tokens.shape}")

In [None]:
# Analyze the distribution of values after projection
plt.figure(figsize=(10, 6))
plt.hist(
    mm_emb.detach().cpu().numpy().flatten(),
    log=True,
    bins=200,
)
plt.title("Distribution of Projected Vision Features")
plt.xlabel("Feature Value")
plt.ylabel("Log Count")
plt.grid(alpha=0.3)
plt.show()

In [None]:
# Visualize the projected image embeddings
plt.figure(figsize=(16, 8))
plt.imshow(
    mm_emb.detach().cpu().numpy()[0, ::6, :512],  # Sample every 6th patch for clarity
    cmap="bwr",
    aspect="auto"
)
plt.colorbar(label="Feature Value")
plt.title("Vision Features After Projection to Language Space")
plt.xlabel("Feature Dimension (first 512)")
plt.ylabel("Image Patch (sampled)")
plt.tight_layout()

### 6.2 Vision-Language Alignment

To understand the alignment between vision and language, we can map image features to the closest text tokens.

In [None]:
# Define a helper function to map embeddings back to token IDs
# This helps us understand what embedding vectors "mean" in token space
def embedding_to_token_id(embedding, embed_tokens):
    """
    Maps embedding vectors to token IDs by finding closest tokens in the embedding space.
    
    Args:
        embedding: The embedding vectors to map back to tokens
        embed_tokens: The embedding layer that maps tokens to vectors
    
    Returns:
        Token IDs that are closest to the given embeddings in vector space
    """
    # Get the embedding weight matrix (shape: vocab_size x embedding_dim)
    weight = embed_tokens.weight  # shape: (vocab_size, embedding_dim)
    
    # If embedding is a single vector, add batch dimension
    if embedding.dim() == 1:
        embedding = embedding.unsqueeze(0)
    
    # Compute cosine similarity between the embedding and all rows in the weight matrix
    # Normalize embeddings along the embedding dimension
    normalized_embedding = F.normalize(embedding, p=2, dim=-1)
    normalized_weight = F.normalize(weight, p=2, dim=-1)
    similarities = torch.matmul(normalized_embedding, normalized_weight.T)  # shape: (batch_size, vocab_size)
    
    # Get the token id with the highest similarity for each embedding in the batch
    token_ids = similarities.argmax(dim=-1)
    
    return token_ids

In [None]:
# Map projected vision features back to token IDs to see "what the image says"
token_ids_recovered = embedding_to_token_id(mm_emb, model.model.embed_tokens)

# Display some of the recovered tokens
print("Image 'translated' to text (full):")
print(tokenizer.decode(token_ids_recovered[0]))

In [None]:
# Sample some tokens to see what the model "sees" in the image
print("Sample tokens from image (every 10th patch):")
print("-" * 50)
for token_id in token_ids_recovered[0][::100]:
    if token_id < 0:
        print(f"Token {token_id:5}: RESERVED_TOKEN")
        continue
    print(f"Token {token_id:5}: '{tokenizer.decode(token_id)}'")

## 7. Text Embedding Analysis

Let's analyze how the model embeds text tokens and processes them through its layers.

### 7.1 Token Embedding

First, we'll examine the initial embedding of input tokens.

In [None]:
# Get the token embeddings from the embedding layer
embedded_tokens = model.model.embed_tokens(input_ids)
print(f"Token embedding shape: {embedded_tokens.shape}")
# The shape is (batch_size, sequence_length, embedding_dim)

In [None]:
# Define a helper function to map embeddings back to token IDs
# This helps us understand what embedding vectors "mean" in token space
def embedding_to_token_id(embedding, embed_tokens):
    """
    Maps embedding vectors to token IDs by finding closest tokens in the embedding space.
    
    Args:
        embedding: The embedding vectors to map back to tokens
        embed_tokens: The embedding layer that maps tokens to vectors
    
    Returns:
        Token IDs that are closest to the given embeddings in vector space
    """
    # Get the embedding weight matrix (shape: vocab_size x embedding_dim)
    weight = embed_tokens.weight  # shape: (vocab_size, embedding_dim)
    
    # If embedding is a single vector, add batch dimension
    if embedding.dim() == 1:
        embedding = embedding.unsqueeze(0)
    
    # Compute cosine similarity between the embedding and all rows in the weight matrix
    # Normalize embeddings along the embedding dimension
    normalized_embedding = F.normalize(embedding, p=2, dim=-1)
    normalized_weight = F.normalize(weight, p=2, dim=-1)
    similarities = torch.matmul(normalized_embedding, normalized_weight.T)  # shape: (batch_size, vocab_size)
    
    # Get the token id with the highest similarity for each embedding in the batch
    token_ids = similarities.argmax(dim=-1)
    
    return token_ids

In [None]:
# Map token embeddings back to tokens to verify our understanding
token_ids_recovered = embedding_to_token_id(embedded_tokens, model.model.embed_tokens)

# Print the tokens to see if they match our original input
print("Recovered tokens from embeddings:")
print(tokenizer.decode(token_ids_recovered[0]))

In [None]:
embedded_tokens.shape

In [None]:
# Visualize the token embeddings
plt.figure(figsize=(16, 2))
plt.imshow(
    embedded_tokens.detach().cpu().numpy()[:,:512],  # Take first 512 dimensions of each token embedding
    cmap="bwr",
    aspect="auto",
)
plt.colorbar(label="Embedding Value")
plt.title("Token Embedding Visualization")
plt.xlabel("Embedding Dimension (first 512)")
plt.ylabel("Token Position")
plt.tight_layout()

### 7.2 Layer-by-Layer Processing

Let's trace how the token representations change as they pass through the model's transformer layers.

In [None]:
# Pass the embeddings through each layer of the model
enc_output = embedded_tokens
print("Tracing embeddings through model layers:")
print("-" * 50)

for i, layer in enumerate(model.model.layers, start=1):
    enc_output = layer(enc_output.unsqueeze(0))[0]
    print(f"Layer {i}: Shape = {enc_output.shape}")
    
    # Optionally visualize intermediate layer outputs
    if i % 8 == 0:  # Visualize every 8th layer
        plt.figure(figsize=(16, 2))
        plt.imshow(
            enc_output.detach().cpu().numpy()[0, :, :512], 
            cmap="bwr",
            aspect="auto",
        )
        plt.colorbar(label="Activation Value")
        plt.title(f"Layer {i} Output")
        plt.xlabel("Hidden Dimension (first 512)")
        plt.ylabel("Token Position")
        plt.tight_layout()

## 8. Text Generation Demo

Now we'll demonstrate the model's text generation capabilities:
1. Setting generation parameters
2. Using a streamer for real-time output
3. Generating text based on our prompt

In [None]:
from transformers import TextIteratorStreamer
from threading import Thread

# Generation parameters
temperature = 0.1    # Lower values make output more deterministic (less random)
top_p = 0.95         # Nucleus sampling parameter (higher = more diversity)
max_new_tokens = 256  # Maximum number of tokens to generate

# Create a streamer for generating text progressively
streamer = TextIteratorStreamer(
    tokenizer,
    skip_prompt=True,      # Don't include the prompt in the output
    skip_special_tokens=True,  # Don't include special tokens in the output
    timeout=150,           # Timeout in seconds
)

# Generate text using the model
print("Generating response to:", prompt)
print("-" * 50)

with torch.inference_mode():  # Disables gradient computation for inference
    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs=dict(
        inputs=input_ids,
        do_sample=True,           # Use sampling instead of greedy decoding
        temperature=temperature,
        top_p=top_p,
        max_new_tokens=max_new_tokens,
        streamer=streamer,
        use_cache=True,           # Use KV cache for faster generation
        # images=images,          # Uncomment to use image input
    ))
    thread.start()

    # Collect generated text from the streamer
    print("Generated text:")
    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        # Stop if we encounter the stop string
        if generated_text.endswith(stop_str):
            generated_text = generated_text[:-len(stop_str)]
        print(new_text, end="")
    print()

In [None]:
# Clean up the thread
thread.join()
del thread
print("Generation completed.")

## 9. Visualize Image Patches with Closest Text Tokens

Let's create a visualization that shows the input image with the closest text tokens written on top of each patch. This helps us understand how the vision-language model processes and interprets different regions of the image.

In [None]:
# Visualize the closest text tokens on the image patches
def visualize_tokens_on_patches(image, tokens, patch_size=14, grid_size=None):
    # If grid_size is not provided, calculate it from the number of tokens
    if grid_size is None:
        # The sqrt of the number of tokens gives us the grid size
        # (assuming square patches)
        grid_size = int(np.sqrt(len(tokens)))
    
    # Resize the image to match the grid size * patch_size
    target_size = (grid_size * patch_size, grid_size * patch_size)
    resized_img = image.resize(target_size, Image.LANCZOS)
    
    # Create a figure and axis for plotting
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.imshow(resized_img)
    
    # Plot the tokens on top of each patch
    for i in range(grid_size):
        for j in range(grid_size):
            token_idx = i * grid_size + j
            if token_idx < len(tokens):
                token = tokens[token_idx]
                # Limit token display length to avoid overcrowding
                if len(token) > 5:
                    token = token[:4] + '...'
                ax.text(j * patch_size + patch_size // 2, i * patch_size + patch_size // 2, 
                        token, color='white', fontsize=8, ha='center', va='center',
                        bbox=dict(boxstyle="round,pad=0.2", fc='black', alpha=0.5))
    
    ax.set_title("Image Patches with Closest Text Tokens", fontsize=16)
    ax.axis('off')
    plt.tight_layout()
    plt.show()

# Get the vision model configuration to determine patch size and grid size
vision_config = model.model.vision_tower.vision_tower.config
patch_size = vision_config.patch_size
image_size = vision_config.image_size
grid_size = image_size // patch_size

# Visualize the tokens on the image
visualize_tokens_on_patches(image, closest_tokens, patch_size=patch_size, grid_size=grid_size)

In [None]:
# Create a more detailed visualization with top-3 tokens and their probabilities
def visualize_top_k_tokens_on_patches(image, token_embeddings, patch_embeddings, tokenizer, 
                                   patch_size=14, grid_size=None, top_k=3):
    if grid_size is None:
        grid_size = int(np.sqrt(len(patch_embeddings)))
    
    # Get top-k tokens for each patch
    top_indices, top_scores = find_closest_tokens(patch_embeddings, token_embeddings, top_k=top_k)
    
    # Convert scores to probabilities via softmax
    top_probs = F.softmax(top_scores, dim=-1)
    
    # Decode the top tokens
    top_tokens = []
    for i in range(len(top_indices)):
        tokens = [tokenizer.decode(idx.item()).strip() for idx in top_indices[i]]
        probs = top_probs[i].tolist()
        top_tokens.append(list(zip(tokens, probs)))
    
    # Resize the image to match the grid
    target_size = (grid_size * patch_size, grid_size * patch_size)
    resized_img = image.resize(target_size, Image.LANCZOS)
    
    # Create a figure for visualization
    fig, ax = plt.subplots(figsize=(30, 30))
    ax.imshow(resized_img)
    
    # Plot the top tokens on each patch
    for i in range(grid_size):
        for j in range(grid_size):
            patch_idx = i * grid_size + j
            if patch_idx < len(top_tokens):
                token_info = top_tokens[patch_idx]
                # Format the text: token (prob%)
                text = '\n'.join([f"{t[:4]}.. ({p:.0%})" if len(t) > 5 else f"{t} ({p:.0%})" 
                                  for t, p in token_info])
                ax.text(j * patch_size + patch_size // 2, i * patch_size + patch_size // 2, 
                       text, color='white', fontsize=7, ha='center', va='center',
                       bbox=dict(boxstyle="round,pad=0.2", fc='black', alpha=0.6))
    
    ax.set_title(f"Image Patches with Top-{top_k} Closest Text Tokens and Their Probabilities", fontsize=16)
    ax.axis('off')
    plt.tight_layout()
    plt.show()

# Get top-3 tokens for each patch
with torch.no_grad():
    visualize_top_k_tokens_on_patches(
        image, 
        token_embeddings,
        projected_features[0, 1:],
        tokenizer,
        patch_size=patch_size,
        grid_size=grid_size,
        top_k=2
    )

## 10. Conclusion

This notebook has explored the ShareGPT4V-7B model's architecture and capabilities:

1. **Model Structure**: The model combines a LLaMA-based language model with a vision transformer
2. **Text Processing**: Text is tokenized and embedded into a high-dimensional space
3. **Image Processing**: Images are divided into patches and processed through a vision transformer
4. **Multimodal Integration**: Visual features are projected into the language embedding space
5. **Generation**: The model can generate text based on either text or text+image inputs

The key to multimodal capabilities is the projection of visual features into the language embedding space, allowing the model to process both modalities coherently.

Further explorations could include:
- Analyzing attention patterns between image and text tokens
- Testing the model's performance on various visual reasoning tasks
- Examining cross-modal transfer and emergent capabilities