In [1]:
import os
from pathlib import Path
import sys

os.environ["HF_TOKEN"] = "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"
sys.path.append(Path(".").resolve().as_posix())

In [2]:
import torch
from llava.model.multimodal_encoder.videomamba.vision_tower import (
    DEFAULT_VIDEOMAMBA_CONFIG,
)
from llava.model.multimodal_encoder.videomamba.models.backbones.videomamba.videomamba import (
    PretrainVideoMamba,
)
from llava.model.multimodal_encoder.videomamba.models.backbones.bert.builder import (
    build_bert,
)



[2024-04-22 07:04:01,655] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status




In [3]:
checkpoint = torch.load("videomamba_m16_5M_f8_res224.pth", map_location="cpu")

In [4]:
from llava.model.multimodal_encoder.videomamba.models.criterions import get_sim


class VideoMambaEncoder(torch.nn.Module):
    def __init__(self, vision_encoder, vision_proj, text_encoder, text_proj):
        super().__init__()

        self.vision_encoder = vision_encoder
        self.vision_proj = vision_proj

        self.text_encoder = text_encoder
        self.text_proj = text_proj

    def forward(self, video_input, text_input):
        vision_feats, pooled_vision_feats, _ = self.vision_encoder(video_input)
        # vision_feats = self.vision_proj(pooled_vision_feats)

        text_output = self.text_encoder(
            text_input.input_ids,
            attention_mask=text_input.attention_mask,
            return_dict=True,
            mode="text",
        )

        # text_feats = text_output.last_hidden_state

        text_embeds = text_output.last_hidden_state
        pooled_text_embeds = text_embeds[:, 0]

        i2t_scores, t2i_scores = get_sim(
            self.vision_proj(pooled_vision_feats), self.text_proj(pooled_text_embeds)
        )

        return i2t_scores

In [5]:
new_state_dict = {}

for k, v in checkpoint.items():
    if "vision_encoder" in k or "vision_proj" in k or "text_encoder" in k or "text_proj" in k:
        if "bert." in k:
            k = k.replace("bert.", "")
        if "cls.predictions" in k:
            continue
        new_state_dict[k] = v

# new_state_dict.keys()

In [6]:
config = DEFAULT_VIDEOMAMBA_CONFIG

vision_encoder = PretrainVideoMamba(
    img_size=config.model.vision_encoder.img_size,
    patch_size=config.model.vision_encoder.patch_size,
    depth=config.model.vision_encoder.depth,
    embed_dim=config.model.vision_encoder.embed_dim,
    drop_path_rate=config.model.vision_encoder.drop_path_rate,
    ssm_cfg=config.model.vision_encoder.ssm_cfg,
    norm_epsilon=config.model.vision_encoder.norm_epsilon,
    fused_add_norm=config.model.vision_encoder.fused_add_norm,
    rms_norm=config.model.vision_encoder.rms_norm,
    residual_in_fp32=config.model.vision_encoder.residual_in_fp32,
    bimamba=config.model.vision_encoder.bimamba,
    pool_type=config.model.vision_encoder.pool_type,
    kernel_size=config.model.vision_encoder.kernel_size,
    num_frames=config.model.vision_encoder.num_frames,
    use_checkpoint=config.model.vision_encoder.use_checkpoint,
    checkpoint_num=config.model.vision_encoder.checkpoint_num,
    clip_decoder_embed_dim=config.model.vision_encoder.clip_decoder_embed_dim,
    clip_output_dim=config.model.vision_encoder.clip_output_dim,
    clip_return_layer=config.model.vision_encoder.clip_return_layer,
    clip_student_return_interval=config.model.vision_encoder.clip_student_return_interval,
    add_pool_norm=True,  # TO GET POOLED FEATURES
)


vision_proj = torch.nn.Linear(config.model.vision_encoder.embed_dim, config.model.embed_dim)

In [7]:
encoder_name = config.model.text_encoder.name

text_encoder = build_bert(
    config.model,
    False,
    config.gradient_checkpointing,
)

text_proj = torch.nn.Linear(config.model.text_encoder.d_model, config.model.embed_dim)

In [8]:
model = VideoMambaEncoder(vision_encoder, vision_proj, text_encoder, text_proj)

In [9]:
model.load_state_dict(new_state_dict, strict=True)

<All keys matched successfully>

In [10]:
# Load and preprocess video
# from videomamba.video_mm.dataset.__init__ -> create_dataset()
# from videomamba.video_mm.dataset.base_dataset -> ImageVideoBaseDataset.load_and_transform_media_data_video()

from llava.model.multimodal_encoder.videomamba.hf_parts.processing_videomamba import VideoMambaVideoProcessor
from llava.model.multimodal_encoder.videomamba.models.backbones.bert.tokenization_bert import (
    BertTokenizer,
)

tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained)
processor = VideoMambaVideoProcessor(config, tokenizer=tokenizer)

video = "/data/vlm_sandbox/videos/lie1.mp4"
text = "Person behind a table"

In [None]:
# inputs = processor(images=video, text=text, return_tensors="pt")

# for key, tensor in inputs.items():
#     inputs[key] = tensor.to("cuda")

# model.to("cuda")
# inputs.keys()

In [None]:
# print(inputs["pixel_values"].device)
# print(inputs["input_ids"].device)
# print(inputs["token_type_ids"].device)
# print(inputs["attention_mask"].device)

# inputs.keys()

In [None]:
# sim = model(inputs["pixel_values"], inputs)
# sim

In [12]:
from collections import defaultdict


model.to("cuda")

sample_texts = ["birds", "fish", "human", "swamp"]
sample_videos = [
    "videos_zero_shot/birds.mp4",
    "videos_zero_shot/fish.mp4",
    "videos_zero_shot/human.mp4",
    "videos_zero_shot/swamp.mp4",
]

res = defaultdict(list)

for video in sample_videos:
    for text in sample_texts:
        inputs = processor(images=video, text=text, return_tensors="pt")

        for key, tensor in inputs.items():
            inputs[key] = tensor.to("cuda")

        sim = model(inputs["pixel_values"], inputs)[0][0]
        res[video].append(sim)

for video, sims in res.items():
    result_str = ", ".join([f"{label} = {prob}" for label, prob in zip(sample_texts, sims)])
    print(f"{video} :: {result_str}")

videos_zero_shot/birds.mp4 :: birds = 0.46033063530921936, fish = 0.3552638292312622, human = 0.20799045264720917, swamp = 0.3916945159435272
videos_zero_shot/fish.mp4 :: birds = 0.27160006761550903, fish = 0.4055670499801636, human = 0.20728805661201477, swamp = 0.13486051559448242
videos_zero_shot/human.mp4 :: birds = 0.10954683274030685, fish = 0.08383047580718994, human = 0.2469262033700943, swamp = -0.02614482119679451
videos_zero_shot/swamp.mp4 :: birds = 0.28155460953712463, fish = 0.280953973531723, human = 0.15472877025604248, swamp = 0.4352746605873108
