### TODOs
- Adjust the object scoring function to:
    - score directly on the 3DObject class once it is created from Matcher
- Write a plotting function for ObjectInstance debugging

In [99]:
%load_ext autoreload
%autoreload 2
import numpy as np
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import plotly.io as pio
import cv2

# Ensure Plotly is set up to work with Jupyter notebooks
pio.renderers.default = 'notebook'

# Hack the path for now, deal with this later
cwd = '/teamspace/studios/this_studio/letsdoit'
if (cwd not in sys.path):
    sys.path.append(cwd)

from masks_finder import MasksFinder
from masks_matcher import MasksMatcher
from clip_retriever import ClipRetriever
from dataloader.dataloader import DataLoader
from object_scorer import ObjectScorer
from letsdoit.utils.object_instance import ObjectInstance, initialize_object_instances
from letsdoit.utils.misc import select_ids

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [100]:
path_dataset = '/teamspace/studios/this_studio/datasets'
ASSET_TYPE = 'wide'
loader = DataLoader(path_dataset, split='dev')
retriever = ClipRetriever()
masks_finder = MasksFinder()
visit_ids = loader.visit_ids
video_ids = loader.get_video_ids(visit_ids[0])

visit_id = visit_ids[0]
video_id = video_ids[0]
images, image_paths, intrinsics, poses, orientations = loader.get_images(visit_id, video_id, asset_type=ASSET_TYPE, sample_freq=1)
depths, depth_paths, _, _, _ = loader.get_depths(visit_id, video_id, asset_type=ASSET_TYPE, sample_freq=1)
retriever.generate_image_features(images)

final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight'])
<All keys matched successfully>


Loading rgb frames from visit 420683 and video 42445132: 100%|██████████| 159/159 [00:07<00:00, 21.27it/s]
Loading depth frames from visit 420683 and video 42445132: 100%|██████████| 159/159 [00:05<00:00, 31.23it/s]


In [101]:
# For an object, get a list of corresponding ObjectInstances
object = 'drawer'
best_indices = retriever.retrieve_best_images_for_object(object)
best_images = select_ids(images, best_indices)
best_image_paths = select_ids(image_paths, best_indices)
best_intrinsics = select_ids(intrinsics, best_indices)
best_poses = select_ids(poses, best_indices)
best_orientations = select_ids(orientations, best_indices)
best_depths = select_ids(depths, best_indices)
best_depth_paths = select_ids(depth_paths, best_indices)
image_ids, masks, bboxes, confidences, labels = masks_finder.get_masks_from_imgs(best_images, object)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:11<00:00,  1.20s/it]


In [115]:
dict_object_instances = {'images': select_ids(best_images, image_ids),
                         'depths': select_ids(best_depths, image_ids),
                         'bboxes': bboxes,
                         'masks': masks,
                         'labels': labels,
                         'confidences': confidences,
                         'intrinsics': select_ids(best_intrinsics, image_ids),
                         'extrinsics': select_ids(best_poses, image_ids),
                         'orientations': select_ids(best_orientations, image_ids)}

object_instances = initialize_object_instances(**dict_object_instances)

In [116]:
object_instances.__len__()

37

In [137]:
# Plot mask 2D (easy, just copy previous code)
# Plot mask 3D (harder, plot mask projection to world's coordinates using depth)
obj = object_instances[35]
obj2 = object_instances[1]
# obj.plot_2d()

In [138]:
obj.plot_3d()

In [139]:
img_unrotated = cv2.rotate(obj.image, cv2.ROTATE_90_COUNTERCLOCKWISE)
depth_unrotated = cv2.rotate(obj.depth, cv2.ROTATE_90_COUNTERCLOCKWISE)
mask_unrotated = np.rot90(obj.mask, k=1, axes=(0, 1))
sketchy_obj_unrotated = ObjectInstance(image=img_unrotated,
                                       depth=depth_unrotated,
                                       bbox=obj.bbox,
                                       mask=mask_unrotated,
                                       label=obj.label,
                                       confidence=obj.confidence,
                                       intrinsic=obj.intrinsic,
                                       extrinsic=obj.extrinsic,
                                       orientation=0)

In [140]:
img_unrotated2 = cv2.rotate(obj2.image, cv2.ROTATE_90_COUNTERCLOCKWISE)
depth_unrotated2 = cv2.rotate(obj2.depth, cv2.ROTATE_90_COUNTERCLOCKWISE)
mask_unrotated2 = np.rot90(obj2.mask, k=1, axes=(0, 1))
sketchy_obj_unrotated2 = ObjectInstance(image=img_unrotated2,
                                       depth=depth_unrotated2,
                                       bbox=obj2.bbox,
                                       mask=mask_unrotated2,
                                       label=obj2.label,
                                       confidence=obj2.confidence,
                                       intrinsic=obj2.intrinsic,
                                       extrinsic=obj2.extrinsic,
                                       orientation=0)

In [141]:
sketchy_obj_unrotated.plot_3d()

In [109]:
import plotly.graph_objects as go

fig = go.Figure()

num_points = sketchy_obj_unrotated.mask_3d.shape[1]

# Subsample for viz, if too many points:
if num_points > 1e4:
    subsample_points = num_points // 100
    indices = np.linspace(0, sketchy_obj_unrotated.mask_3d.shape[1] - 1, subsample_points, dtype=int)
    mask_3d = sketchy_obj_unrotated.mask_3d[:, indices]
else:
    mask_3d = sketchy_obj_unrotated.mask_3d

mask_color = np.random.rand(3,)
fig.add_trace(go.Scatter3d(
    x=mask_3d[0, :],
    y=mask_3d[1, :],
    z=mask_3d[2, :],
    mode='markers',
    marker=dict(size=2, color=mask_color, opacity=0.8),
    name='Mask Points'
))

center_3d = np.expand_dims(sketchy_obj_unrotated.center_3d, -1)
center_color = np.random.rand(3,)
fig.add_trace(go.Scatter3d(
    x=center_3d[0, :],
    y=center_3d[1, :],
    z=center_3d[2, :],
    mode='markers',
    marker=dict(size=5, color=center_color, opacity=1.0),
    name='Mask Center'
))

num_points = sketchy_obj_unrotated2.mask_3d.shape[1]

# Subsample for viz, if too many points:
if num_points > 1e4:
    subsample_points = num_points // 100
    indices = np.linspace(0, sketchy_obj_unrotated2.mask_3d.shape[1] - 1, subsample_points, dtype=int)
    mask_3d = sketchy_obj_unrotated2.mask_3d[:, indices]
else:
    mask_3d = sketchy_obj_unrotated2.mask_3d

mask_color = np.random.rand(3,)
fig.add_trace(go.Scatter3d(
    x=mask_3d[0, :],
    y=mask_3d[1, :],
    z=mask_3d[2, :],
    mode='markers',
    marker=dict(size=2, color=mask_color, opacity=0.8),
    name='Mask Points'
))

center_3d = np.expand_dims(sketchy_obj_unrotated2.center_3d, -1)
center_color = np.random.rand(3,)
fig.add_trace(go.Scatter3d(
    x=center_3d[0, :],
    y=center_3d[1, :],
    z=center_3d[2, :],
    mode='markers',
    marker=dict(size=5, color=center_color, opacity=1.0),
    name='Mask Center'
))