In [None]:
from google.cloud import storage
from PIL import Image
import numpy as np
import hashlib
import os
import sys
import torch
import io
import cv2

import torch

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm

import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.tf_record_reader as tfrr

NUM_KEYPOINTS = 9

In [None]:
def glob_objectron():
    client = storage.Client.create_anonymous_client()
    blobs = client.list_blobs('objectron',
                              prefix='v1/records_shuffled/cup/cup_train')
    return [blob.name for blob in blobs]

In [None]:
glob_objectron()

In [None]:
def decode(example):
  """ Convert TFRecord Entry into torch-compatible format """
  w = example['image/width'].item()
  h = example['image/height'].item()
  points = example['point_2d'].numpy()
  num_instances = example['instance_num'].item()
  points = points.reshape(num_instances, NUM_KEYPOINTS, 3)
  image_data = example['image/encoded'].numpy().tobytes()
  image = Image.open(io.BytesIO(image_data))
  npa = np.asarray(image)
  return torch.from_numpy(npa), points, num_instances


In [None]:
%matplotlib inline
# https://stackoverflow.com/questions/11159436/multiple-figures-in-a-single-window
import matplotlib.pyplot as plt


In [None]:
plt.plot([0,0], [1,1])
plt.show()

In [None]:

RADIUS = 10

colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (128, 128, 0), (128, 0, 128), 
          (0, 128, 128), (255, 255, 255), (0, 0, 0)]
EDGES = [
  [1, 5], [2, 6], [3, 7], [4, 8],  # lines along x-axis
  [1, 3], [5, 7], [2, 4], [6, 8],  # lines along y-axis
  [1, 2], [3, 4], [5, 6], [7, 8]   # lines along z-axis
]  

def load_dataset():
    r = tfrr.TfRecordReader(path, compression='', transforms=transforms)
    for i in range(num_samples):
        example = r.read_example()
        img_tensor, keypoints, num_instances = decode(example)


def show_3d_bounding_box(path, num_samples):
  count = 0
  transforms = {}  
  r = tfrr.TfRecordReader(path, compression='', transforms=transforms)
  fig, ax = plt.subplots(1, 10, figsize = (12, 16))
  
  for i in range(num_samples):
    example = r.read_example()
    if not example: break
    img_tensor, keypoints, num_instances = decode(example)
    image_clone = img_tensor

    for object_id in range(num_instances):
      w = 480
      h = 640
      for kp_id in range(NUM_KEYPOINTS):
        kp_pixel = keypoints[object_id, kp_id, :]
        cv2.circle(image_clone.numpy(), (int(w  * kp_pixel[0]), int(h * kp_pixel[1])), 
                  RADIUS, colors[object_id % len(colors)], -1)
      for edge in EDGES:
        start_kp = keypoints[object_id, edge[0], :]
        start_x = int(w * start_kp[0])
        start_y = int(h * start_kp[1])
        
        end_kp = keypoints[object_id, edge[1], :]
        end_x = int(w * end_kp[0])
        end_y = int(h * end_kp[1])

        cv2.line(image_clone.numpy(), (start_x, start_y), (end_x, end_y), 
                  colors[object_id % len(colors)], 1)
    ax[i].grid(False)
    ax[i].imshow(image_clone);
    ax[i].get_xaxis().set_visible(False)
    ax[i].get_yaxis().set_visible(False)

  fig.tight_layout();
  plt.show()

In [None]:
training_shards = glob_objectron()
for i in range(5):
  shard_name = 'gs://objectron/' + training_shards[i]
  print(shard_name)
  # Visualize the bounding box on the first 10 sample from this shard.
  show_3d_bounding_box(path = shard_name, num_samples = 10)