In [1]:
from pathlib import Path
import os
import sys

sys.path.append(Path(".").resolve().as_posix())
os.environ["HF_TOKEN"] = "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"

In [2]:
import torch
from transformers import AutoTokenizer
import math
import os

from llava.model import LlavaGemmaForCausalLM
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from llava.conversation import conv_templates
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path


def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]



[2024-04-16 06:25:58,687] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
model_path = "google/gemma-2b"
model_base = None
conv_mode = "gemma"
num_chunks = 1
chunk_idx = 0
temperature = 0.1
top_p = None
num_beams = 1

In [4]:
# Model but manually, lifted from load_pretrained_model

disable_torch_init()
model_path = os.path.expanduser(model_path)
model_name = get_model_name_from_path(model_path)

device_map = "auto"

kwargs = {
    "device_map": device_map,
    "torch_dtype": torch.float16,
}


tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

model = LlavaGemmaForCausalLM.from_pretrained(
    model_path, low_cpu_mem_usage=True, **kwargs
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# Vision tower with pretrain

from llava.train.train import (
    DataArguments,
    ModelArguments,
    TrainingArguments,
)

model_args = ModelArguments(
    vision_tower="videomamba",
    # vision_tower="LanguageBind/LanguageBind_Video_merge",
    # pretrain_mm_mlp_adapter="mm_projector.bin",
    mm_vision_select_layer=-2,
    mm_use_im_start_end=False,
    mm_use_im_patch_token=False,
)

data_args = DataArguments()
training_args = TrainingArguments(
    output_dir="./llava_output",
)

In [6]:
model.get_model().initialize_vision_modules(
    model_args=model_args,
    fsdp=training_args.fsdp
)

vision_tower = model.get_vision_tower()
vision_tower

LOADING WHAT SEEMS LIKE THE CORRECT STATE DICT
<class 'llava.model.multimodal_encoder.videomamba.models.backbones.videomamba.videomamba.PretrainVideoMamba'>
_IncompatibleKeys(missing_keys=['norm.weight', 'clip_decoder.0.head.weight', 'clip_decoder.0.head.bias', 'clip_decoder.0.norm.weight', 'clip_decoder.0.norm.bias', 'pool_norm.weight', 'pool_norm.bias'], unexpected_keys=['head.weight', 'head.bias', 'norm_f.weight'])
DONE LOADING


VideoMambaVisionTower(
  (vision_tower): PretrainVideoMamba(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(3, 576, kernel_size=(1, 16, 16), stride=(1, 16, 16))
    )
    (drop_path): DropPath(drop_prob=0.250)
    (layers): ModuleList(
      (0-1): 2 x Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=576, out_features=2304, bias=False)
          (conv1d): Conv1d(1152, 1152, kernel_size=(4,), stride=(1,), padding=(3,), groups=1152)
          (act): SiLU()
          (x_proj): Linear(in_features=1152, out_features=68, bias=False)
          (dt_proj): Linear(in_features=36, out_features=1152, bias=True)
          (conv1d_b): Conv1d(1152, 1152, kernel_size=(4,), stride=(1,), padding=(3,), groups=1152)
          (x_proj_b): Linear(in_features=1152, out_features=68, bias=False)
          (dt_proj_b): Linear(in_features=36, out_features=1152, bias=True)
          (out_proj): Linear(in_features=1152, out_features=576, bias=False)
        )
        (norm): RMSNorm()
  

In [7]:
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

data_args.image_processor = vision_tower.video_processor
data_args.is_multimodal = True

model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length

model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
    model.requires_grad_(False)
    for p in model.get_model().mm_projector.parameters():
        p.requires_grad = True

model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
    for p in model.get_model().mm_projector.parameters():
        p.requires_grad = False

model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)

image_processor = None

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)

if mm_use_im_patch_token:
    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
    tokenizer.add_tokens(
        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
    )
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model(device_map=device_map)
if device_map != "auto":
    vision_tower.to(device=device_map, dtype=torch.float16)
image_processor = vision_tower.video_processor

if hasattr(model.config, "max_sequence_length"):
    context_len = model.config.max_sequence_length
else:
    context_len = 2048

In [8]:
qs = "What is unusual about this image?"
cur_prompt = qs

# Insert special image tokens into the text prompt

if model.config.mm_use_im_start_end:
    image_tokens = (
        " ".join(
            [DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN]
            * model.get_vision_tower().config.num_frames
        )
        + "\n"
    )
else:
    image_tokens = (
        " ".join([DEFAULT_IMAGE_TOKEN] * model.get_vision_tower().config.num_frames)
        + "\n"
    )
qs = image_tokens + qs

# Construct conversation prompt

conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .cuda()
)

# url = "https://www.ilankelman.org/stopsigns/australia.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
# image_tensor = process_images([image], image_processor, model.config)[0]

video = "/data/vlm_sandbox/videos/lie1.mp4"
video_tensor = image_processor(video, return_tensors="pt")["pixel_values"].to("cuda")
model.get_model().mm_projector.to("cuda", dtype=torch.float16)

with torch.inference_mode(), torch.amp.autocast("cuda"):
    output_ids = model.generate(
        input_ids,
        images=video_tensor.half().cuda(),
        # image_sizes=[image.size],
        do_sample=True if temperature > 0 else False,
        temperature=temperature,
        top_p=top_p,
        num_beams=num_beams,
        # no_repeat_ngram_size=3,
        max_new_tokens=1024,
        use_cache=True,
    )

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()



torch.Size([1, 8, 3, 224, 224])
torch.Size([1, 1568, 576])


AttributeError: 'Tensor' object has no attribute 'hidden_states'

In [9]:
model.get_vision_tower().hidden_size

576