In [1]:
from einops import rearrange
import torch
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
import numpy as np
from torchvision.datasets import ImageFolder
from ncut_pytorch import NCUT, rgb_from_tsne_3d
from matplotlib import pyplot as plt
import os
import glob
import matplotlib.pyplot as plt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoModel, AutoProcessor, CLIPTokenizer, CLIPTextModelWithProjection
from qwen_vl_utils import process_vision_info
import requests
from PIL import Image, ImageOps
import accelerate
import gc
from diffusers import StableDiffusion3Pipeline, AutoencoderKL, SD3Transformer2DModel
import functools

  from .autonotebook import tqdm as notebook_tqdm
2025-03-28 23:27:23.602957: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-28 23:27:23.602996: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-28 23:27:23.603922: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-28 23:27:23.609421: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from huggingface_hub import login
login("") # login token removed since this is going on a public repo

In [3]:
def resize(image, size=(448, 448), pad=(255, 255, 255)):
    image.thumbnail((size[0], size[1]), Image.Resampling.LANCZOS)

    resized = Image.new("RGB", size, pad)

    x_offset = (size[0] - image.size[0]) // 2
    y_offset = (size[1] - image.size[1]) // 2

    resized.paste(image, (x_offset, y_offset))

    return resized

In [4]:
def compute_paligemma_features(images):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    
    model_id = "google/paligemma2-3b-ft-docci-448"
    model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto").to("cuda").eval()
    processor = PaliGemmaProcessor.from_pretrained(model_id)

    processor.do_resize = False
    processor.do_center_crop = False 

    text = ["<image>" for i in range(len(images))]
    model_inputs = processor(text=text, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device)
    
    with torch.no_grad():
        vision_outputs = model.vision_tower.vision_model(pixel_values=model_inputs["pixel_values"])
        features = vision_outputs.last_hidden_state.to(torch.float32)

    return features.reshape(len(images), 32, 32, -1)

In [5]:
def compute_qwen_features(images):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto").to("cuda").eval()
    processor = AutoProcessor.from_pretrained(model_id)

    processor.do_resize = False
    processor.do_center_crop = False 

    model_inputs = processor(text="", images=images, return_tensors="pt").to(torch.bfloat16).to(model.device)

    with torch.no_grad():
        vision_outputs = model.visual(hidden_states=model_inputs["pixel_values"], grid_thw=model_inputs["image_grid_thw"])
        features = vision_outputs.to(torch.float32)
        
    return features.reshape(len(images), 32, 32, -1)

In [6]:
def compute_dino_features(images):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    
    model_id = "facebook/dinov2-base"
    model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").to("cuda").eval()
    processor = AutoProcessor.from_pretrained(model_id)

    processor.do_resize = False
    processor.do_center_crop = False 

    model_inputs = processor(images=images, return_tensors="pt").to(torch.bfloat16).to(model.device)

    with torch.no_grad():
        vision_outputs = model(**model_inputs)
        features = vision_outputs.last_hidden_state.to(torch.float32)[:,1:]

    return features.reshape(len(images), 32, 32, -1)

In [7]:
def compute_stable_diffusion_features(images):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    
    model_id = "stabilityai/stable-diffusion-3.5-medium"
    model_vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
    model_transformer = SD3Transformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    
    image_tensors = []
    for image in images:
        image_tensor = transform(image).to(torch.bfloat16).to(model_vae.device)
        image_tensors.append(image_tensor)

    image_tensors = torch.stack(image_tensors, dim=0)

    with torch.no_grad():
        latent_distributions = model_vae.encode(image_tensors)
        latents = latent_distributions.latent_dist.sample()

    vision_outputs = []
    def save_vision_outputs(module, input, output):
        vision_outputs.append(output)

    model_transformer.transformer_blocks[22].register_forward_hook(save_vision_outputs)

    timestep = torch.tensor([0], dtype=torch.long, device=latents.device)
    pooled_projections = torch.zeros((latents.shape[0], 2048), dtype=latents.dtype, device=latents.device)
    text_embeddings = torch.zeros((latents.shape[0], 77, 4096), dtype=latents.dtype, device=latents.device)
    
    with torch.no_grad():        
        model_transformer(hidden_states=latents, timestep=timestep, pooled_projections=pooled_projections, encoder_hidden_states=text_embeddings,)

    features = vision_outputs[0][1]

    return features.reshape(len(images), 32, 32, -1)

In [8]:
image_files = sorted(glob.glob("data/*_base.png") + glob.glob("data/*_test.png"))

images = []
for image_file in image_files:
    image = Image.open(image_file).convert("RGB")
    image = resize(image, size=(448, 448))
    images.append(image)

paligemma_features = compute_paligemma_features(images)

Loading checkpoint shards: 100%|██████████| 2/2 [00:13<00:00,  6.89s/it]


In [9]:
image_files = sorted(glob.glob("data/*_base.png") + glob.glob("data/*_test.png"))

images = []
for image_file in image_files:
    image = Image.open(image_file).convert("RGB")
    image = resize(image, size=(896, 896))
    images.append(image)

qwen_features = compute_qwen_features(images)

Loading checkpoint shards: 100%|██████████| 2/2 [00:16<00:00,  8.40s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [10]:
image_files = sorted(glob.glob("data/*_base.png") + glob.glob("data/*_test.png"))

images = []
for image_file in image_files:
    image = Image.open(image_file).convert("RGB")
    image = resize(image, size=(448, 448))
    images.append(image)

dino_features = compute_dino_features(images)

In [11]:
image_files = sorted(glob.glob("data/*_base.png") + glob.glob("data/*_test.png"))

images = []
for image_file in image_files:
    image = Image.open(image_file).convert("RGB")
    image = resize(image, size=(512, 512))
    images.append(image)

stable_diffusion_features = compute_stable_diffusion_features(images)

In [12]:
all_features = {"paligemma_features": paligemma_features,
                "qwen_features": qwen_features,
                "dino_features": dino_features,
                "stable_diffusion_features": stable_diffusion_features}

torch.save(all_features, "all_features.pt")