In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration

def get_caption(img, llava_model, llava_processor, query=""):
    prompt = f"USER: <image>\n{query}\nASSISTANT:"
    inputs = llava_processor(text=prompt, images=img, return_tensors="pt").to(
        "cuda"
    )

    with torch.no_grad():
        # Generate
        generate_ids = llava_model.generate(**inputs, max_new_tokens=77)
        caption = llava_processor.batch_decode(
            generate_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0]

        start_token = "ASSISTANT:"
        idx_to_start = caption.find(start_token)
        return caption[idx_to_start + len(start_token) :]

# model setup
dtype = torch.float16
model_id = "llava-hf/llava-1.5-7b-hf"
llava_model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=dtype,
    low_cpu_mem_usage=True,
).cuda()
llava_processor = AutoProcessor.from_pretrained(model_id)

In [None]:
img = Image.open("./amp-llava.png")
caption = get_caption(img, llava_model, llava_processor, query="Describe the image in 20 words or less")
print(caption)
display(img)