In [None]:
!git clone https://github.com/siddk/voltron-robotics
%cd voltron-robotics
!pip install -e .


In [None]:
!pip install voltron-robotics

In [1]:
import voltron

voltron.available_models() #list of available models to choose from, ones specified in the paper are 'v-cond','v-dual','v-gen'

['v-cond',
 'v-dual',
 'v-gen',
 'v-cond-base',
 'r-mvp',
 'r-r3m-vit',
 'r-r3m-rn50']

Below is code for V-cond.

In [37]:
import torch
import torch.nn.functional as F
from torchvision.io import read_image
import voltron

# code for vcond, for single frame images (and optional language context)
# images tested from Yihe and behavior seem to have H,W or 784. preprocessed to H,W = 224, but when reshaping, can set to 784 as well and it still works

def get_embeddings(path, vcond, preprocess, vector_extractor):
  image_tensor = read_image(path)
  if image_tensor.shape[0] != 3:  # Check if there's an alpha channel, 4 for behavior imgs
      image_tensor = image_tensor[:3, :, :]  # Keep only the first three channels (RGB)

  img = preprocess(image_tensor)[None, ...].to("cuda")
  lang = [""] #Empty, can specify as needed for model

  # Extract both multimodal AND vision-only embeddings!
  multimodal_embeddings = vcond(img, lang, mode="multimodal")
  visual_embeddings = vcond(img, mode="visual")

  # Use the `vector_extractor` to output dense vector representations for downstream applications!
  #   => Pass this representation to model of your choice (object detector, control policy, etc.)
  representation = vector_extractor(multimodal_embeddings) #(1, 384) shape tensor

  return multimodal_embeddings, visual_embeddings, representation


def reshape_vcond(path, vcond, preprocess, vector_extractor):
  multimodal,visual,vec_rep = get_embeddings(path, vcond, preprocess, vector_extractor)
  multimodal = multimodal.unsqueeze(-1) # Add extra dim for interpolate
  visual = visual.unsqueeze(-1)
  vec_rep = vec_rep.unsqueeze(-1)
  H,W = 224,224 # Behavior images were shape [4,784,784] when I was testing, but preprocess makes them 224x224. changing H,W to 224 or 784 both work

  multimodal_resized = F.interpolate(multimodal, size=(H,W), mode='bilinear', align_corners=False)
  visual_resized = F.interpolate(visual, size=(H,W), mode='bilinear', align_corners=False)
  vec_rep = vec_rep.unsqueeze(-1)
  vec_rep_resized = F.interpolate(vec_rep, size=(H,W), mode='bilinear', align_corners=False)

  multimodel_og_dim = multimodal_resized[0, 0, :, :]
  visual_og_dim = visual_resized[0, 0, :, :]
  vec_rep_og_dim = vec_rep_resized[0, 0, :, :]

  return multimodel_og_dim, visual_og_dim, vec_rep_og_dim

path = "" # TODO: Specify the path(s) to your images

# Load a frozen Voltron (V-Cond) model & configure a vector extractor
vcond, preprocess = voltron.load("v-cond", device="cuda", freeze=True)
vector_extractor = voltron.instantiate_extractor(vcond)().to("cuda")
m,v,r = reshape_vcond(path, vcond, preprocess, vector_extractor)



torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])


Below is code for V-dual and V-gen.

In [34]:
from torchvision.transforms import ToTensor
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.io import read_image
import voltron

# V-dual and V-gen takes in a dual frame (+optional lang). Tensor must be dim5 and stacked by pairs (shape[5,2, ., ., .])
# Looks at image PAIRS, so when uploading images make sure there is an even number

def process_img(path):
  image_tensor = read_image(path)
  if image_tensor.shape[0] != 3:
      image_tensor = image_tensor[:3, :, :]
  img = preprocess(image_tensor)[None, ...].to("cuda")
  return img

def create_image_pairs(image_paths):
  batch = []

  # Process and pair images
  for i in range(0, len(image_paths), 2):
      img1 = process_img(image_paths[i])
      img2 = process_img(image_paths[i + 1])
      pair = torch.cat([img1, img2], dim=0)  # Stack images (dual pairs)
      batch.append(pair)

  # Stack all pairs to form the final batch tensor (should be 5 dimensions)
  batch_tensor = torch.stack(batch, dim=0).to("cuda")
  return batch_tensor

def reshape_v(m,v,r):
  multimodal = m.unsqueeze(-1) # Add extra dim for interpolate
  visual = v.unsqueeze(-1)
  vec_rep = r.unsqueeze(-1)
  H,W = 224,224 # Behavior images were shape [4,784,784] when i was testing, but preprocess makes them 224x224. changing H,W to 224 or 784 both work

  multimodal_resized = F.interpolate(multimodal, size=(H,W), mode='bilinear', align_corners=False)
  visual_resized = F.interpolate(visual, size=(H,W), mode='bilinear', align_corners=False)
  vec_rep = vec_rep.unsqueeze(-1)
  vec_rep_resized = F.interpolate(vec_rep, size=(H,W), mode='bilinear', align_corners=False)

  multimodel_og_dim = multimodal_resized[0, 0, :, :]
  visual_og_dim = visual_resized[0, 0, :, :]
  vec_rep_og_dim = vec_rep_resized[0, 0, :, :]

  return multimodel_og_dim, visual_og_dim, vec_rep_og_dim

def get_reshaped_tensors(list_of_images, model_name):
  batch_tensor = create_image_pairs(list_of_images) #already set device to cuda
  model, preprocess = voltron.load(model_name, device="cuda", freeze=True)
  vector_extractor = voltron.instantiate_extractor(model)().to("cuda")

  lang = [""] # Empty, can specify as needed for model

  # Extract both multimodal AND vision-only embeddings!
  multimodal_embeddings = model(batch_tensor, lang, mode="multimodal")
  visual_embeddings = model(batch_tensor, mode="visual")

  # Use the `vector_extractor` to output dense vector representations for downstream applications!
  #   => Pass this representation to model of your choice (object detector, control policy, etc.)
  representation = vector_extractor(multimodal_embeddings)
  m,v,r = reshape_v(multimodal_embeddings,visual_embeddings,representation) # Reshape to original H,W
  return m,v,r

# TODO: need to define list_of_images. can be a folder of imgs, just make sure there is an even # of imgs
list_of_images = [] # Cannot be empty otherwise error
model_names = ["v-dual", "v-gen"] # Models that use dual frames
multimodal, visual, representation = get_reshaped_tensors(list_of_images, model_names[0])


torch.Size([224, 224])
torch.Size([224, 224])
torch.Size([224, 224])
