In [1]:
import matplotlib.pyplot as plt
import numpy as np
from model_util_scannet_CIL_37 import ScannetDatasetConfig
import random
from memory_bank_object import Memory_Bank_Object
SEED = 42
random.seed(SEED)

In [2]:
# load object_reservoir from 'object_reservoir.pth'
with open('object_reservoir.pth', 'rb') as f:
    object_reservoir = np.load(f, allow_pickle=True)

# initialize config
config = ScannetDatasetConfig()

In [3]:
# get a set of all object classes in object reservoir
object_classes = set()
for obj in object_reservoir:
    object_classes.add(obj['object_class'])
print(object_classes)

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}


In [4]:
len(object_reservoir) # num objects 29579

num_objects = [6721, 2026, 1985, 1554, 1427, 1271, 1255, 928, 745, 661, 657, 551, 486, 481, 406, 390, 386, 364, 307, 300, 292, 279, 253, 251, 216, 201, 190, 186, 177, 170, 116, 113, 52, 39, 32, 22]
print(sum(num_objects)) # 29579 # 25490
print(len(num_objects)) # 36

25490
36


In [5]:
# # randomly select an object from the reservoir
# object_idx = 7

# object_0 = object_reservoir[object_idx]
# for _key in object_0.keys():
#     # print key and value if the key is not 'object_point_cloud'
#     if _key != 'object_point_cloud':
#         print(_key, object_0[_key])
#         if _key == 'object_class':
#             print(config.class2type[object_0[_key]])
#     else:
#         print(_key, object_0[_key].shape)
# # visualize object_0['object_point_cloud']
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# ax.scatter(object_0['object_point_cloud'][:,0], object_0['object_point_cloud'][:,1], object_0['object_point_cloud'][:,2])
# plt.show()

In [6]:
TOTAL_BUDGET = 1000
LOAD_PATH = 'object_reservoir.pth'
mbo = Memory_Bank_Object(total_budget=TOTAL_BUDGET, load_path=LOAD_PATH)
obj_idx_check = 88

In [7]:
mbo.update_memory(budget=-1, classes=[0, 1, 2, 3, 4, 5, 6], criteria='random')
# print the length of the memory bank
print(len(mbo))
# print the 8th object in the memory bank
print(mbo[obj_idx_check])

Adding 1000 objects to the memory bank.
1000
{'scene_name': 'scene0310_00', 'object_id': 39, 'object_class': 0, 'object_point_cloud': array([[  1.1953244,   1.3193747,   1.1298935, 206.       , 190.       ,
        178.       ],
       [  1.1702533,   1.3535439,   0.9768777, 110.       , 114.       ,
        115.       ],
       [  1.1036168,   1.2563448,   0.9998543, 103.       ,  95.       ,
        109.       ],
       ...,
       [  1.095832 ,   1.3010659,   1.1273091, 153.       , 145.       ,
        150.       ],
       [  1.0963621,   1.3119166,   1.0237806, 126.       , 123.       ,
        128.       ],
       [  1.0927912,   1.330611 ,   1.1266642, 136.       , 132.       ,
        138.       ]], dtype=float32)}


In [8]:
# create a new numpy array of the shape mbo.one_hot_mask and fill it with values from mbo.one_hot_mask. Do not fill it will all zeros.
one_hot_mask = np.zeros(mbo.one_hot_mask.shape)
one_hot_mask = mbo.one_hot_mask.copy()

In [9]:
mbo.update_memory(budget=-1, classes=[7,8,9,10,11,12,13], criteria='random')
# print the length of the memory bank
print(len(mbo))
# print the 8th object in the memory bank
# print(mbo[obj_idx_check])

Adding 500 objects to the memory bank.
1000


In [10]:
mask_stage_1 = mbo.one_hot_mask

In [18]:
# in one_hot_mask and mask_stage_1, find the indices where both are 1.
indices = np.where((one_hot_mask == 0) & (mask_stage_1 == 1))
print(len(indices[0])) # 0

500


In [19]:
mbo.update_memory(budget=-1, classes=[14,15,16,17,18,19,20], criteria='random')
# print the length of the memory bank
print(len(mbo))
# print the 8th object in the memory bank
# print(mbo[obj_idx_check])

Adding 333 objects to the memory bank.
1000


In [20]:
mask_stage_2 = mbo.one_hot_mask

In [24]:
# in one_hot_mask and mask_stage_1, find the indices where both are 1.
indices = np.where((one_hot_mask == 1) & (mask_stage_2 == 1))
print(len(indices[0])) # 0

331


In [10]:
mbo.update_memory(budget=-1, classes=[21,22,23,24,25,26,27], criteria='random')
# print the length of the memory bank
print(len(mbo))
# print the 8th object in the memory bank
print(mbo[obj_idx_check])

Adding 250 objects to the memory bank.
1000
{'scene_name': 'scene0359_00', 'object_id': 3, 'object_class': 19, 'object_point_cloud': array([[ 5.0468940e-01,  8.3762985e-01,  6.0741168e-01,  1.3400000e+02,
         9.7000000e+01,  5.6000000e+01],
       [ 2.9532003e-01, -1.2838641e+00,  5.9754020e-01,  1.4900000e+02,
         1.1500000e+02,  8.4000000e+01],
       [ 2.6913908e-01,  8.1671840e-01,  1.1035556e-03,  5.1000000e+01,
         4.9000000e+01,  3.9000000e+01],
       ...,
       [ 4.4670415e-01,  1.3078301e-01,  8.1197667e-01,  4.1000000e+01,
         3.6000000e+01,  2.9000000e+01],
       [ 9.1425705e-01,  5.1757568e-01, -4.3650065e-03,  5.0000000e+01,
         4.6000000e+01,  3.6000000e+01],
       [ 4.6300647e-01,  6.4903134e-01,  4.6623667e-04,  4.6000000e+01,
         4.3000000e+01,  3.3000000e+01]], dtype=float32)}


In [11]:
mbo.update_memory(budget=-1, classes=[28,29,30,31,32,33,34], criteria='random')
# print the length of the memory bank
print(len(mbo))
# print the 8th object in the memory bank
print(mbo[obj_idx_check])

Adding 200 objects to the memory bank.
1000
{'scene_name': 'scene0280_01', 'object_id': 4, 'object_class': 19, 'object_point_cloud': array([[  2.2629323 ,   1.8884851 ,   0.9729849 , 196.        ,
        189.        , 177.        ],
       [  1.937362  ,   1.3635402 ,   0.5903004 , 144.        ,
        110.        , 132.        ],
       [  1.7024807 ,   0.43043536,   0.51419413,  72.        ,
         34.        ,  71.        ],
       ...,
       [  2.0084867 ,   1.9972068 ,   0.6695033 ,  74.        ,
         69.        ,  62.        ],
       [  2.2880542 ,   0.5738802 ,   0.5614863 , 151.        ,
        155.        , 142.        ],
       [  2.021271  ,   1.1111015 ,   0.6105524 ,  35.        ,
         34.        ,  41.        ]], dtype=float32)}
