In [None]:
import data
import torch
from models import imagebind_model
from models.imagebind_model import ModalityType

text_list=["A dog.", "A car", "A bird"]
image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"]

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

# Load data
inputs = {
    ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
    ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}

with torch.no_grad():
    embeddings = model(inputs)

print(
    "Vision x Text: ",
    torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
)
print(
    "Audio x Text: ",
    torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1),
)
print(
    "Vision x Audio: ",
    torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
)

# Expected output:
#
# Vision x Text:
# tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05],
#         [3.3836e-05, 9.9994e-01, 2.4118e-05],
#         [4.7997e-05, 1.3496e-02, 9.8646e-01]])
#
# Audio x Text:
# tensor([[1., 0., 0.],
#         [0., 1., 0.],
#         [0., 0., 1.]])
#
# Vision x Audio:
# tensor([[0.8070, 0.1088, 0.0842],
#         [0.1036, 0.7884, 0.1079],
#         [0.0018, 0.0022, 0.9960]])


In [None]:
torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1)

In [None]:
from models.multimodal_preprocessors import PatchEmbedGeneric,PadIm2Video
import torch.nn as nn
import numpy as np
from typing import Optional
class PatchEmbedGeneric(nn.Module):
    """
    PatchEmbed from Hydra
    """

    def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
        super().__init__()

        if len(proj_stem) > 1:
            self.proj = nn.Sequential(*proj_stem)
        else:
            # Special case to be able to load pre-trained models that were
            # trained with a standard stem
            self.proj = proj_stem[0]
        self.norm_layer = norm_layer

    def get_patch_layout(self, img_size):
        with torch.no_grad():
            dummy_img = torch.zeros(
                [
                    1,
                ]
                + img_size
            )
            print(dummy_img.shape)
            dummy_out = self.proj(dummy_img)
        print(dummy_out.shape)
        embed_dim = dummy_out.shape[1]
        patches_layout = tuple(dummy_out.shape[2:])
        num_patches = np.prod(patches_layout)
        return patches_layout, num_patches, embed_dim

    def forward(self, x):
        print(x.shape)
        x = self.proj(x)
        print(x.shape)
        # B C (T) H W -> B (T)HW C
        x = x.flatten(2).transpose(1, 2)
        print(x.shape)
        if self.norm_layer is not None:
            x = self.norm_layer(x)
        return x
    
kernel_size=(2, 14, 14)
vision_embed_dim=1024
proj_stem=[
                PadIm2Video(pad_type="repeat", ntimes=2),
                nn.Conv3d(
                    in_channels=3,
                    kernel_size=kernel_size,
                    out_channels=vision_embed_dim,
                    stride=kernel_size,
                    bias=False,
                )
]
PatchEmbedGeneric(proj_stem,None).get_patch_layout([3, 2,224, 224])

In [None]:
from models.helpers import VerboseNNModule
from typing import Tuple, Optional, Callable
from models.helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize,
                            SelectElement, SelectEOSAndProject)
from models.multimodal_preprocessors import (AudioPreprocessor,
                                             IMUPreprocessor, PadIm2Video,
                                             PatchEmbedGeneric,
                                             RGBDTPreprocessor,
                                             SpatioTemporalPosEmbeddingHelper,
                                             TextPreprocessor,
                                             ThermalPreprocessor)
from models.transformer import MultiheadAttention, SimpleTransformer

import logging
import os
from functools import partial
from types import SimpleNamespace
from typing import Dict

import torch
import torch.nn as nn

