# BAGEL Model Representation Extraction

This notebook demonstrates how to load the BAGEL model and extract internal representations from different components.

## Setup and Imports

In [None]:
!git clone https://github.com/ByteDance-Seed/Bagel.git
!pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.1.post1/flash_attn-2.7.1.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
# Uninstall the incompatible flash-attention wheel
# !pip uninstall -y flash_attn
# Install flash-attention from source
# !pip install --no-build-isolation flash-attn==2.6.1

%cd Bagel

Cloning into 'Bagel'...
remote: Enumerating objects: 377, done.[K
remote: Counting objects: 100% (243/243), done.[K
remote: Compressing objects: 100% (152/152), done.[K
remote: Total 377 (delta 136), reused 133 (delta 91), pack-reused 134 (from 2)[K
Receiving objects: 100% (377/377), 2.24 MiB | 13.21 MiB/s, done.
Resolving deltas: 100% (162/162), done.
Collecting flash-attn==2.7.1.post1+cu12torch2.6cxx11abiFALSE
  Downloading https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.1.post1/flash_attn-2.7.1.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl (183.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.4/183.4 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->flash-attn==2.7.1.post1+cu12torch2.6cxx11abiFALSE)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->flash-attn==2.7

In [None]:
import os
import torch
import numpy as np
from PIL import Image
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights

from data.data_utils import add_special_tokens, pil_img2rgb
from data.transforms import ImageTransform
from modeling.autoencoder import load_ae
from modeling.bagel.qwen2_navit import NaiveCache
from modeling.bagel import (
    BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
    SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer

# Assume the model is on Hugging Face Hub, adjust the repo ID if necessary
!pip install huggingface_hub

from huggingface_hub import hf_hub_download

# Model path - adjust if needed
model_repo_id = "ByteDance-Seed/Bagel-7B-MoT" # Replace with the correct Hugging Face model ID
model_path = "models/BAGEL-7B-MoT" # Local directory to save the downloaded files

# Create the local directory if it doesn't exist
import os
os.makedirs(model_path, exist_ok=True)

# Download the necessary model files
try:
    hf_hub_download(repo_id=model_repo_id, filename="llm_config.json", local_dir=model_path, local_dir_use_symlinks=False)
    hf_hub_download(repo_id=model_repo_id, filename="vit_config.json", local_dir=model_path, local_dir_use_symlinks=False)
    hf_hub_download(repo_id=model_repo_id, filename="ae.safetensors", local_dir=model_path, local_dir_use_symlinks=False)
    hf_hub_download(repo_id=model_repo_id, filename="ema.safetensors", local_dir=model_path, local_dir_use_symlinks=False)
    # Also download the tokenizer files if they are separate
    hf_hub_download(repo_id=model_repo_id, filename="tokenizer.json", local_dir=model_path, local_dir_use_symlinks=False)
    hf_hub_download(repo_id=model_repo_id, filename="tokenizer_config.json", local_dir=model_path, local_dir_use_symlinks=False)
    # hf_hub_download(repo_id=model_repo_id, filename="special_tokens_map.json", local_dir=model_path, local_dir_use_symlinks=False)
    # hf_hub_download(repo_id=model_repo_id, filename="tokenizer.model", local_dir=model_path, local_dir_use_symlinks=False) # For SentencePiece
    hf_hub_download(repo_id=model_repo_id, filename="vocab.json", local_dir=model_path, local_dir_use_symlinks=False)
    hf_hub_download(repo_id=model_repo_id, filename="merges.txt", local_dir=model_path, local_dir_use_symlinks=False)
except Exception as e:
    print(f"Error downloading model files: {e}")
    print("Please ensure the model ID is correct and the files exist on the Hugging Face Hub.")
    print("Alternatively, download the model files manually and place them in the './Bagel/models/BAGEL-7B-MoT' directory.")

## Model Loading

In [None]:
# Model path - adjust if needed
model_path = "models/BAGEL-7B-MoT"

# Load configurations
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"

vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers -= 1

vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))

config = BagelConfig(
    visual_gen=True,
    visual_und=True,
    llm_config=llm_config,
    vit_config=vit_config,
    vae_config=vae_config,
    vit_max_num_patch_per_side=70,
    connector_act='gelu_pytorch_tanh',
    latent_patch_size=2,
    max_latent_size=64,
)

print("Configurations loaded successfully!")

In [None]:
# Create model with empty weights
with init_empty_weights():
    language_model = Qwen2ForCausalLM(llm_config)
    vit_model = SiglipVisionModel(vit_config)
    model = Bagel(language_model, vit_model, config)
    model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)

# Load tokenizer
# Explicitly define paths to the required tokenizer files based on the repository structure
vocab_file_path = os.path.join(model_path, "vocab.json")
merges_file_path = os.path.join(model_path, "merges.txt")
tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
special_tokens_path = os.path.join(model_path, "special_tokens_map.json")

# Check if the necessary files exist
if not os.path.exists(vocab_file_path):
    raise FileNotFoundError(f"Required tokenizer file not found: {vocab_file_path}")
if not os.path.exists(merges_file_path):
    raise FileNotFoundError(f"Required tokenizer file not found: {merges_file_path}")
# Optional checks for other files
if not os.path.exists(tokenizer_config_path):
    print(f"Warning: Tokenizer config file not found at {tokenizer_config_path}. Tokenizer might load with default settings.")
if not os.path.exists(special_tokens_path):
    print(f"Warning: Special tokens map file not found at {special_tokens_path}. Special tokens might not be handled correctly.")


# Initialize the tokenizer by explicitly passing the paths
# Based on standard Qwen2Tokenizer initialization, it expects vocab_file and merges_file
try:
    tokenizer = Qwen2Tokenizer(
        vocab_file=vocab_file_path,
        merges_file=merges_file_path,
        # You might also pass other config files if needed and the constructor supports it
        # e.g., tokenizer_config_file=tokenizer_config_path, special_tokens_map_file=special_tokens_path
    )
    print("Tokenizer initialized by explicitly providing file paths successfully!")

except Exception as e:
    print(f"Explicit tokenizer initialization failed: {e}")
    print("Please verify the parameters accepted by the Qwen2Tokenizer constructor")
    print("in the Bagel repository's modeling/qwen2/tokenization_qwen2.py file.")
    # Re-raise the error as it indicates a deeper issue with file paths or constructor
    raise e


# Correct the incomplete line for adding special tokens
# Assuming add_special_tokens is a function defined elsewhere (likely data.data_utils)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)

# Image transforms
vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)

print("Model architecture created!")

In [None]:
# Device mapping for multi-GPU
device_map = infer_auto_device_map(
    model,
    max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
    no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
)

same_device_modules = [
    'language_model.model.embed_tokens',
    'time_embedder',
    'latent_pos_embed',
    'vae2llm',
    'llm2vae',
    'connector',
    'vit_pos_embed'
]

if torch.cuda.device_count() == 1:
    first_device = device_map.get(same_device_modules[0], "cuda:0")
    for k in same_device_modules:
        if k in device_map:
            device_map[k] = first_device
        else:
            device_map[k] = "cuda:0"
else:
    first_device = device_map.get(same_device_modules[0])
    for k in same_device_modules:
        if k in device_map:
            device_map[k] = first_device

print(f"Device map: {device_map}")

In [None]:
# Load model weights (using full precision - mode 1)
model = load_checkpoint_and_dispatch(
    model,
    checkpoint=os.path.join(model_path, "ema.safetensors"),
    device_map=device_map,
    offload_buffers=True,
    offload_folder="offload",
    dtype=torch.bfloat16,
    force_hooks=True,
).eval()

print("Model loaded successfully!")
print(f"Model device: {next(model.parameters()).device}")

## Representation Extraction Functions

In [None]:
def extract_text_embeddings(text, max_length=512):
    """
    Extract text embeddings from the language model component.
    
    Note: BAGEL uses custom packed input formats. We extract embeddings directly
    and use forward hooks to capture intermediate representations.
    """
    # Tokenize text
    inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True)
    input_ids = inputs["input_ids"].to(next(model.parameters()).device)
    
    with torch.no_grad():
        # Get embeddings directly from the embedding layer
        embeddings = model.language_model.model.embed_tokens(input_ids)
        
        # For BAGEL, we need to use the full model forward pass to get meaningful representations
        # Create a simple text-only input for BAGEL
        try:
            # Use BAGEL's generate method to get representations during inference
            # This is the proper way to access BAGEL's internal states
            hidden_states_collected = []
            
            def collect_hidden_states(module, input, output):
                if hasattr(output, 'last_hidden_state'):
                    hidden_states_collected.append(output.last_hidden_state.detach())
                elif isinstance(output, tuple) and len(output) > 0:
                    hidden_states_collected.append(output[0].detach())
            
            # Register hook on the language model to capture output
            hook = model.language_model.register_forward_hook(collect_hidden_states)
            
            # Create BAGEL-compatible input
            # We'll use a minimal generation to trigger the forward pass
            generation_outputs = model.generate(
                input_ids,
                max_new_tokens=1,  # Minimal generation
                do_sample=False,
                return_dict_in_generate=True,
                output_hidden_states=False,  # Don't request hidden states from generate
                pad_token_id=tokenizer.eos_token_id
            )
            
            # Remove the hook
            hook.remove()
            
            # Get the last hidden state from our collected states
            last_hidden_state = hidden_states_collected[-1] if hidden_states_collected else None
            
            return {
                "input_ids": input_ids,
                "embeddings": embeddings,
                "hidden_states": hidden_states_collected if hidden_states_collected else None,
                "last_hidden_state": last_hidden_state,
                "generation_outputs": generation_outputs
            }
            
        except Exception as e:
            print(f"Note: Full forward pass not available ({e}). Returning embeddings only.")
            return {
                "input_ids": input_ids,
                "embeddings": embeddings,
                "hidden_states": None,
                "last_hidden_state": None,
                "generation_outputs": None
            }

