# BAGEL Model Representation Extraction

This notebook demonstrates how to load the BAGEL model and extract internal representations from different components.

## Setup and Imports

In [None]:
import os
import torch
import numpy as np
from PIL import Image
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights

from data.data_utils import add_special_tokens, pil_img2rgb
from data.transforms import ImageTransform
from modeling.autoencoder import load_ae
from modeling.bagel.qwen2_navit import NaiveCache
from modeling.bagel import (
    BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
    SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer

## Model Loading

In [None]:
# Model path - adjust if needed
model_path = "models/BAGEL-7B-MoT"

# Load configurations
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"

vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers -= 1

vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))

config = BagelConfig(
    visual_gen=True,
    visual_und=True,
    llm_config=llm_config, 
    vit_config=vit_config,
    vae_config=vae_config,
    vit_max_num_patch_per_side=70,
    connector_act='gelu_pytorch_tanh',
    latent_patch_size=2,
    max_latent_size=64,
)

print("Configurations loaded successfully!")

In [None]:
# Create model with empty weights
with init_empty_weights():
    language_model = Qwen2ForCausalLM(llm_config)
    vit_model = SiglipVisionModel(vit_config)
    model = Bagel(language_model, vit_model, config)
    model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)

# Load tokenizer
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)

# Image transforms
vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)

print("Model architecture created!")

In [None]:
# Device mapping for multi-GPU
device_map = infer_auto_device_map(
    model,
    max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
    no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
)

same_device_modules = [
    'language_model.model.embed_tokens',
    'time_embedder',
    'latent_pos_embed',
    'vae2llm',
    'llm2vae',
    'connector',
    'vit_pos_embed'
]

if torch.cuda.device_count() == 1:
    first_device = device_map.get(same_device_modules[0], "cuda:0")
    for k in same_device_modules:
        if k in device_map:
            device_map[k] = first_device
        else:
            device_map[k] = "cuda:0"
else:
    first_device = device_map.get(same_device_modules[0])
    for k in same_device_modules:
        if k in device_map:
            device_map[k] = first_device

print(f"Device map: {device_map}")

In [None]:
# Load model weights (using full precision - mode 1)
model = load_checkpoint_and_dispatch(
    model,
    checkpoint=os.path.join(model_path, "ema.safetensors"),
    device_map=device_map,
    offload_buffers=True,
    offload_folder="offload",
    dtype=torch.bfloat16,
    force_hooks=True,
).eval()

print("Model loaded successfully!")
print(f"Model device: {next(model.parameters()).device}")

## Representation Extraction Functions

In [None]:
def extract_text_embeddings(text, max_length=512):
    """
    Extract text embeddings from the language model component.
    """
    # Tokenize text
    inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True)
    input_ids = inputs["input_ids"].to(next(model.parameters()).device)
    
    with torch.no_grad():
        # Get embeddings from the language model
        embeddings = model.language_model.model.embed_tokens(input_ids)
        
        # Get hidden states from language model layers
        hidden_states = model.language_model.model(input_ids, output_hidden_states=True)
        
    return {
        "input_ids": input_ids,
        "embeddings": embeddings,
        "hidden_states": hidden_states.hidden_states,
        "last_hidden_state": hidden_states.last_hidden_state
    }

def extract_image_features(image_path):
    """
    Extract image features from both ViT and VAE encoders.
    """
    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    
    # ViT preprocessing
    vit_image = vit_transform(pil_img2rgb(image)).unsqueeze(0)
    vit_image = vit_image.to(next(model.parameters()).device)
    
    # VAE preprocessing  
    vae_image = vae_transform(pil_img2rgb(image)).unsqueeze(0)
    vae_image = vae_image.to(next(model.parameters()).device)
    
    with torch.no_grad():
        # ViT features
        vit_outputs = model.vit_model(vit_image, output_hidden_states=True)
        vit_features = vit_outputs.last_hidden_state
        vit_pooled = vit_outputs.pooler_output if hasattr(vit_outputs, 'pooler_output') else None
        
        # VAE features
        vae_features = vae_model.encode(vae_image).latent_dist.sample()
        
    return {
        "vit_features": vit_features,
        "vit_pooled": vit_pooled,
        "vit_hidden_states": vit_outputs.hidden_states,
        "vae_features": vae_features,
        "original_image": image
    }

def extract_multimodal_representations(text, image_path=None):
    """
    Extract representations from the full multimodal pipeline.
    """
    # Prepare inputs
    inputs = {"text": text}
    
    if image_path:
        image = Image.open(image_path).convert("RGB")
        inputs["image"] = image
        
        # Process image through transforms
        vit_image = vit_transform(pil_img2rgb(image)).unsqueeze(0)
        vae_image = vae_transform(pil_img2rgb(image)).unsqueeze(0)
        inputs["vit_image"] = vit_image.to(next(model.parameters()).device)
        inputs["vae_image"] = vae_image.to(next(model.parameters()).device)
    
    # Tokenize text
    text_inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    input_ids = text_inputs["input_ids"].to(next(model.parameters()).device)
    
    with torch.no_grad():
        # Get text embeddings
        text_embeds = model.language_model.model.embed_tokens(input_ids)
        
        if image_path:
            # Get image features
            vit_outputs = model.vit_model(inputs["vit_image"], output_hidden_states=True)
            vae_features = vae_model.encode(inputs["vae_image"]).latent_dist.sample()
            
            # Get cross-modal representations through connectors
            vit_features_processed = model.connector(vit_outputs.last_hidden_state)
            
            return {
                "text_embeddings": text_embeds,
                "vit_features": vit_outputs.last_hidden_state,
                "vit_processed": vit_features_processed,
                "vae_features": vae_features,
                "input_ids": input_ids,
                "multimodal_ready": True
            }
        else:
            return {
                "text_embeddings": text_embeds,
                "input_ids": input_ids,
                "multimodal_ready": False
            }

