In [1]:
import torch
from transformers import Blip2Processor, Blip2Model, Blip2ForConditionalGeneration
from PIL import Image
import torchvision.transforms as T

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load BLIP-2 model and processor
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b")
decoder_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")

# Move models to GPU if available
device = 'mps'
model = model.to(device)
decoder_model = decoder_model.to(device)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Fetching 2 files: 100%|██████████| 2/2 [04:03<00:00, 121.97s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.77it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.42it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 18.12 GB, other allocations: 1.70 MB, max allowed: 18.13 GB). Tried to allocate 9.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:

def embed_images(images):
    """Embed a batch of images using BLIP-2.
    
    Args:
        images (torch.Tensor): Batch of images with shape (B, C, H, W)
        
    Returns:
        torch.Tensor: Image embeddings
    """
    # Convert images to range [0, 1] if they're in [-1, 1]
    if images.min() < 0:
        images = (images + 1) / 2
    
    # Convert to PIL Images
    transform = T.ToPILImage()
    pil_images = [transform(img) for img in images]
    
    # Process images
    inputs = processor(images=pil_images, return_tensors="pt").to(device)
    
    # Get image embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        image_embeddings = outputs.image_embeds
    
    return image_embeddings

In [None]:
def decode_embeddings(image_embeddings, max_length=50):
    """Decode image embeddings into text descriptions using BLIP-2.
    
    Args:
        image_embeddings (torch.Tensor): Image embeddings from BLIP-2
        max_length (int): Maximum length of generated text
        
    Returns:
        list: List of generated text descriptions
    """
    with torch.no_grad():
        outputs = decoder_model.generate(
            vision_hidden_states=image_embeddings,
            max_length=max_length,
            num_beams=5,
            min_length=5,
            top_p=0.9,
            repetition_penalty=1.5,
            length_penalty=1.0,
        )
    
    # Decode the generated tokens to text
    generated_texts = processor.batch_decode(outputs, skip_special_tokens=True)
    return generated_texts

In [None]:
# Load and process test images
test_images = torch.load('test_imgs.pt')
print(f"Loaded images shape: {test_images.shape}")

# Get embeddings
embeddings = embed_images(test_images)
print(f"Generated embeddings shape: {embeddings.shape}")

# Generate descriptions for a few examples
descriptions = decode_embeddings(embeddings[:5])  # Process first 5 images as example
for i, desc in enumerate(descriptions):
    print(f"Image {i}: {desc}") 