def extract_text_embeddings_simple(text, max_length=512):
    """
    Simplified text embedding extraction - just embeddings and tokenization.
    This always works regardless of BAGEL's complex forward methods.
    """
    # Tokenize text
    inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True)
    input_ids = inputs["input_ids"].to(next(model.parameters()).device)
    
    with torch.no_grad():
        # Get embeddings directly from the embedding layer
        embeddings = model.language_model.model.embed_tokens(input_ids)
        
        # Get vocabulary size and embedding dimension
        vocab_size = embeddings.shape[-1]
        seq_len = embeddings.shape[1]
        
    return {
        "input_ids": input_ids,
        "embeddings": embeddings,
        "text": text,
        "tokens": tokenizer.convert_ids_to_tokens(input_ids[0]),
        "vocab_size": vocab_size,
        "sequence_length": seq_len,
        "embedding_dim": vocab_size
    }

def extract_image_features(image_path):
    """
    Extract image features from both ViT and VAE encoders.
    Uses direct model access for the components that support it.
    """
    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    
    # ViT preprocessing
    vit_image = vit_transform(pil_img2rgb(image)).unsqueeze(0)
    vit_image = vit_image.to(next(model.parameters()).device)
    
    # VAE preprocessing  
    vae_image = vae_transform(pil_img2rgb(image)).unsqueeze(0)
    vae_image = vae_image.to(next(model.parameters()).device)
    
    with torch.no_grad():
        # VAE features - this should work as VAE is more standard
        vae_features = vae_model.encode(vae_image).latent_dist.sample()
        
        # ViT features - try direct access first
        try:
            # BAGEL's ViT might also have custom input format
            # Let's try the simpler approach first
            vit_features = model.vit_model.vision_model.embeddings(vit_image)
            vit_last_hidden = None
            vit_hidden_states = None
            
            print("Note: Using direct ViT embedding access. Full ViT forward pass not implemented.")
            
        except Exception as e:
            print(f"Warning: ViT feature extraction failed: {e}")
            vit_features = None
            vit_last_hidden = None
            vit_hidden_states = None
        
    return {
        "vit_features": vit_features,
        "vit_pooled": None,
        "vit_hidden_states": vit_hidden_states,
        "vae_features": vae_features,
        "original_image": image,
        "image_shape": vit_image.shape,
        "vae_latent_shape": vae_features.shape if vae_features is not None else None
    }

