In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

def get_caption(img, cogvlm_model, cogvlm_tokenizer, query=""):
    inputs = cogvlm_model.build_conversation_input_ids(
        cogvlm_tokenizer,
        query=query,
        images=[img],
    )
    inputs = {
        "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
        "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
        "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
        "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
    }
    gen_kwargs = {"max_new_tokens": 77, "do_sample": False}

    with torch.no_grad():
        outputs = cogvlm_model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs["input_ids"].shape[1] :]
        caption = cogvlm_tokenizer.decode(outputs[0])

    caption = caption.replace(cogvlm_tokenizer.eos_token, "")
    return caption

# model setup
dtype = torch.bfloat16
cogvlm_model = (
    AutoModelForCausalLM.from_pretrained(
        "THUDM/cogvlm-chat-hf",
        torch_dtype=dtype,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    .eval()
    .cuda()
)
tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")

In [None]:
img = Image.open("./amp-cogvlm.png")
caption = get_caption(img, cogvlm_model, tokenizer, query="Describe the image in 20 words or less")
print(caption)
display(img)