In [1]:
from pathlib import Path
import sys

sys.path.append(Path(".").resolve().as_posix())

In [2]:
# !pip3 install -e causal-conv1d
# !pip3 install -e mamba

In [6]:
# Setup configuration
# from videomamba.video_mm.exp_zs.msrvtt.config

from models.umt_videomamba import UMT_VIDEOMAMBA
from utils.easydict import EasyDict

num_frames = 8
img_size = 224
batch_size = 64
max_txt_l = 32

# model_pth = "videomamba_m16_k400_mask_pt_f8_res224.pth"  # default in all configs
# model_pth = "videomamba_m16_5M_f8_res224.pth"  # broken pos
model_pth = "videomamba_m16_k400_mask_ft_f8_res224.pth"

config_dict = {
    "num_frames": num_frames,
    "num_frames_test": num_frames,
    "batch_size": batch_size,
    "max_txt_l": max_txt_l,
    "inputs": {
        "image_res": img_size,
        "video_input": {
            "num_frames": num_frames,
            "sample_type": "rand",
            "num_frames_test": num_frames,
            "sample_type_test": "middle",
            "random_aug": False
        },
        "max_txt_l": {
            "image": max_txt_l,
            "video": max_txt_l
        },
        "batch_size": {
            "image": batch_size,
            "video": batch_size
        },
        "batch_size_test": {
            "image": batch_size,
            "video": batch_size
        }
    },
    "text_enc": "bert",
    "model": {
        "model_cls": UMT_VIDEOMAMBA,
        "vision_encoder": {
            "name": "videomamba_middle",
            "img_size": img_size,
            "patch_size": 16,
            "depth": 32,
            "embed_dim": 576,
            "drop_path_rate": 0.25,
            "ssm_cfg": None,
            "norm_epsilon": 1e-5,
            "fused_add_norm": True,
            "rms_norm": True,
            "residual_in_fp32": True,
            "bimamba": True,
            "pool_type": "cls+avg",
            "kernel_size": 1,
            "num_frames": num_frames,
            "ckpt_num_frame": 8,
            "use_checkpoint": False,
            "checkpoint_num": 0,
            "clip_decoder_embed_dim": 576,
            "clip_output_dim": 512,
            "clip_norm_type": "l2",
            "clip_return_layer": 1,
            "clip_student_return_interval": 1,
            "pretrained": model_pth,
            "clip_teacher": "none",
            "clip_img_size": img_size,
            "clip_return_interval": 1,
            "video_mask_type": "none",
            "video_mask_ratio": 0.0,
            "video_double_mask_ratio": 0.0,
            "image_mask_type": "none",
            "image_mask_ratio": 0.0,
            "image_double_mask_ratio": 0.0,
            "keep_temporal": True
        },
        "text_encoder": {
            "name": "bert_base",
            "pretrained": "bert-base-uncased",
            "config": "configs/config_bert.json",
            "d_model": 768,
            "fusion_layer": 9
        },
        "multimodal": {"enable": True},
        "embed_dim": 512,
        "temp": 0.07
    },
    "criterion": {
        "loss_weight": {
            "vtc": 1.0,
            "mlm": 1.0,
            "vtm": 1.0,
            "uta": 0.0
        },
        "vtm_hard_neg": True,
        "mlm_masking_prob": 0.5,
        "uta_norm_type": "l2",
        "uta_loss_type": "l2"
    },
    "optimizer": {
        "opt": "adamW",
        "lr": 1e-5,
        "opt_betas": [0.9, 0.999],
        "weight_decay": 0.02,
        "max_grad_norm": -1,
        "different_lr": {"enable": False, "module_names": [], "lr": 4e-3}
    },
    "scheduler": {
        "sched": "cosine",
        "epochs": 2,
        "min_lr_multi": 0.01,
        "warmup_epochs": 0.2,
        "num_warmup_steps": 8,
        "num_training_steps": 1000,
    },
    "evaluate": False,
    "deep_fusion": False,
    "evaluation": {
        "eval_frame_ensemble": "concat",
        "eval_x_only": False,
        "k_test": 128,
        "eval_offload": False
    },
    "fp16": True,
    "bf16": True,
    "gradient_checkpointing": True,
    "device": "cuda",
    "mode": "pt",
    "output_dir": None,
    "resume": False,
    "debug": False,
    "log_freq": 1,
    "seed": 42,
    "zero_shot": True,
    "save_latest": False,
    "auto_resume": False,
    "pretrained_path": model_pth,
    "distributed": False,
}

config = EasyDict(config_dict)

In [4]:
# !wget https://huggingface.co/OpenGVLab/VideoMamba/resolve/main/videomamba_m16_k400_mask_pt_f8_res224.pth
# !wget https://huggingface.co/OpenGVLab/VideoMamba/resolve/main/videomamba_m16_5M_f8_res224.pth
# !wget https://huggingface.co/OpenGVLab/VideoMamba/resolve/main/videomamba_m16_k400_mask_ft_f8_res224.pth