def extract_multimodal_representations_simple(text, image_path=None):
    """
    Simplified multimodal representation extraction.
    Gets embeddings and features that are directly accessible.
    """
    # Get text representations
    text_repr = extract_text_embeddings_simple(text)
    
    result = {
        "text_embeddings": text_repr["embeddings"],
        "input_ids": text_repr["input_ids"],
        "text": text,
        "multimodal_ready": False
    }
    
    if image_path:
        # Get image representations
        image_repr = extract_image_features(image_path)
        
        result.update({
            "vit_features": image_repr["vit_features"],
            "vae_features": image_repr["vae_features"],
            "original_image": image_repr["original_image"],
            "multimodal_ready": True
        })
        
        # Try to get cross-modal features through BAGEL's connector
        try:
            if image_repr["vit_features"] is not None:
                # This might work if we can access the connector
                vit_processed = model.connector(image_repr["vit_features"])
                result["vit_processed"] = vit_processed
        except Exception as e:
            print(f"Note: Cross-modal processing not available: {e}")
            result["vit_processed"] = None
    
    return result

def analyze_model_structure():
    """
    Analyze BAGEL's structure to understand what components are accessible.
    """
    print("BAGEL Model Structure Analysis:")
    print("="*60)
    
    # Main model components
    print(f"Main model type: {type(model).__name__}")
    print(f"Language model: {type(model.language_model).__name__}")
    print(f"Vision model: {type(model.vit_model).__name__}")
    print(f"VAE model: {type(vae_model).__name__}")
    
    # Check for accessible components
    print("\nAccessible Components:")
    print("-" * 30)
    
    # Text embeddings
    try:
        test_ids = torch.tensor([[1, 2, 3]]).to(next(model.parameters()).device)
        embed_out = model.language_model.model.embed_tokens(test_ids)
        print(f"✓ Text embeddings: {embed_out.shape}")
    except Exception as e:
        print(f"✗ Text embeddings: {e}")
    
    # Connector
    try:
        test_vit = torch.randn(1, 256, 768).to(next(model.parameters()).device)
        conn_out = model.connector(test_vit)
        print(f"✓ Connector: {test_vit.shape} -> {conn_out.shape}")
    except Exception as e:
        print(f"✗ Connector: {e}")
    
    # Time embedder
    try:
        if hasattr(model, 'time_embedder'):
            print(f"✓ Time embedder available")
        else:
            print(f"✗ Time embedder not found")
    except Exception as e:
        print(f"✗ Time embedder: {e}")
    
    print("\nModel Configuration:")
    print("-" * 30)
    print(f"Visual generation: {config.visual_gen}")
    print(f"Visual understanding: {config.visual_und}")
    print(f"Max latent size: {config.max_latent_size}")
    print(f"Connector activation: {config.connector_act}")
    
    # Parameter count
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal parameters: {total_params:,}")
    
    print("="*60)

print("BAGEL-compatible representation extraction functions defined!")
print("Use extract_text_embeddings_simple() for reliable text embeddings.")
print("Use analyze_model_structure() to see what components are accessible.")

## Usage Examples

### Extract Text Representations

In [None]:
# Example text
sample_text = "A beautiful sunset over the mountains with golden light."

# Extract text representations using the simple method (most reliable)
text_repr = extract_text_embeddings_simple(sample_text)

print(f"Text: {sample_text}")
print(f"Input IDs shape: {text_repr['input_ids'].shape}")
print(f"Embeddings shape: {text_repr['embeddings'].shape}")
print(f"Tokens: {text_repr['tokens'][:10]}...")  # Show first 10 tokens
print(f"Sequence length: {text_repr['sequence_length']}")
print(f"Embedding dimension: {text_repr['embedding_dim']}")

# Show embedding statistics
embedding_mean = text_repr['embeddings'].mean().item()
embedding_std = text_repr['embeddings'].std().item()
print(f"Embedding mean: {embedding_mean:.4f}, std: {embedding_std:.4f}")

# Try the advanced extraction method (might work with generation)
print("\n" + "="*50)
print("Trying advanced extraction with generation...")
try:
    text_repr_advanced = extract_text_embeddings(sample_text)
    
    if text_repr_advanced['hidden_states'] is not None:
        print(f"✓ Advanced extraction successful!")
        print(f"Hidden states captured: {len(text_repr_advanced['hidden_states'])}")
        print(f"Generation output shape: {text_repr_advanced['generation_outputs'].sequences.shape}")
    else:
        print("ℹ️ Advanced extraction returned embeddings only")
        
except Exception as e:
    print(f"✗ Advanced extraction failed: {e}")
    print("Using simple extraction is recommended.")

### Extract Image Representations

In [None]:
# Note: You'll need to provide an actual image path
# Uncomment and modify the path below when you have an image to test

# image_path = "path/to/your/image.jpg"
# 
# # Extract image representations
# image_repr = extract_image_features(image_path)
# 
# print(f"ViT features shape: {image_repr['vit_features'].shape}")
# print(f"VAE features shape: {image_repr['vae_features'].shape}")
# print(f"Number of ViT hidden layers: {len(image_repr['vit_hidden_states'])}")
# 
# # Show feature statistics
# vit_mean = image_repr['vit_features'].mean().item()
# vit_std = image_repr['vit_features'].std().item()
# vae_mean = image_repr['vae_features'].mean().item()
# vae_std = image_repr['vae_features'].std().item()
# 
# print(f"ViT features - mean: {vit_mean:.4f}, std: {vit_std:.4f}")
# print(f"VAE features - mean: {vae_mean:.4f}, std: {vae_std:.4f}")

print("Uncomment the code above and provide an image path to test image feature extraction")

### Extract Multimodal Representations

In [None]:
# Text-only multimodal extraction using the new simple method
multimodal_text = extract_multimodal_representations_simple("Describe this beautiful landscape.")

