In [1]:
import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, StableDiffusionXLPipeline, StableDiffusion3Pipeline, FluxPipeline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import textwrap
import numpy as np

# Define test prompts
test_prompts = [
    'a white cat',
    'a white cat on the right of a black dog',
    'a cup with "Hello" written on it',
    'a cup with "Accommodate" written on it',
    'two computer screen of different sizes',
    '3 computer screens, the biggest one on the right, smallest on the middle',
    'a person showing peace sign',
    'a person, hand over head, doing the ok sign, and the peace sign with the other hands',
]

# Model configurations
model_configs = [
    {
        "name": "SD v1.4",
        "model_id": "CompVis/stable-diffusion-v1-4",
        "model_type": "SD",
        "special_case": False
    },
    {
        "name": "DDPO",
        "model_id": "kvablack/ddpo-alignment",
        "model_type": "SD",
        "special_case": False
    },
    {
        "name": "SDXL",
        "model_id": "stabilityai/stable-diffusion-xl-base-1.0",
        "model_type": "SDXL",
        "special_case": False
    },
    {
        "name": "DPO-SDXL",
        "model_id": "mhdang/dpo-sdxl-text2image-v1",
        "model_type": "SDXL",
        "special_case": True,
        "base_model_id": "stabilityai/stable-diffusion-xl-base-1.0"
    },
    {
        "name": "SD 3.5",
        "model_id": "stabilityai/stable-diffusion-3-5-large",
        "model_type": "SD3",
        "special_case": False
    },
    # {
    #     "name": "FLUX.1",
    #     "model_id": "black-forest-labs/FLUX.1-schnell",
    #     "model_type": "FLUX",
    #     "special_case": False
    # },
]

# Set device
device = "cuda"

def load_model(config):
    """Load a model based on its configuration"""
    print(f"Loading {config['name']}...")
    
    if config['special_case']:
        # Special case for DPO-SDXL
        if config['model_type'] == 'SDXL' and 'dpo' in config['name'].lower():
            # First load the base SDXL model
            pipe = StableDiffusionXLPipeline.from_pretrained(
                config['base_model_id'], 
                torch_dtype=torch.float16, 
                variant="fp16", 
                use_safetensors=True
            ).to(device)
            
            # Then load and replace the UNet
            dpo_unet = UNet2DConditionModel.from_pretrained(
                config['model_id'],
                subfolder='unet',
                torch_dtype=torch.float16
            ).to(device)
            
            # Replace the UNet in the pipeline
            pipe.unet = dpo_unet
            return pipe
    else:
        # Standard model loading
        if config['model_type'] == 'SD':
            return StableDiffusionPipeline.from_pretrained(
                config['model_id'], 
                torch_dtype=torch.float16
            ).to(device)
        elif config['model_type'] == 'SDXL':
            return StableDiffusionXLPipeline.from_pretrained(
                config['model_id'], 
                torch_dtype=torch.float16, 
                variant="fp16", 
                use_safetensors=True
            ).to(device)
        elif config['model_type'] == 'SD3':
            return StableDiffusion3Pipeline.from_pretrained(
                config['model_id'], 
                torch_dtype=torch.bfloat16
            ).to(device)
        elif config['model_type'] == 'FLUX':
            return FluxPipeline.from_pretrained(
                config['model_id'], 
                torch_dtype=torch.bfloat16
            ).to(device)
    
    raise ValueError(f"Unknown model type: {config['model_type']}")

def generate_images(model, prompts, config):
    """Generate images using the provided model and prompts"""
    print(f"Generating images with {config['name']}...")
    images = []
    
    for prompt in prompts:
        with torch.cuda.amp.autocast():
            result = model(prompt).images[0]
            images.append(result)
    
    return images

def display_items(prompts, all_images, model_names, width=40):
    """
    Display a list of items containing text prompts and images in a subplot.
    
    Parameters:
    prompts (list): List of text prompts.
    all_images (list): List of lists of images for each model.
    model_names (list): List of model names for the columns.
    width (int): Width of the text wrapping.
    """
    num_items = len(prompts)
    num_models = len(model_names)
    
    fig, axs = plt.subplots(num_items, num_models, figsize=(3 * num_models, 3 * num_items))
    
    for i in range(num_items):
        if i == 0:
            for j in range(num_models):
                axs[i, j].set_title(model_names[j], fontsize=20)
                
        # Display the text prompt
        wrapped_text = textwrap.fill(prompts[i], width=width)
        axs[i, 0].text(0.5, 0.5, wrapped_text, ha='center', va='center', fontsize=12)
        axs[i, 0].axis('off')
        
        for j in range(1, num_models):
            if j-1 < len(all_images) and i < len(all_images[j-1]):
                axs[i, j].imshow(np.array(all_images[j-1][i]))
            axs[i, j].axis('off')
            
    plt.tight_layout()
    plt.show()
    return fig

# Process each model sequentially
all_images = []
model_names = ["Prompt"]  # First column is for prompts

for config in model_configs:
    # Load the model
    model = load_model(config)
    
    # Generate images
    images = generate_images(model, test_prompts, config)
    all_images.append(images)
    model_names.append(config['name'])
    
    # Clear GPU memory
    del model
    torch.cuda.empty_cache()

# Display all generated images
fig = display_items(test_prompts, all_images, model_names)

Loading SD v1.4...


model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

scheduler_config-checkpoint.json:   0%|          | 0.00/209 [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]



Generating images with SD v1.4...


  with torch.cuda.amp.autocast():


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Loading DDPO...


model_index.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

unet/diffusion_pytorch_model.safetensors not found


Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/604 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/492M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/520 [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/433 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.58k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.64k [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

An error occurred while trying to fetch /nfs/stak/users/shressag/hpc-share/hf-cache/hub/models--kvablack--ddpo-alignment/snapshots/23c5dc41c49dbd9495759200dad1b8c6fd727d21/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /nfs/stak/users/shressag/hpc-share/hf-cache/hub/models--kvablack--ddpo-alignment/snapshots/23c5dc41c49dbd9495759200dad1b8c6fd727d21/unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
The config attributes {'use_memory_efficient_attention': False} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.
An error occurred while trying to fetch /nfs/stak/users/shressag/hpc-share/hf-cache/hub/models--kvablack--ddpo-alignment/snapshots/23c5dc41c49dbd9495759200dad1b8c6fd727d21/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /nfs/stak/users/shressag/hpc-share/hf-cache/hub/models--kvablack--ddpo-ali

Generating images with DDPO...


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Loading SDXL...


model_index.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

model.fp16.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

model.fp16.safetensors:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/575 [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/5.14G [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/642 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Generating images with SDXL...


  0%|          | 0/50 [00:00<?, ?it/s]

  images = (images * 255).round().astype("uint8")


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Loading DPO-SDXL...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/10.3G [00:00<?, ?B/s]

Generating images with DPO-SDXL...


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Couldn't connect to the Hub: 404 Client Error. (Request ID: Root=1-67e5852b-4a1031047251293215c749b5;1f8ebf04-4cd6-4670-a76a-27c0fc6eb85f)

Repository Not Found for url: https://huggingface.co/api/models/stabilityai/stable-diffusion-3-5-large.
Please make sure you specified the correct `repo_id` and `repo_type`.
If you are trying to access a private or gated repo, make sure you are authenticated..
Will try to load from local cache.


Loading SD 3.5...


OSError: Cannot load model stabilityai/stable-diffusion-3-5-large: model is not cached locally and an error occurred while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace above.