print("Representation extraction functions defined!")

## Usage Examples

### Extract Text Representations

In [None]:
# Example text
sample_text = "A beautiful sunset over the mountains with golden light."

# Extract text representations
text_repr = extract_text_embeddings(sample_text)

print(f"Text: {sample_text}")
print(f"Input IDs shape: {text_repr['input_ids'].shape}")
print(f"Embeddings shape: {text_repr['embeddings'].shape}")
print(f"Last hidden state shape: {text_repr['last_hidden_state'].shape}")
print(f"Number of hidden layers: {len(text_repr['hidden_states'])}")

# Show embedding statistics
embedding_mean = text_repr['embeddings'].mean().item()
embedding_std = text_repr['embeddings'].std().item()
print(f"Embedding mean: {embedding_mean:.4f}, std: {embedding_std:.4f}")

### Extract Image Representations

In [None]:
# Note: You'll need to provide an actual image path
# Uncomment and modify the path below when you have an image to test

# image_path = "path/to/your/image.jpg"
# 
# # Extract image representations
# image_repr = extract_image_features(image_path)
# 
# print(f"ViT features shape: {image_repr['vit_features'].shape}")
# print(f"VAE features shape: {image_repr['vae_features'].shape}")
# print(f"Number of ViT hidden layers: {len(image_repr['vit_hidden_states'])}")
# 
# # Show feature statistics
# vit_mean = image_repr['vit_features'].mean().item()
# vit_std = image_repr['vit_features'].std().item()
# vae_mean = image_repr['vae_features'].mean().item()
# vae_std = image_repr['vae_features'].std().item()
# 
# print(f"ViT features - mean: {vit_mean:.4f}, std: {vit_std:.4f}")
# print(f"VAE features - mean: {vae_mean:.4f}, std: {vae_std:.4f}")

print("Uncomment the code above and provide an image path to test image feature extraction")

### Extract Multimodal Representations

In [None]:
# Text-only multimodal extraction
multimodal_text = extract_multimodal_representations("Describe this beautiful landscape.")

print("Text-only multimodal extraction:")
print(f"Text embeddings shape: {multimodal_text['text_embeddings'].shape}")
print(f"Multimodal ready: {multimodal_text['multimodal_ready']}")

# Uncomment below for text + image multimodal extraction
# image_path = "path/to/your/image.jpg"
# multimodal_both = extract_multimodal_representations(
#     "What do you see in this image?", 
#     image_path=image_path
# )
# 
# print("\nText + Image multimodal extraction:")
# print(f"Text embeddings shape: {multimodal_both['text_embeddings'].shape}")
# print(f"ViT features shape: {multimodal_both['vit_features'].shape}")
# print(f"ViT processed shape: {multimodal_both['vit_processed'].shape}")
# print(f"VAE features shape: {multimodal_both['vae_features'].shape}")
# print(f"Multimodal ready: {multimodal_both['multimodal_ready']}")

### Analyze Model Architecture

In [None]:
# Analyze model components
print("BAGEL Model Architecture:")
print(f"Language model: {type(model.language_model).__name__}")
print(f"Vision model: {type(model.vit_model).__name__}")
print(f"VAE model: {type(vae_model).__name__}")
print(f"\nModel config:")
print(f"- Visual generation: {config.visual_gen}")
print(f"- Visual understanding: {config.visual_und}")
print(f"- Max latent size: {config.max_latent_size}")
print(f"- Latent patch size: {config.latent_patch_size}")
print(f"- ViT max patches per side: {config.vit_max_num_patch_per_side}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Parameters:")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

### Export Representations

In [None]:
def save_representations(representations, filename):
    """
    Save representations to file for further analysis.
    """
    # Convert tensors to numpy for saving
    numpy_repr = {}
    for key, value in representations.items():
        if torch.is_tensor(value):
            numpy_repr[key] = value.cpu().numpy()
        elif isinstance(value, list) and torch.is_tensor(value[0]):
            numpy_repr[key] = [v.cpu().numpy() for v in value]
        else:
            numpy_repr[key] = value
    
    np.savez(filename, **numpy_repr)
    print(f"Representations saved to {filename}")

# Example: Save text representations
save_representations(text_repr, "text_representations.npz")

# Load back example
loaded_repr = np.load("text_representations.npz", allow_pickle=True)
print(f"Loaded representation keys: {list(loaded_repr.keys())}")
print(f"Loaded embeddings shape: {loaded_repr['embeddings'].shape}")

## Summary

This notebook provides tools to:

1. **Load the BAGEL model** following the official loading pattern
2. **Extract text representations** from the language model component
3. **Extract image features** from both ViT and VAE encoders
4. **Extract multimodal representations** combining text and image processing
5. **Analyze model architecture** and parameter counts
6. **Save representations** for further analysis

The extracted representations can be used for:
- Understanding model internal states
- Building interpretability tools
- Creating safety mechanisms
- Analyzing multimodal alignment
- Fine-tuning experiments