In [None]:
import torch
from datasets import load_dataset
from transformers import Blip2Processor, Blip2ForConditionalGeneration

model_name = "Salesforce/blip2-opt-2.7b"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
    model_name, torch_dtype=torch.float16, device_map="auto"
)

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

inputs = processor(images=image, return_tensors="pt").to(
    model.device, dtype=torch.float16
)
image_embeds = model.vision_model(
    inputs["pixel_values"], return_dict=True
).last_hidden_state

print(model.vision_model)
print(image_embeds)
print(image_embeds.shape)

In [None]:
image_attention_mask = torch.ones(
    image_embeds.size()[:-1], dtype=torch.long, device=model.device
)
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)

query_outputs = model.qformer(
    query_embeds=query_tokens,
    encoder_hidden_states=image_embeds,
    encoder_attention_mask=image_attention_mask,
    return_dict=True
)
query_output = query_outputs.last_hidden_state

print(image_attention_mask.shape)
print(query_tokens.shape)
print(query_output.shape)

In [None]:
language_model_inputs = model.language_projection(query_output)
language_attention_mask = torch.ones(
    language_model_inputs.size()[:-1],
    dtype=torch.long,
    device=model.device
)
input_ids = (
    torch.LongTensor([[model.config.text_config.bos_token_id]])
    .repeat(inputs["pixel_values"].shape[0], 1)
    .to(model.device)
)
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat(
    [language_attention_mask, attention_mask.to(model.device)], dim=1
)

print(language_model_inputs.shape)
print(input_ids)

inputs_embeds = model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat(
    [language_model_inputs, inputs_embeds.to(model.device)], dim=1
)

outputs = model.language_model.generate(
    inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_length=50
)
print(outputs)
print(outputs.shape)