In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForVision2Seq, AutoTokenizer, StoppingCriteria

def get_caption(img, blip_model, blip_image_processor, blip_tokenizer, query=""):
    # define the prompt template
    def apply_prompt_template(prompt):
        s = (
            "<|system|>\nA chat between a curious user and an artificial intelligence assistant. "
            "The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
            f"<|user|>\n<image>\n{prompt}<|end|>\n<|assistant|>\n"
        )
        return s

    class EosListStoppingCriteria(StoppingCriteria):
        def __init__(self, eos_sequence=[32007]):
            self.eos_sequence = eos_sequence

        def __call__(
            self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
        ) -> bool:
            last_ids = input_ids[:, -len(self.eos_sequence) :].tolist()

            return self.eos_sequence in last_ids

    inputs = blip_image_processor(
        [img], return_tensors="pt", image_aspect_ratio="pad"
    )
    prompt = apply_prompt_template(query)
    language_inputs = blip_tokenizer([prompt], return_tensors="pt")
    inputs.update(language_inputs)
    inputs = {name: tensor.cuda() for name, tensor in inputs.items()}

    with torch.no_grad():
        generated_text = blip_model.generate(
            **inputs,
            image_size=[img.size],
            pad_token_id=blip_tokenizer.pad_token_id,
            do_sample=False,
            max_new_tokens=77,
            top_p=None,
            num_beams=1,
            stopping_criteria=[EosListStoppingCriteria()],
        )
        prediction = blip_tokenizer.decode(
            generated_text[0], skip_special_tokens=True
        ).split("<|end|>")[0]

    return prediction

# model setup
model_name_or_path = "Salesforce/xgen-mm-phi3-mini-instruct-r-v1"
blip_model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, trust_remote_code=True).to("cuda")
blip_tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path, trust_remote_code=True, use_fast=False, legacy=False
)
blip_image_processor = AutoImageProcessor.from_pretrained(
    model_name_or_path, trust_remote_code=True
)
blip_tokenizer = blip_model.update_special_tokens(blip_tokenizer)

In [None]:
img = Image.open("./amp-xgenmm.png")
caption = get_caption(img, blip_model, blip_image_processor, blip_tokenizer, query="Describe the image in 20 words or less")
print(caption)
display(img)