class RGBDTPreprocessor(VerboseNNModule):
    def __init__(
        self,
        rgbt_stem: PatchEmbedGeneric,
        depth_stem: Optional[PatchEmbedGeneric],
        img_size: Tuple = (3, 224, 224),
        num_cls_tokens: int = 1,
        pos_embed_fn: Optional[Callable] = None,
        use_type_embed: bool = False,
        init_param_style: str = "openclip",
    ) -> None:
        super().__init__()
        stem = rgbt_stem if rgbt_stem is not None else depth_stem
        (
            self.patches_layout,
            self.num_patches,
            self.embed_dim,
        ) = stem.get_patch_layout(img_size)
        self.rgbt_stem = rgbt_stem
        self.depth_stem = depth_stem
        self.use_pos_embed = pos_embed_fn is not None
        self.use_type_embed = use_type_embed
        self.num_cls_tokens = num_cls_tokens

        if self.use_pos_embed:
            self.pos_embedding_helper = pos_embed_fn(
                patches_layout=self.patches_layout,
                num_cls_tokens=num_cls_tokens,
                num_patches=self.num_patches,
                embed_dim=self.embed_dim,
            )
        if self.num_cls_tokens > 0:
            self.cls_token = nn.Parameter(
                torch.zeros(1, self.num_cls_tokens, self.embed_dim)
            )
        if self.use_type_embed:
            self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))

        self.init_parameters(init_param_style)

    @torch.no_grad()
    def init_parameters(self, init_param_style):
        if init_param_style == "openclip":
            # OpenCLIP style initialization
            scale = self.embed_dim**-0.5
            if self.use_pos_embed:
                nn.init.normal_(self.pos_embedding_helper.pos_embed)
                self.pos_embedding_helper.pos_embed *= scale

            if self.num_cls_tokens > 0:
                nn.init.normal_(self.cls_token)
                self.cls_token *= scale
        elif init_param_style == "vit":
            self.cls_token.data.fill_(0)
        else:
            raise ValueError(f"Unknown init {init_param_style}")

        if self.use_type_embed:
            nn.init.normal_(self.type_embed)

    def tokenize_input_and_cls_pos(self, input, stem, mask):
        # tokens is of shape B x L x D
        tokens = stem(input)
        assert tokens.ndim == 3
        assert tokens.shape[2] == self.embed_dim
        B = tokens.shape[0]
        if self.num_cls_tokens > 0:
            class_tokens = self.cls_token.expand(
                B, -1, -1
            )  # stole class_tokens impl from Phil Wang, thanks
            tokens = torch.cat((class_tokens, tokens), dim=1)
        if self.use_pos_embed:
            pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
            tokens = tokens + pos_embed
        if self.use_type_embed:
            tokens = tokens + self.type_embed.expand(B, -1, -1)
        return tokens

    def forward(self, vision=None, depth=None, patch_mask=None):
        if patch_mask is not None:
            raise NotImplementedError()

        if vision is not None:
            vision_tokens = self.tokenize_input_and_cls_pos(
                vision, self.rgbt_stem, patch_mask
            )

        if depth is not None:
            depth_tokens = self.tokenize_input_and_cls_pos(
                depth, self.depth_stem, patch_mask
            )

        # aggregate tokens
        if vision is not None and depth is not None:
            final_tokens = vision_tokens + depth_tokens
        else:
            final_tokens = vision_tokens if vision is not None else depth_tokens
        return_dict = {
            "trunk": {
                "tokens": final_tokens,
            },
            "head": {},
        }
        return return_dict

rgbt_stem = PatchEmbedGeneric(
            proj_stem=[
                PadIm2Video(pad_type="repeat", ntimes=2),
                nn.Conv3d(
                    in_channels=3,
                    kernel_size=kernel_size,
                    out_channels=vision_embed_dim,
                    stride=kernel_size,
                    bias=False,
                ),
            ]
        )

rgbt_preprocessor = RGBDTPreprocessor(
            img_size=[3, 2, 224, 224],
            num_cls_tokens=1,
            pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
            rgbt_stem=rgbt_stem,
            depth_stem=None,
        )

rgbt_preprocessor(dummy_img[0])['trunk']['tokens'].shape

In [1]:
from models.imagebind_model import ImageBindModel
imageBind=ImageBindModel()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
dummy_img = torch.zeros(
                [
                    1,
                ]
                + [1, 224, 224]
            )
imageBind.layer_shapes(dummy_img,modality_type="thermal")

input torch.Size([1, 1, 224, 224])
trunk_inputs shape torch.Size([1, 197, 768])
after trunk shape torch.Size([1, 197, 768])
after head shape torch.Size([1, 768])
postprocessor shape torch.Size([1, 768])


In [4]:
import torch
path='.checkpoints/imagebind_huge.pth'
state_dict = torch.load(path)

In [6]:
state_dict['modality_heads.thermal.2.weight'].shape

torch.Size([1024, 768])

In [10]:
[keys for keys in state_dict.keys() if 'thermal' in keys]

