In [1]:
import omegaconf
import hydra
import torch
import torchvision.transforms as T
import numpy as np
from PIL import Image

from r3m import load_r3m

In [4]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

r3m = load_r3m("resnet50", load_path='/root/model/r3m/r3m-rn50')
r3m.eval()
r3m.to(device)



DataParallel(
  (module): R3M(
    (cs): CosineSimilarity()
    (bce): BCELoss()
    (sigm): Sigmoid()
    (convnet): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1

In [5]:
## DEFINE PREPROCESSING
transforms = T.Compose([T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor()]) # ToTensor() divides by 255

## ENCODE IMAGE
image = np.random.randint(0, 255, (500, 500, 3))
preprocessed_image = transforms(Image.fromarray(image.astype(np.uint8))).reshape(-1, 3, 224, 224)
preprocessed_image.to(device) 

tensor([[[[0.3176, 0.3294, 0.5569,  ..., 0.4000, 0.4941, 0.4980],
          [0.4627, 0.4667, 0.6000,  ..., 0.2824, 0.3059, 0.4941],
          [0.6863, 0.5843, 0.7020,  ..., 0.5098, 0.5255, 0.4706],
          ...,
          [0.5569, 0.3765, 0.4314,  ..., 0.4392, 0.4510, 0.5216],
          [0.4353, 0.3176, 0.7137,  ..., 0.6078, 0.4314, 0.6196],
          [0.3725, 0.4118, 0.4667,  ..., 0.4549, 0.5843, 0.6353]],

         [[0.3725, 0.5255, 0.5608,  ..., 0.3882, 0.5216, 0.6118],
          [0.5098, 0.4314, 0.4588,  ..., 0.4941, 0.5647, 0.6510],
          [0.4549, 0.3882, 0.4392,  ..., 0.3490, 0.5255, 0.5961],
          ...,
          [0.4549, 0.4157, 0.2863,  ..., 0.6078, 0.4902, 0.5333],
          [0.4039, 0.3647, 0.3294,  ..., 0.5098, 0.4627, 0.6706],
          [0.4118, 0.4667, 0.5059,  ..., 0.5451, 0.4588, 0.3882]],

         [[0.3843, 0.3725, 0.5137,  ..., 0.5294, 0.5098, 0.5804],
          [0.5804, 0.4824, 0.3059,  ..., 0.5804, 0.4902, 0.5255],
          [0.6118, 0.5176, 0.4275,  ..., 0

In [6]:
with torch.no_grad():
  embedding = r3m(preprocessed_image * 255.0) ## R3M expects image input to be [0-255]
print(embedding.shape) # [1, 2048]

torch.Size([1, 2048])
