## Imports

First, import the necessary packages and the selective-scan compression module (i.e., MambaCompressor).

In [1]:
import torch
from torch import nn
from llava.model.multimodal_resampler.mamba_ssm.modules.mamba_compressor import MambaCompressor

  from .autonotebook import tqdm as notebook_tqdm


Suppose we are given a video of 64 frames, and each frame has (24x24) tokens. We want to compress the temporal dimension 4 times (i.e., 16 frames) and height and width by 2 times. So, we want our output compressed tokens to be of shape (16, 12, 12).

In [3]:
batch_size = 1
input_shape = (64, 24, 24)
hidden_size = 1024
target_shape = (16, 12, 12)

## Model Initialization

Now, let's initialize the MambaCompressor model. We initialize the output projection of the last mamba layer of our model from zero. This is very important for stable optimization if you want to start training our model from a pretrained checkpoint so that it does not break any pretrained weights. 

In [4]:
model = MambaCompressor(d_model=hidden_size, n_layer=1).to("cuda")
torch.nn.init.constant_(model.layers[-1].mixer.out_proj.weight, 0)
for n, p in model.named_parameters():
    if hasattr(p, "ds_numel"):
        print(n, torch.sum(p.ds_tensor).item())
    else:
        print(n, torch.sum(p).item())

layers.0.norm.weight 1024.0
layers.0.mixer.A_log 62815.97265625
layers.0.mixer.D 2048.0
layers.0.mixer.A_b_log 62785.30078125
layers.0.mixer.D_b 2048.0
layers.0.mixer.in_proj.weight 21.868789672851562
layers.0.mixer.conv1d.weight -12.396291732788086
layers.0.mixer.conv1d.bias 9.923412322998047
layers.0.mixer.x_proj.weight -1.0851119756698608
layers.0.mixer.dt_proj.weight 24.908506393432617
layers.0.mixer.dt_proj.bias -9511.03515625
layers.0.mixer.conv1d_b.weight -21.952051162719727
layers.0.mixer.conv1d_b.bias 2.3076255321502686
layers.0.mixer.x_proj_b.weight -14.01053237915039
layers.0.mixer.dt_proj_b.weight -22.66807746887207
layers.0.mixer.dt_proj_b.bias -1.1077779531478882
layers.0.mixer.out_proj.weight 0.0


Here, layers.0.mixer.out_proj.weight is initialized from 0.0, which is expected. Also, check the other weights so that they are not nan. You may need to explicitly initialize Mamba module weights if you want to insert this module inside any other model (e.g., LLaVA). For example, check lines 1755-1758 of BIMBA-LLaVA-NeXT/llava/train/train.py.

## Query Initialization

Then, we initialize query tokens with the same shape as the output using average pooling. The average pooling gives a good initialization for the output tokens.

In [5]:
pooling = nn.AdaptiveAvgPool3d(target_shape)
temporal_pooling = False  # We found not using temporal pooling is good for query initialization

# Assume, space_time_tokens represents our input video.
space_time_tokens = torch.randn(batch_size, input_shape[0], input_shape[1], input_shape[2], hidden_size).to("cuda")
if not temporal_pooling:
    query_tokens = space_time_tokens[:,::4]
# [1, 16, 24, 24, 1024]
query_tokens = query_tokens.permute(0, 4, 1, 2, 3)
# [1, 1024, 16, 24, 24]
query_tokens = pooling(query_tokens)
# [1, 1024, 16, 12, 12]
query_tokens = query_tokens.permute(0, 2, 3, 4, 1)
# [1, 16, 12, 12, 1024]
query_tokens = query_tokens.reshape(batch_size, target_shape[0], -1, hidden_size)
print(space_time_tokens.shape, query_tokens.shape)


torch.Size([1, 64, 24, 24, 1024]) torch.Size([1, 16, 144, 1024])


## Selective-Scan Compression

Now, we apply the MambaCompressor model, which captures fine-grained details from the space_time_tokens into the query_tokens.

In [6]:
query_tokens = model(space_time_tokens, query_tokens)
query_tokens = query_tokens.reshape(batch_size,target_shape[0], target_shape[1], target_shape[2], hidden_size)
print(query_tokens.shape)

torch.Size([1, 16, 12, 12, 1024])


The query_tokens represents a compressed representation of the space_time_tokens (16x compression ratio), which we can pass to our subsequent model (e.g., LLM) for further efficient processing.