--2024-04-15 07:58:10--  https://huggingface.co/OpenGVLab/VideoMamba/resolve/main/videomamba_m16_k400_mask_ft_f8_res224.pth
Resolving huggingface.co (huggingface.co)... 18.154.227.7, 18.154.227.69, 18.154.227.67, ...
Connecting to huggingface.co (huggingface.co)|18.154.227.7|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/60/50/60507f0fe1a4d5abb54b48fe1e032af809ede0f905d09ec0dd581d9a9e9e6b3f/a101883c7ce8794b887469bcbbce7e09c0c6962dfae23d30223f0dd96275272f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27videomamba_m16_k400_mask_ft_f8_res224.pth%3B+filename%3D%22videomamba_m16_k400_mask_ft_f8_res224.pth%22%3B&Expires=1713427090&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMzQyNzA5MH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzYwLzUwLzYwNTA3ZjBmZTFhNGQ1YWJiNTRiNDhmZTFlMDMyYWY4MDllZGUwZjkwNWQwOWVjMGRkNTgxZDlhOWU5ZTZiM2YvYTEw

In [7]:
# Load the model
# from videomamba.video_mm.tasks.retrieval

from tasks.shared_utils import setup_model

(
    model,
    model_without_ddp,
    optimizer,
    scheduler,
    scaler,
    tokenizer,
    start_epoch,
    global_step,
) = setup_model(
    config,
    model_cls=config.model.model_cls,
    has_decoder=False,
    pretrain=False,
    # find_unused_parameters=True,
    find_unused_parameters=False,
)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [8]:
# Load and preprocess video
# from videomamba.video_mm.dataset.__init__ -> create_dataset()
# from videomamba.video_mm.dataset.base_dataset -> ImageVideoBaseDataset.load_and_transform_media_data_video()

from torchvision import transforms
from dataset.video_utils import read_frames_decord

mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)

normalize = transforms.Normalize(mean, std)

# loaded images and videos are torch.Tensor of torch.uint8 format,
# ordered as (T, 1 or 3, H, W) where T=1 for image
type_transform = transforms.Lambda(lambda x: x.float().div(255.0))

transform = transforms.Compose(
    [
        transforms.Resize(
            (config.inputs.image_res, config.inputs.image_res),
            interpolation=transforms.InterpolationMode.BICUBIC,
        ),
        type_transform,
        normalize,
    ]
)

video_reader = read_frames_decord

max_num_frames = -1
frames, frame_indices, video_duration = video_reader(
    "/data/vlm_sandbox/videos/lie1.mp4",
    config.inputs.video_input.num_frames,
    config.inputs.video_input.sample_type,
    max_num_frames=max_num_frames,
    client=None,
    trimmed30=False,
)

frames = transform(frames).unsqueeze(0)
frames.shape



torch.Size([1, 8, 3, 224, 224])

In [9]:
# Encode image
# from videomamba.video_mm.tasks.retrieval_utils -> extract_vision_feats()

import torch
from einops import rearrange

device = torch.device("cuda")

image_feats_all = []
pooled_image_feats_all = []

with torch.cuda.amp.autocast():
    image = frames.to(device, non_blocking=True)
    image_feat, pooled_image_feat = model.encode_vision(image, test=True)

if len(image_feat.shape) == 4:
    image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
image_feat = image_feat.unsqueeze(1)  # (bsz, 1, #frm*L, d)

if config.evaluation.eval_offload:
    image_feats_all.append(image_feat.cpu())
    pooled_image_feats_all.append(pooled_image_feat.cpu())
else:
    image_feats_all.append(image_feat)
    pooled_image_feats_all.append(pooled_image_feat)

image_feats_all = torch.cat(image_feats_all, dim=0)
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)

In [10]:
# Encode text
# from videomamba.video_mm.tasks.retrieval_utils -> extract_text_feats()

text = "Person behind a table"
text_bs = 256
text_feats = []
text_atts = []

text_input = tokenizer(
    text,
    padding="max_length",
    truncation=True,
    max_length=max_txt_l,
    return_tensors="pt",
).to(device)

text_feat = model.encode_text(text_input)[0]
text_feats.append(text_feat)
text_atts.append(text_input.attention_mask)

text_feats = torch.cat(text_feats, dim=0)
text_atts = torch.cat(text_atts, dim=0)

In [11]:
print(image_feats_all.shape)
print(pooled_image_feats_all.shape)
print(text_feats.shape)

torch.Size([1, 1, 1568, 576])
torch.Size([1, 8, 576])
torch.Size([1, 32, 768])


In [12]:
# Calculate similarity
# from videomamba.video_mm.tasks.retrieval_utils -> evaluation()

from models.criterions import get_sim

_pooled_image_feats = (
    pooled_image_feats_all.to(device, non_blocking=True)
    if config.evaluation.eval_offload
    else pooled_image_feats_all
)
with torch.cuda.amp.autocast():
    i2t_scores, t2i_scores = get_sim(
        model.vision_proj(_pooled_image_feats), model.text_proj(text_feats[:, 0])
    )

print(float(i2t_scores))

-0.0031070709228515625
