In [None]:
import os
from pathlib import Path
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sys.path.append(Path(".").resolve().as_posix())

In [None]:
import torch

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    DEFAULT_IMAGE_PATCH_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    tokenizer_image_token,
    tokenizer_image_token2,
    process_images,
    get_model_name_from_path,
)
from llava.model import (
    LlavaConfig,
    LlavaMistralForCausalLM,
    LlavaLlamaForCausalLM,
    LlavaGemmaForCausalLM,
    LlavaGemmaConfig,
    LlavaPhi3Config,
    LlavaPhi3ForCausalLM,
)

In [None]:
model_path = "path/to/checkpoint"

In [None]:
disable_torch_init()
model_path = os.path.expanduser(model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)

In [None]:
model.get_model().to("cuda", dtype=torch.bfloat16)
model.get_model().mm_projector.to("cuda", dtype=torch.bfloat16)

In [None]:
# set generation parameters
conv_mode = "gemma"
num_chunks = 1
chunk_idx = 0
temperature = 0.5
top_p = None
num_beams = 1

In [None]:
def build_prompt(text):
    # insert special image tokens into the text prompt
    text = f"{DEFAULT_IMAGE_TOKEN}\n{text}"

    # construct conversation
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], text)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    return prompt


def get_text_batch(texts, tokenizer):
    prompts = []
    for text in texts:
        prompts.append(build_prompt(text))

    # # tokenize the prompt
    inputs = (
        tokenizer_image_token2(prompts, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    )

    for k, v in inputs.items():
        inputs[k] = v.to("cuda")

    return inputs


def get_video_batch(videos, image_processor):
    video_tensor = image_processor(videos, return_tensors="pt")["pixel_values"].to(
        "cuda", dtype=torch.bfloat16
    )
    return video_tensor

In [None]:
texts = [
    "How many birds are there?",
    "Describe what's on the video.",
    "What facial expression does this person have?",
    "Describe the scene."
]

videos = [
    "videos_zero_shot/birds.mp4",
    "videos_zero_shot/fish.mp4",
    "videos_zero_shot/human.mp4",
    "videos_zero_shot/swamp.mp4"
]

text_inputs = get_text_batch(texts, tokenizer)
video_tensor = get_video_batch(videos, image_processor)

In [None]:
with torch.inference_mode(), torch.amp.autocast("cuda"):
    output_ids = model.generate(
        **text_inputs,
        images=video_tensor,
        # image_sizes=[image.size],
        do_sample=True if temperature > 0 else False,
        temperature=temperature,
        top_p=top_p,
        num_beams=num_beams,
        # no_repeat_ngram_size=3,
        max_new_tokens=1024,
        use_cache=True)

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

In [None]:
for output in outputs:
    print(output)