['modality_preprocessors.thermal.cls_token',
 'modality_preprocessors.thermal.rgbt_stem.proj.weight',
 'modality_preprocessors.thermal.rgbt_stem.norm_layer.weight',
 'modality_preprocessors.thermal.rgbt_stem.norm_layer.bias',
 'modality_preprocessors.thermal.pos_embedding_helper.pos_embed',
 'modality_trunks.thermal.blocks.0.attn.in_proj_weight',
 'modality_trunks.thermal.blocks.0.attn.in_proj_bias',
 'modality_trunks.thermal.blocks.0.attn.bias_k',
 'modality_trunks.thermal.blocks.0.attn.bias_v',
 'modality_trunks.thermal.blocks.0.attn.out_proj.weight',
 'modality_trunks.thermal.blocks.0.attn.out_proj.bias',
 'modality_trunks.thermal.blocks.0.norm_1.weight',
 'modality_trunks.thermal.blocks.0.norm_1.bias',
 'modality_trunks.thermal.blocks.0.mlp.fc1.weight',
 'modality_trunks.thermal.blocks.0.mlp.fc1.bias',
 'modality_trunks.thermal.blocks.0.mlp.fc2.weight',
 'modality_trunks.thermal.blocks.0.mlp.fc2.bias',
 'modality_trunks.thermal.blocks.0.norm_2.weight',
 'modality_trunks.thermal.blo

In [15]:
[key for key in state_dict.keys() if key.startswith("modality_preprocessors.thermal")]

['modality_preprocessors.thermal.cls_token',
 'modality_preprocessors.thermal.rgbt_stem.proj.weight',
 'modality_preprocessors.thermal.rgbt_stem.norm_layer.weight',
 'modality_preprocessors.thermal.rgbt_stem.norm_layer.bias',
 'modality_preprocessors.thermal.pos_embedding_helper.pos_embed']

In [3]:
from models.events import EventModel
e=EventModel()
e

<models.events.EventModel at 0x7f975b4e1220>

In [4]:
e.event_preprocessor.cls_token

Parameter containing:
tensor([[[ 1.6901e-03,  2.0244e-02,  2.6660e-02,  5.9328e-02,  3.1017e-02,
          -1.8562e-02,  1.1770e-02,  5.9572e-02, -7.6173e-02,  5.7392e-03,
           4.9086e-02,  1.0775e-02,  4.0713e-02, -4.3928e-02, -3.2772e-02,
           2.1218e-02, -3.3414e-02,  3.4800e-02,  3.6663e-02,  2.5046e-05,
          -6.2629e-03, -7.0296e-02, -5.2900e-02,  2.8473e-04, -1.8450e-02,
          -4.5848e-02,  8.9453e-03, -5.9607e-03,  1.8311e-02,  1.0183e-02,
           4.7936e-02,  2.5173e-02,  3.6072e-03,  5.9513e-02,  1.7499e-02,
           7.0098e-02,  5.5672e-03,  8.9034e-04,  7.3013e-02, -1.8981e-02,
          -5.8576e-03,  5.1198e-03,  3.6699e-02,  1.9843e-02,  2.8855e-02,
           1.4209e-02,  7.4288e-03, -3.2217e-02,  8.0838e-02, -2.0358e-02,
          -2.2422e-02, -3.4644e-02, -1.6200e-02,  2.9877e-02, -2.0436e-02,
           4.5344e-02, -1.2675e-02,  3.2115e-02, -2.5863e-02, -2.4318e-02,
          -8.2837e-04,  2.7650e-02, -6.2197e-02,  3.4397e-03, -2.9415e-02,
   

In [16]:
e.event_preprocessor.state_dict()['cls_token']

tensor([[[ 0.0216, -0.0204, -0.0031,  0.0263,  0.0095, -0.0296, -0.0817,
          -0.0247,  0.0244, -0.0231,  0.0479, -0.0058,  0.0232,  0.0399,
          -0.0922, -0.0184,  0.0244, -0.0361,  0.0175,  0.0034,  0.0123,
           0.0208, -0.0182, -0.0053,  0.0406,  0.0445,  0.0111,  0.0472,
          -0.0185,  0.0164, -0.0453, -0.0149, -0.0017, -0.0231,  0.0250,
          -0.0359,  0.0543, -0.0092, -0.0140, -0.0057, -0.0096,  0.0006,
           0.0372,  0.0211,  0.0089, -0.0156,  0.0044,  0.0141, -0.0540,
          -0.0087,  0.0678, -0.0273, -0.0088,  0.0241,  0.0194,  0.0472,
          -0.0228,  0.0368,  0.0102, -0.0241, -0.0740,  0.0107, -0.0082,
           0.0195,  0.0530,  0.0257, -0.0187,  0.0066,  0.0205, -0.0340,
           0.0405,  0.0110,  0.0159, -0.0156, -0.0035, -0.0348,  0.0279,
           0.0346,  0.0485,  0.0110,  0.0020, -0.0161, -0.0078, -0.0630,
          -0.0265, -0.0172,  0.0715,  0.0302, -0.0890,  0.0183, -0.0018,
          -0.0184,  0.0354, -0.0040,  0.0428, -0.05

In [5]:
e.load_weights(path='.checkpoints/imagebind_huge.pth')