In [1]:
import cv2
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
from mpi import load_mpi



In [2]:
root_dir = "/root/model/mpi/mpi-small"
language_model_path = "/root/model/distilbert-base-uncased"
device = "cuda:0"
model = load_mpi(root_dir, device, freeze=True, language_model_path=language_model_path)
model.eval()
model.to(device)

MPI(
  (patch2embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (encoder_blocks): ModuleList(
    (0-11): 12 x Block(
      (pre_norm_attn): RMSNorm()
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (layer_scale_attn): LayerScale()
      (pre_norm_mlp): RMSNorm()
      (mlp): Sequential(
        (0): SwishGLU(
          (act): SiLU()
          (project): Linear(in_features=384, out_features=3072, bias=True)
        )
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=1536, out_features=384, bias=True)
      )
      (layer_scale_mlp): LayerScale()
    )
  )
  (encoder_norm): RMSNorm()
  (temporal_attn): MultiScaleDeformableAttention(
    (sampling_offsets): Linear(in_features=384, out_features=48, bias=True)
    (attention_weights): Linear(in_feat

In [4]:
transforms = T.Compose([T.Resize(256),
                        T.CenterCrop(224),
                        T.ToTensor()])
image = cv2.imread("/root/code/BC-IB/third_party/methods/MPI/assets/example_franka_kitchen.jpg")
image = transforms(Image.fromarray(image.astype(np.uint8))).unsqueeze(0)
print(image)
visual_input = torch.stack((image, image), dim=1) # simply repeat the current observation in downstream downstask
visual_input = visual_input.to(device=device)
print(visual_input.shape)

tensor([[[[0.0706, 0.0588, 0.0627,  ..., 0.2039, 0.0863, 0.0667],
          [0.0706, 0.0902, 0.0667,  ..., 0.1608, 0.0745, 0.0745],
          [0.3725, 0.1922, 0.1020,  ..., 0.0588, 0.0706, 0.0706],
          ...,
          [0.5255, 0.5216, 0.5176,  ..., 0.7608, 0.7922, 0.8314],
          [0.5137, 0.5176, 0.5255,  ..., 0.7569, 0.7882, 0.8235],
          [0.5373, 0.5412, 0.5294,  ..., 0.7529, 0.7843, 0.8196]],

         [[0.0667, 0.0627, 0.0667,  ..., 0.1961, 0.0784, 0.0588],
          [0.0510, 0.0863, 0.0549,  ..., 0.1529, 0.0667, 0.0667],
          [0.3294, 0.1490, 0.0667,  ..., 0.0510, 0.0627, 0.0627],
          ...,
          [0.5255, 0.5216, 0.5176,  ..., 0.8118, 0.8431, 0.8824],
          [0.5137, 0.5176, 0.5255,  ..., 0.8078, 0.8392, 0.8745],
          [0.5373, 0.5412, 0.5294,  ..., 0.8039, 0.8353, 0.8706]],

         [[0.0510, 0.0471, 0.0510,  ..., 0.1922, 0.0745, 0.0549],
          [0.0392, 0.0706, 0.0392,  ..., 0.1490, 0.0627, 0.0627],
          [0.3137, 0.1333, 0.0510,  ..., 0

In [5]:
lang_input = ["turn on the knob", ]
embedding_with_lang_tokens = model.get_representations(visual_input, lang_input, with_lang_tokens = True)
embedding_without_lang_tokens = model.get_representations(visual_input, None, with_lang_tokens = False)
print(embedding_with_lang_tokens.shape, embedding_without_lang_tokens.shape) # (1, 218, 384), (1, 197, 384)

torch.Size([1, 218, 384]) torch.Size([1, 197, 384])
