In [None]:
%load_ext autoreload
%autoreload 2
from data_loader_h5 import H5Dataset
from data_loader_jsonl import JSONLDataset

train_dataset = JSONLDataset(jsonl_file_path="data/cvla-droid-block-v3")
print(len(train_dataset))

In [None]:
# Augment depth - change suffix

import numpy as np
import re
import random
import matplotlib.pyplot as plt

from torchvision.transforms import v2

depth_range = (25, 100)
max_shift = 35

# get first data item, get entry dictionary
# suffix = (12) = obj_start (h w d rv0 rv1 rv2) obj_end (h w d rv0 rv1 rv2)
suffix = train_dataset[0][1]["suffix"]
print("suffix", suffix)

# convert string tokens to int
# depth in cm, x and y in [0, 1023]
loc_strings = [int(x) for x in re.findall(r"<(?:loc|seg)(\d+)>", suffix)]
loc_strings = np.array(loc_strings)
loc_strings = loc_strings.reshape(-1, 6)
print("to integer", loc_strings)

# get min depth of start and end objects
depth_cm = loc_strings[:,2]
min_depth = min(depth_cm)
local_depth_range = [min_depth - max_shift,  min_depth + max_shift]
local_depth_range = np.clip(local_depth_range, *depth_range)
delta_depth_cm = min_depth - random.randint(local_depth_range[0], local_depth_range[1])

# randomly change the depth
loc_strings[:,2] += delta_depth_cm
suffix_new = ""
for x in loc_strings:
    suffix_new += "<loc{a[0]:04d}><loc{a[1]:04d}><loc{a[2]:04d}><seg{a[3]:03d}><seg{a[4]:03d}><seg{a[5]:03d}>".format(a=x)

print("depth", depth_cm)
print("min depth", min_depth)
print("delta depth", delta_depth_cm)
print("new suffix", suffix_new)

In [None]:
# Crop image, change suffix
imginfo = lambda img: print(type(img), img.dtype, img.shape, img.min(), img.max())

# get first data item, get image
image = train_dataset[0][0]
image_width, image_height = image.size[:2] # 1280 x 720
print("image size", image_width, image_height)


# get from maniskill (1024, 1024) to (image_height, image_width)
loc_h = (loc_strings[:, 0]/(1024-1)*image_height).round().astype(int) # (obj_start_x obj_end_x)
loc_w = (loc_strings[:, 1]/(1024-1)*image_width).round().astype(int)

# Image crop at obj_start
crop_size = int(image_width/3)
crop_half = crop_size // 2
top = loc_h[0]-crop_half
left = loc_w[0]-crop_half
crop_shape = (top, left, crop_size, crop_size) # yxhw
crop = v2.functional.crop(image, *crop_shape)
print(crop_shape)

# add delta to x y, go from (image_height, image_width) to maniskill (1024, 1024)
loc_strings[:, 0] = ((loc_h - top)/image_height*(1024-1)).round().astype(int)
loc_strings[:, 1] = ((loc_w - left)/image_width*(1024-1)).round().astype(int)

suffix_new = ""
for x in loc_strings:
    suffix_new += "<loc{a[0]:04d}><loc{a[1]:04d}><loc{a[2]:04d}><seg{a[3]:03d}><seg{a[4]:03d}><seg{a[5]:03d}>".format(a=x)

print(suffix_new)

plot_width, plot_height = 448, 448
dpi = 100
figsize = (plot_width / dpi, plot_height / dpi)
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
ax.imshow(crop)
ax.scatter(loc_strings[:, 1]/(1024-1)*image_width, loc_strings[:, 0]/(1024-1)*image_height, color='green')
ax.axis("off")