In [3]:
from transformers import LlavaProcessor, LlavaForConditionalGeneration
import torch
import uuid
import requests
import cv2
from PIL import Image

In [5]:
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"

processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

LlavaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(729, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-25): 26 x SiglipEncoderLayer(
            (self_attn): SiglipSdpaAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features=115

In [6]:
def replace_video_with_images(text, frames):
  return text.replace("<video>", "<image>" * frames)

def sample_frames(url, num_frames):

    response = requests.get(url)
    path_id = str(uuid.uuid4())

    path = f"./{path_id}.mp4" 

    with open(path, "wb") as f:
      f.write(response.content)

    video = cv2.VideoCapture(path)
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    interval = total_frames // num_frames
    frames = []
    for i in range(total_frames):
        ret, frame = video.read()
        pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        if not ret:
            continue
        if i % interval == 0:
            frames.append(pil_img)
    video.release()
    return frames[:num_frames]

In [7]:
video_1 = "https://huggingface.co/spaces/merve/llava-interleave/resolve/main/cats_1.mp4"
video_2 = "https://huggingface.co/spaces/merve/llava-interleave/resolve/main/cats_2.mp4"

video_1 = sample_frames(video_1, 6)
video_2 = sample_frames(video_2, 6)

videos = video_1 + video_2

videos

[<PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>,
 <PIL.Image.Image image mode=RGB size=1920x1080>]

In [8]:
user_prompt = "Are these two cats in these two videos doing the same thing?"
toks = "<image>" * 12
prompt = "<|im_start|>user"+ toks + f"\n{user_prompt}<|im_end|><|im_start|>assistant"
inputs = processor(text=prompt, images=videos, return_tensors="pt").to(model.device, model.dtype)

In [10]:
output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True)[len(user_prompt)+10:])

Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


KeyboardInterrupt: 