In [1]:
import torch

import os
from pathlib import Path
import sys

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sys.path.append(Path(".").resolve().as_posix())

1. Load `VideoMambaVideoEncoder` from checkpoint.
2. Load and run forward on `VideoMambaVisionModel`.
3. Load `VideoMambaTextModel` from checkpoint.
4. Assemble and run `VideoMambaModel`.
5. Create a processor class to match `VideoMambaModel`.

In [2]:
# 1. Load VideoMambaVideoEncoder with weights

from llava.model.multimodal_encoder.videomamba2.modeling_videomamba import VideoMambaVideoEncoder, VideoMambaVisionModel, VideoMambaTextModel
from llava.model.multimodal_encoder.videomamba2.configuration_videomamba import VideoMambaVisionConfig, VideoMambaTextConfig
from llava.model.multimodal_encoder.videomamba2.video_processing_videomamba import VideoMambaVideoProcessor

In [None]:
config = VideoMambaVisionConfig()
model = VideoMambaVideoEncoder(config)

In [None]:
checkpoint = torch.load("/workspaces/gemamba/videomamba_m16_25M_f8_res224.pth", map_location="cpu")

new_state_dict = {}

for k, v in checkpoint.items():
    # if "vision_encoder" in k or "vision_proj" in k:
    if "vision_encoder" in k:
        new_key = k[len("vision_encoder."):]
        # print(new_key)
        new_state_dict[new_key] = v

In [None]:
msg = model.load_state_dict(new_state_dict, strict=True)
print(msg)

In [None]:
processor = VideoMambaVideoProcessor()

video_paths = [
    "videos_zero_shot/birds.mp4",
    "videos_zero_shot/fish.mp4",
    "videos_zero_shot/human.mp4",
    "videos_zero_shot/swamp.mp4",
]

videos = processor(video_paths, return_tensors="pt")["pixel_values"].to("cuda")


# b, _, t, _, _ = videos.shape
videos.shape
model = model.to("cuda")
model(videos)
pass

In [None]:
# 2. Build VideoMambaVisionModel

vision_model = VideoMambaVisionModel(config)
vision_model.vision_model.load_state_dict(new_state_dict, strict=True)
vision_model.vision_model = vision_model.vision_model.to("cuda")

In [None]:
vision_model(videos).vision_embeds[0].shape

In [3]:
# 3. Load a text model

text_config = VideoMambaTextConfig.from_json_file("llava/model/multimodal_encoder/videomamba/configs/config_bert.json")

In [7]:
text_model = VideoMambaTextModel.from_pretrained("bert-base-uncased", config=text_config, add_pooling_layer=False)

In [15]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# text_model(tokenizer("some text"))
text_model(**tokenizer("some text", return_tensors='pt'))

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0390,  0.2326, -0.2814,  ..., -0.2142, -0.0064,  0.1941],
         [ 0.3832, -0.3164, -0.2389,  ..., -0.3462,  0.1316,  0.3012],
         [ 0.2595,  0.0545, -0.4032,  ..., -0.3395, -0.2595, -0.1076],
         [ 0.8996,  0.1475, -0.4212,  ..., -0.0248, -0.6638, -0.3487]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=None, hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)

In [None]:
# 4. Create VideoMambaModel

In [None]:
# 5. Create VideoMambaProcessor