print("Text-only multimodal extraction:")
print(f"Text embeddings shape: {multimodal_text['text_embeddings'].shape}")
print(f"Text: {multimodal_text['text']}")
print(f"Multimodal ready: {multimodal_text['multimodal_ready']}")

# Uncomment below for text + image multimodal extraction
# image_path = "path/to/your/image.jpg"
# multimodal_both = extract_multimodal_representations_simple(
#     "What do you see in this image?", 
#     image_path=image_path
# )
# 
# print("\nText + Image multimodal extraction:")
# print(f"Text embeddings shape: {multimodal_both['text_embeddings'].shape}")
# if multimodal_both['vit_features'] is not None:
#     print(f"ViT features shape: {multimodal_both['vit_features'].shape}")
# if multimodal_both['vit_processed'] is not None:
#     print(f"ViT processed shape: {multimodal_both['vit_processed'].shape}")
# print(f"VAE features shape: {multimodal_both['vae_features'].shape}")
# print(f"Multimodal ready: {multimodal_both['multimodal_ready']}")

print("\n" + "="*50)
print("Model structure analysis:")
analyze_model_structure()

### Analyze Model Architecture

In [None]:
# Analyze model components
print("BAGEL Model Architecture:")
print(f"Language model: {type(model.language_model).__name__}")
print(f"Vision model: {type(model.vit_model).__name__}")
print(f"VAE model: {type(vae_model).__name__}")
print(f"\nModel config:")
print(f"- Visual generation: {config.visual_gen}")
print(f"- Visual understanding: {config.visual_und}")
print(f"- Max latent size: {config.max_latent_size}")
print(f"- Latent patch size: {config.latent_patch_size}")
print(f"- ViT max patches per side: {config.vit_max_num_patch_per_side}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Parameters:")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

### Export Representations

In [None]:
def save_representations(representations, filename):
    """
    Save representations to file for further analysis.
    """
    # Convert tensors to numpy for saving
    numpy_repr = {}
    for key, value in representations.items():
        if torch.is_tensor(value):
            numpy_repr[key] = value.cpu().numpy()
        elif isinstance(value, list) and torch.is_tensor(value[0]):
            numpy_repr[key] = [v.cpu().numpy() for v in value]
        else:
            numpy_repr[key] = value
    
    np.savez(filename, **numpy_repr)
    print(f"Representations saved to {filename}")

# Example: Save text representations
save_representations(text_repr, "text_representations.npz")

# Load back example
loaded_repr = np.load("text_representations.npz", allow_pickle=True)
print(f"Loaded representation keys: {list(loaded_repr.keys())}")
print(f"Loaded embeddings shape: {loaded_repr['embeddings'].shape}")

## BAGEL Circuit Breakers

This section implements circuit breaker safety mechanisms targeting BAGEL's MoT (Mixture-of-Tokens) architecture.

### Analyze MoT Architecture

Let's first analyze BAGEL's MoT layers to identify the best intervention points.

In [None]:
def analyze_mot_architecture():
    """
    Analyze BAGEL's MoT (Mixture-of-Tokens) architecture to identify circuit breaker intervention points.
    """
    print("BAGEL MoT Architecture Analysis:")
    print("="*60)
    
    # Get the language model
    llm = model.language_model.model
    print(f"Language model type: {type(llm).__name__}")
    print(f"Number of layers: {len(llm.layers)}")
    
    # Analyze MoT layers
    mot_layers = []
    for i, layer in enumerate(llm.layers):
        layer_type = type(layer).__name__
        print(f"Layer {i}: {layer_type}")
        
        if 'MoT' in layer_type:
            mot_layers.append(i)
            # Analyze MoT-specific components
            print(f"  ✓ MoT Layer - Components:")
            print(f"    - self_attn: {type(layer.self_attn).__name__}")
            print(f"    - mlp: {type(layer.mlp).__name__}")
            
            # Check for MoT-specific dual components
            if hasattr(layer, 'mlp_moe_gen'):
                print(f"    - mlp_moe_gen: {type(layer.mlp_moe_gen).__name__} (Generation-specific)")
            if hasattr(layer, 'input_layernorm_moe_gen'):
                print(f"    - input_layernorm_moe_gen: {type(layer.input_layernorm_moe_gen).__name__}")
            if hasattr(layer, 'post_attention_layernorm_moe_gen'):
                print(f"    - post_attention_layernorm_moe_gen: {type(layer.post_attention_layernorm_moe_gen).__name__}")
                
            # Check for freeze_und attribute
            if hasattr(layer, 'freeze_und'):
                print(f"    - freeze_und: {layer.freeze_und}")
    
    print(f"\nMoT Layers found: {mot_layers}")
    print(f"Total MoT layers: {len(mot_layers)}")
    
    # Recommended intervention points (middle and later layers for safety)
    if mot_layers:
        num_layers = len(mot_layers)
        # Target middle to later layers for circuit breakers
        target_layers = []
        if num_layers >= 20:
            target_layers = [num_layers//2, num_layers*2//3, num_layers*3//4, num_layers-3, num_layers-1]
        elif num_layers >= 10:
            target_layers = [num_layers//2, num_layers*2//3, num_layers-2, num_layers-1]
        else:
            target_layers = [num_layers//2, num_layers-1]
            
        # Ensure target layers are valid indices
        target_layers = [l for l in target_layers if l < len(mot_layers)]
        
        print(f"\nRecommended Circuit Breaker Target Layers: {target_layers}")
        print("(These target middle-to-late layers for effective safety intervention)")
        
        return {
            "total_layers": len(llm.layers),
            "mot_layers": mot_layers,
            "target_layers": target_layers,
            "layer_objects": [llm.layers[i] for i in target_layers]
        }
    else:
        print("\n⚠️  No MoT layers found!")
        return {"total_layers": len(llm.layers), "mot_layers": [], "target_layers": []}

# Run the analysis
mot_analysis = analyze_mot_architecture()

### Circuit Breaker Hook System

Now let's implement forward hooks to intercept and modify hidden states at the MoT layers.

In [None]:
import torch.nn.functional as F
from typing import Dict, List, Callable, Any
import json
import random

class ImageGenerationCircuitBreakerHooks:
    """
    Manages forward hooks specifically for image generation safety on BAGEL's MoT layers.
    Targets the generation-specific components (mlp_moe_gen) to prevent harmful image generation.
    """
    
    def __init__(self, target_layers: List[int], layer_objects: List[torch.nn.Module]):
        """
        Initialize hook manager for image generation safety.
        
        Args:
            target_layers: List of layer indices to target
            layer_objects: List of actual layer modules to hook
        """
        self.target_layers = target_layers
        self.layer_objects = layer_objects
        self.hooks = []
        self.generation_interventions = {}
        self.collected_activations = {}
        self.safety_enabled = False
        
        # Image generation safety parameters
        self.steering_strength = 0.2  # Stronger intervention for generation
        self.intervention_mode = "generation_blocking"  # Specific to image generation
        self.generation_threshold = 0.5  # Threshold for harmful content detection
        
    def register_generation_hooks(self):
        """Register forward hooks specifically targeting generation components."""
        self.hooks = []
        
        for i, (layer_idx, layer_obj) in enumerate(zip(self.target_layers, self.layer_objects)):
            # Check if this layer has generation-specific components
            if hasattr(layer_obj, 'mlp_moe_gen'):
                print(f"Found generation-specific MLP in layer {layer_idx}")
                
                # Hook the generation-specific MLP
                def make_gen_hook(layer_index):
                    def generation_hook(module, input, output):
                        return self._generation_intervention_hook(layer_index, module, input, output)
                    return generation_hook
                
                gen_hook = layer_obj.mlp_moe_gen.register_forward_hook(make_gen_hook(layer_idx))
                self.hooks.append(gen_hook)
                
                # Also hook the generation-specific layer norms if they exist
                if hasattr(layer_obj, 'input_layernorm_moe_gen'):
                    norm_hook = layer_obj.input_layernorm_moe_gen.register_forward_hook(make_gen_hook(layer_idx))
                    self.hooks.append(norm_hook)
                
                print(f"Registered generation safety hooks on layer {layer_idx}")
            else:
                print(f"Layer {layer_idx} does not have generation-specific components")
    
    def remove_hooks(self):
        """Remove all registered hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        print("All generation safety hooks removed")
    
    def _generation_intervention_hook(self, layer_idx: int, module, input, output):
        """
        Intervention hook specifically for image generation components.
        
        Args:
            layer_idx: Index of the layer
            module: The generation module (e.g., mlp_moe_gen)
            input: Input to the module
            output: Output from the module
            
        Returns:
            Modified output if safety is enabled and harmful generation detected
        """
        # Always collect activations for analysis
        if isinstance(output, torch.Tensor):
            self.collected_activations[f"layer_{layer_idx}_gen"] = output.detach().clone()
        
        # Apply safety intervention if enabled
        if self.safety_enabled and layer_idx in self.generation_interventions:
            return self._apply_generation_safety(layer_idx, module, input, output)
        
        return output
    
    def _apply_generation_safety(self, layer_idx: int, module, input, output):
        """
        Apply safety intervention specifically for image generation.
        
        Args:
            layer_idx: Layer index
            module: Generation module
            input: Module input
            output: Module output
            
        Returns:
            Modified output that blocks harmful generation
        """
        intervention = self.generation_interventions[layer_idx]
        
        if self.intervention_mode == "generation_blocking":
            # Detect if this is likely harmful generation
            if self._detect_harmful_generation_pattern(output):
                print(f"🛑 Circuit breaker activated at layer {layer_idx} - blocking harmful generation")
                
                # Replace with safe generation pattern
                safe_output = self._get_safe_generation_output(output, intervention)
                return safe_output
        
        elif self.intervention_mode == "steering":
            # Apply steering away from harmful generation
            steering_vector = intervention.get("safety_steering", torch.zeros_like(output))
            return output + self.steering_strength * steering_vector
        
        return output
    
    def _detect_harmful_generation_pattern(self, output: torch.Tensor) -> bool:
        """
        Simple harmful generation pattern detection.
        In practice, this could use learned classifiers or more sophisticated methods.
        
        Args:
            output: Tensor output from generation component
            
        Returns:
            True if harmful pattern detected
        """
        # Simple heuristic: look for high activation patterns that might indicate harmful content
        # This is a placeholder - in practice you'd use learned detection
        
        # Check for unusually high activations (potential sign of harmful generation)
        max_activation = output.abs().max().item()
        mean_activation = output.abs().mean().item()
        
        # Simple threshold-based detection
        if max_activation > 3.0 * mean_activation and max_activation > self.generation_threshold:
            return True
        
        return False
    
    def _get_safe_generation_output(self, original_output: torch.Tensor, intervention: Dict[str, Any]) -> torch.Tensor:
        """
        Replace harmful generation output with safe alternative.
        
        Args:
            original_output: Original potentially harmful output
            intervention: Intervention parameters
            
        Returns:
            Safe generation output
        """
        # Option 1: Use a learned safe replacement
        if "safe_replacement" in intervention:
            return intervention["safe_replacement"]
        
        # Option 2: Zero out high activations (conservative approach)
        safe_output = original_output.clone()
        threshold = self.generation_threshold
        
        # Clamp extreme values that might lead to harmful generation
        safe_output = torch.clamp(safe_output, -threshold, threshold)
        
        # Add some noise to prevent memorization of the clamping pattern
        noise = torch.randn_like(safe_output) * 0.01
        safe_output = safe_output + noise
        
        return safe_output
    
    def set_generation_intervention(self, layer_idx: int, intervention_type: str, **kwargs):
        """
        Set intervention parameters for generation safety.
        
        Args:
            layer_idx: Layer index
            intervention_type: Type of intervention ("blocking", "steering", etc.)
            **kwargs: Intervention parameters
        """
        self.generation_interventions[layer_idx] = {
            "type": intervention_type,
            **kwargs
        }
        print(f"Set {intervention_type} generation safety intervention for layer {layer_idx}")
    
    def enable_generation_safety(self):
        """Enable generation safety interventions."""
        self.safety_enabled = True
        print("🛡️ Image generation safety enabled")
    
    def disable_generation_safety(self):
        """Disable generation safety interventions."""
        self.safety_enabled = False
        print("Image generation safety disabled")
    
    def get_generation_activations(self) -> Dict[str, torch.Tensor]:
        """Get collected activations from generation components."""
        return self.collected_activations.copy()
    
    def clear_activations(self):
        """Clear collected activations."""
        self.collected_activations = {}
    
    def analyze_generation_pathway(self):
        """Analyze which layers have generation-specific components."""
        print("Analyzing BAGEL's Generation Pathway:")
        print("="*50)
        
        generation_layers = []
        for layer_idx, layer_obj in zip(self.target_layers, self.layer_objects):
            has_gen_components = False
            components = []
            
            if hasattr(layer_obj, 'mlp_moe_gen'):
                components.append("mlp_moe_gen")
                has_gen_components = True
            
            if hasattr(layer_obj, 'input_layernorm_moe_gen'):
                components.append("input_layernorm_moe_gen")
                has_gen_components = True
            
            if hasattr(layer_obj, 'post_attention_layernorm_moe_gen'):
                components.append("post_attention_layernorm_moe_gen")
                has_gen_components = True
            
            if has_gen_components:
                generation_layers.append(layer_idx)
                print(f"Layer {layer_idx}: {', '.join(components)}")
            else:
                print(f"Layer {layer_idx}: No generation-specific components")
        
        print(f"\nGeneration-capable layers: {generation_layers}")
        print(f"Total generation layers: {len(generation_layers)}")
        
        return generation_layers

# Initialize image generation circuit breaker hooks
if mot_analysis["target_layers"]:
    img_cb_hooks = ImageGenerationCircuitBreakerHooks(
        target_layers=mot_analysis["target_layers"],
        layer_objects=mot_analysis["layer_objects"]
    )
    
    print("Image Generation Circuit Breaker Hooks initialized!")
    
    # Analyze the generation pathway
    generation_layers = img_cb_hooks.analyze_generation_pathway()
    
else:
    img_cb_hooks = None
    print("⚠️  No MoT layers found - image generation circuit breaker hooks not available")

### Circuit Breaker Dataset

Adapt the proven circuit breaker training approach for BAGEL.

In [None]:
class I2PImageGenerationDataset:
    """
    Proper I2P dataset integration for image generation safety.
    Uses the real I2P annotations: categories, inappropriate_percentage, etc.
    """
    
    def __init__(self, tokenizer, max_length=512, 
                 harmful_threshold=70.0, safe_threshold=30.0, 
                 nudity_threshold=20.0, hard_only=False):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.harmful_threshold = harmful_threshold
        self.safe_threshold = safe_threshold
        self.nudity_threshold = nudity_threshold
        self.hard_only = hard_only
        
        # Load I2P dataset
        print("Loading I2P dataset...")
        try:
            from datasets import load_dataset
            self.i2p_dataset = load_dataset("AIML-TUDA/i2p")
            print(f"✓ I2P dataset loaded: {len(self.i2p_dataset['train'])} total prompts")
            
            # Process the dataset
            self._process_i2p_dataset()
            
        except Exception as e:
            print(f"✗ Failed to load I2P dataset: {e}")
            raise RuntimeError("I2P dataset is required for this implementation")
    
    def _process_i2p_dataset(self):
        """Process I2P dataset based on safety annotations."""
        
        self.safe_examples = []
        self.harmful_examples = []
        self.category_stats = {}
        
        print("Processing I2P dataset annotations...")
        
        for item in self.i2p_dataset['train']:
            # Extract fields from I2P format
            prompt = item['prompt']
            categories = item.get('categories', 'unknown')
            hard = item.get('hard', 0)
            inappropriate_pct = item.get('inappropriate_percentage', 0)
            nudity_pct = item.get('nudity_percentage', 0)
            sd_safety_pct = item.get('sd_safety_percentage', 100)
            toxicity = item.get('prompt_toxicity', 0)
            
            # Skip if hard_only is enabled and this isn't hard
            if self.hard_only and hard == 0:
                continue
            
            # Count categories
            if categories not in self.category_stats:
                self.category_stats[categories] = {"total": 0, "harmful": 0, "safe": 0}
            self.category_stats[categories]["total"] += 1
            
            # Determine if harmful based on multiple criteria
            is_harmful = (
                inappropriate_pct >= self.harmful_threshold or 
                nudity_pct >= self.nudity_threshold or
                sd_safety_pct <= 30  # Low SD safety score indicates harmful
            )
            
            # Determine if safe based on criteria
            is_safe = (
                inappropriate_pct <= self.safe_threshold and 
                nudity_pct <= 10 and
                sd_safety_pct >= 80 and
                toxicity <= 0.1
            )
            
            # Create example data structure
            example = {
                "prompt": prompt,
                "categories": categories,
                "hard": hard,
                "inappropriate_percentage": inappropriate_pct,
                "nudity_percentage": nudity_pct,
                "sd_safety_percentage": sd_safety_pct,
                "prompt_toxicity": toxicity,
                "full_text": f"Generate an image: {prompt}",
            }
            
            if is_harmful:
                example.update({
                    "type": "harmful_generation",
                    "should_generate": False
                })
                self.harmful_examples.append(example)
                self.category_stats[categories]["harmful"] += 1
                
            elif is_safe:
                example.update({
                    "type": "safe_generation", 
                    "should_generate": True
                })
                self.safe_examples.append(example)
                self.category_stats[categories]["safe"] += 1
        
        # Print statistics
        print(f"\nI2P Dataset Processing Results:")
        print(f"- Safe examples: {len(self.safe_examples)}")
        print(f"- Harmful examples: {len(self.harmful_examples)}")
        print(f"- Hard examples only: {self.hard_only}")
        print(f"- Harmful threshold: {self.harmful_threshold}%")
        print(f"- Safe threshold: {self.safe_threshold}%")
        
        print(f"\nCategory Breakdown:")
        for category, stats in self.category_stats.items():
            print(f"  {category}: {stats['total']} total, {stats['harmful']} harmful, {stats['safe']} safe")
    
    def get_examples_by_category(self, category: str, harmful: bool = True):
        """Get examples filtered by category."""
        examples = self.harmful_examples if harmful else self.safe_examples
        return [ex for ex in examples if ex["categories"] == category]
    
    def get_high_risk_examples(self, min_inappropriate=90, min_nudity=50):
        """Get the most problematic examples for robust training."""
        high_risk = []
        for ex in self.harmful_examples:
            if (ex["inappropriate_percentage"] >= min_inappropriate or 
                ex["nudity_percentage"] >= min_nudity):
                high_risk.append(ex)
        return high_risk
    
    def tokenize_for_generation(self, examples):
        """Tokenize examples for image generation training."""
        tokenized = []
        
        for example in examples:
            # Tokenize the prompt
            tokens = self.tokenizer(
                example["prompt"],
                max_length=self.max_length,
                truncation=True,
                padding="max_length",
                return_tensors="pt"
            )
            
            tokenized.append({
                "input_ids": tokens["input_ids"].squeeze(),
                "attention_mask": tokens["attention_mask"].squeeze(),
                "prompt": example["prompt"],
                "type": example["type"],
                "should_generate": example["should_generate"],
                "categories": example["categories"],
                "inappropriate_percentage": example["inappropriate_percentage"],
                "nudity_percentage": example["nudity_percentage"],
                "sd_safety_percentage": example["sd_safety_percentage"],
                "prompt_toxicity": example["prompt_toxicity"],
                "hard": example["hard"]
            })
        
        return tokenized
    
    def get_training_batch(self, batch_size=8, safe_ratio=0.5, 
                          target_categories=None, min_difficulty=None):
        """
        Get a training batch with sophisticated filtering.
        
        Args:
            batch_size: Number of examples in batch
            safe_ratio: Ratio of safe to harmful examples
            target_categories: List of categories to focus on (e.g., ['sexual', 'violence'])
            min_difficulty: Minimum inappropriate_percentage for harmful examples
        """
        n_safe = int(batch_size * safe_ratio)
        n_harmful = batch_size - n_safe
        
        # Filter examples based on criteria
        available_safe = self.safe_examples.copy()
        available_harmful = self.harmful_examples.copy()
        
        # Filter by categories if specified
        if target_categories:
            available_safe = [ex for ex in available_safe if ex["categories"] in target_categories]
            available_harmful = [ex for ex in available_harmful if ex["categories"] in target_categories]
        
        # Filter by minimum difficulty
        if min_difficulty:
            available_harmful = [ex for ex in available_harmful 
                               if ex["inappropriate_percentage"] >= min_difficulty]
        
        # Sample examples
        safe_batch = random.sample(
            available_safe, 
            min(n_safe, len(available_safe))
        )
        harmful_batch = random.sample(
            available_harmful, 
            min(n_harmful, len(available_harmful))
        )
        
        # Tokenize
        safe_tokenized = self.tokenize_for_generation(safe_batch)
        harmful_tokenized = self.tokenize_for_generation(harmful_batch)
        
        return {
            "safe_generation": safe_tokenized,
            "harmful_generation": harmful_tokenized,
            "safe_ratio": safe_ratio,
            "batch_categories": target_categories,
            "min_difficulty": min_difficulty
        }
    
    def get_test_prompts(self, n_safe=5, n_harmful=5, categories=None):
        """Get test prompts with category filtering."""
        available_safe = self.safe_examples
        available_harmful = self.harmful_examples
        
        if categories:
            available_safe = [ex for ex in available_safe if ex["categories"] in categories]
            available_harmful = [ex for ex in available_harmful if ex["categories"] in categories]
        
        safe_test = random.sample(available_safe, min(n_safe, len(available_safe)))
        harmful_test = random.sample(available_harmful, min(n_harmful, len(available_harmful)))
        
        return {
            "safe_prompts": [{"prompt": ex["prompt"], "metadata": ex} for ex in safe_test],
            "harmful_prompts": [{"prompt": ex["prompt"], "metadata": ex} for ex in harmful_test]
        }
    
    def analyze_dataset(self):
        """Analyze the I2P dataset characteristics."""
        print("I2P Dataset Analysis:")
        print("="*50)
        
        # Overall statistics
        total_examples = len(self.safe_examples) + len(self.harmful_examples)
        print(f"Total usable examples: {total_examples}")
        print(f"Safe examples: {len(self.safe_examples)} ({len(self.safe_examples)/total_examples*100:.1f}%)")
        print(f"Harmful examples: {len(self.harmful_examples)} ({len(self.harmful_examples)/total_examples*100:.1f}%)")
        
        # Difficulty distribution
        if self.harmful_examples:
            inappropriateness_scores = [ex["inappropriate_percentage"] for ex in self.harmful_examples]
            print(f"\nInappropriateness Distribution (Harmful Examples):")
            print(f"  Mean: {np.mean(inappropriateness_scores):.1f}%")
            print(f"  Median: {np.median(inappropriateness_scores):.1f}%")
            print(f"  Min: {min(inappropriateness_scores):.1f}%")
            print(f"  Max: {max(inappropriateness_scores):.1f}%")
        
        # Category analysis
        print(f"\nCategory Distribution:")
        sorted_categories = sorted(self.category_stats.items(), 
                                 key=lambda x: x[1]["total"], reverse=True)
        for category, stats in sorted_categories[:10]:  # Top 10 categories
            total = stats["total"]
            harmful_pct = stats["harmful"] / total * 100 if total > 0 else 0
            print(f"  {category}: {total} total ({harmful_pct:.1f}% harmful)")

# Create proper I2P-based dataset
print("Creating I2P-based Image Generation Safety Dataset...")
i2p_dataset = I2PImageGenerationDataset(
    tokenizer, 
    max_length=512, 
    harmful_threshold=70.0,  # 70%+ inappropriate content = harmful
    safe_threshold=30.0,     # <30% inappropriate content = safe
    nudity_threshold=20.0,   # 20%+ nudity = harmful
    hard_only=False          # Use both easy and hard examples
)

### Image Generation Safety Testing

Test the circuit breaker system with safe and harmful image generation prompts.

In [None]:
def test_i2p_image_generation_safety():
    """
    Test the image generation circuit breaker system with real I2P data.
    """
    if img_cb_hooks is None:
        print("⚠️ Image generation circuit breakers not available")
        return
    
    print("Testing BAGEL Image Generation Safety with I2P Dataset")
    print("="*70)
    
    # Analyze the I2P dataset first
    i2p_dataset.analyze_dataset()
    
    # Get test prompts from real I2P data
    test_prompts = i2p_dataset.get_test_prompts(n_safe=3, n_harmful=3)
    
    print(f"\n{'='*70}")
    print("🟢 Safe Prompts from I2P (low inappropriate %):")
    for i, item in enumerate(test_prompts["safe_prompts"]):
        prompt = item["prompt"]
        meta = item["metadata"]
        print(f"  {i+1}. {prompt[:80]}...")
        print(f"     Category: {meta['categories']}, Inappropriate: {meta['inappropriate_percentage']}%")
    
    print(f"\n🔴 Harmful Prompts from I2P (high inappropriate %):")
    for i, item in enumerate(test_prompts["harmful_prompts"]):
        prompt = item["prompt"]
        meta = item["metadata"]
        print(f"  {i+1}. {prompt[:80]}...")
        print(f"     Category: {meta['categories']}, Inappropriate: {meta['inappropriate_percentage']}%")
        print(f"     Nudity: {meta['nudity_percentage']}%, Toxicity: {meta['prompt_toxicity']:.3f}")
    
    # Test category-specific examples
    print(f"\n{'='*70}")
    print("Category-Specific Testing:")
    print("="*70)
    
    # Test different categories
    test_categories = ['sexual', 'violence', 'hate', 'shocking']
    for category in test_categories:
        category_examples = i2p_dataset.get_examples_by_category(category, harmful=True)
        if category_examples:
            example = category_examples[0]
            print(f"\n🎯 {category.upper()} category example:")
            print(f"   Prompt: {example['prompt'][:100]}...")
            print(f"   Inappropriate: {example['inappropriate_percentage']}%")
            print(f"   SD Safety: {example['sd_safety_percentage']}%")
    
    # Test high-risk examples
    high_risk = i2p_dataset.get_high_risk_examples(min_inappropriate=95, min_nudity=80)
    print(f"\n⚠️ High-Risk Examples (95%+ inappropriate OR 80%+ nudity): {len(high_risk)} found")
    if high_risk:
        example = high_risk[0]
        print(f"   Most problematic: {example['prompt'][:100]}...")
        print(f"   Inappropriate: {example['inappropriate_percentage']}%, Nudity: {example['nudity_percentage']}%")
    
    # Test circuit breaker training batch generation
    print(f"\n{'='*70}")
    print("Training Batch Generation:")
    print("="*70)
    
    # Test different training configurations
    configs = [
        {"safe_ratio": 0.5, "target_categories": None, "min_difficulty": None},
        {"safe_ratio": 0.3, "target_categories": ['sexual', 'violence'], "min_difficulty": 80},
        {"safe_ratio": 0.7, "target_categories": ['hate'], "min_difficulty": 70}
    ]
    
    for i, config in enumerate(configs):
        print(f"\nTraining Config {i+1}: {config}")
        try:
            batch = i2p_dataset.get_training_batch(batch_size=6, **config)
            print(f"  ✓ Generated batch: {len(batch['safe_generation'])} safe, {len(batch['harmful_generation'])} harmful")
            
            # Show sample from batch
            if batch['harmful_generation']:
                sample = batch['harmful_generation'][0]
                print(f"    Sample harmful: {sample['prompt'][:60]}... ({sample['categories']}, {sample['inappropriate_percentage']}%)")
        except Exception as e:
            print(f"  ✗ Batch generation failed: {e}")
    
    # Test safety system integration
    print(f"\n{'='*70}")
    print("Circuit Breaker Integration Test:")
    print("="*70)
    
    # Enable circuit breaker hooks
    try:
        img_cb_hooks.register_generation_hooks()
        img_cb_hooks.enable_generation_safety()
        
        # Test with a real harmful prompt from I2P
        if test_prompts["harmful_prompts"]:
            harmful_item = test_prompts["harmful_prompts"][0]
            harmful_prompt = harmful_item["prompt"]
            harmful_meta = harmful_item["metadata"]
            
            print(f"\\nTesting circuit breakers with real harmful prompt:")
            print(f"Prompt: {harmful_prompt[:100]}...")
            print(f"I2P Rating: {harmful_meta['inappropriate_percentage']}% inappropriate")
            
            # Tokenize and test
            tokens = tokenizer(harmful_prompt, return_tensors="pt", max_length=128, truncation=True)
            input_ids = tokens["input_ids"].to(next(model.parameters()).device)
            
            with torch.no_grad():
                img_cb_hooks.clear_activations()
                # This would normally go through full generation pipeline
                embeddings = model.language_model.model.embed_tokens(input_ids)
                activations = img_cb_hooks.get_generation_activations()
                
                print(f"✓ Processed with safety monitoring")
                print(f"  Activations captured: {len(activations)} generation layers")
                print(f"  Ready for circuit breaker training with I2P annotations")
        
        # Clean up
        img_cb_hooks.remove_hooks()
        img_cb_hooks.disable_generation_safety()
        
    except Exception as e:
        print(f"✗ Circuit breaker integration error: {e}")
    
    print(f"\n{'='*70}")
    print("I2P Integration Summary:")
    print("✓ Real inappropriate image prompts loaded from I2P dataset")
    print("✓ Multi-dimensional safety ratings (inappropriate %, nudity %, toxicity)")
    print("✓ Category-specific training capability") 
    print("✓ Difficulty-based filtering for robust training")
    print("✓ Integration with MoT generation circuit breakers")
    print("✓ Ready for progressive training with retain vs circuit breaker loss")
    print("="*70)

# Run the I2P-based safety test
test_i2p_image_generation_safety()