In [None]:
import torch 
from torch.utils.data import Dataset
import os
from tqdm import tqdm
import json
import numpy as np
import matplotlib.pyplot as plt
import random
import cv2
from learning.datatools import augment_bounding_box, augment_rgb, random_dropout, voxel_normalize, pad_collate_fn
from learning.models import PointNetPlusPlusMSG
from blender.utils import get_visible_objects

  warn(f"Failed to load image Python extension: {e}")


In [23]:

STORE_IN_RAM=True
class ApplePointCloudDataset(Dataset):
    def __init__(self, data_root: str, manifest_path: str, config: dict={}, augment=True):
        self.root = data_root
        self.augment = augment
        self.records = []
        self.voxel_size = config.get("voxel_size", 0.045)
        self.percentile = config.get("percentile", 95)
        self.subset_size = config.get("subset_size", 1.0)
        self.normalize = config.get("normalize", True)

        with open(manifest_path) as f:
            scenes = [json.loads(line) for line in f]

        if self.subset_size < 1.0:
            np.random.seed(config['SEED'])
            np.random.shuffle(scenes)
            scenes = scenes[:int(len(scenes) * self.subset_size)]
        print(f"Loading {len(scenes)} scenes from {manifest_path} …")

        for scene_i, scene in enumerate(scenes):
            stem = scene["stem"]
            
            
            for apple_i, (bbox, center, occ_rate) in enumerate(zip(scene["boxes"], scene["centers"], scene["occ_rates"])):
                center[1] = -center[1]  # flip y-axis to match the point cloud
                self.records.append({
                    "stem": stem,
                    "bbox": bbox,
                    "occ_rate": occ_rate,
                    "center": center,
                })
            # plt.imshow(instance_mask)
            # plt.show()
            print(f"Processed scene {scene_i+1}/{len(scenes)}: {stem} with {len(scene['boxes'])} apples.")
        if STORE_IN_RAM:
            print(f"Pre-loading {len(self.records)} apples into RAM …")
            stems_to_records = {}
            for r in self.records:
                if r["stem"] not in stems_to_records:
                    stems_to_records[r["stem"]] = []
                stems_to_records[r["stem"]].append(r)
            self.records = []
            for stem, recs in stems_to_records.items():
                xyz, rgb = self._load_scene_xyzrgb(stem)

                visible_objs, id_mask, instance_mask, id_to_name = get_visible_objects(
                        exr_path=os.path.join(self.root, f"{stem}_id0000.exr"),
                        id_mapping_path=os.path.join(self.root, f"{stem}_id_map.json"),
                        conditional=lambda id, name: 'apple' in name and 'stem' not in name
                    )
                
                instance_mask = (instance_mask > 0).astype(np.float32)
                # mask out rgb
                rgb = rgb * instance_mask[:, :, np.newaxis]
                # mask out xyz
                xyz = xyz * instance_mask[:, :, np.newaxis]
                # plt.imshow(rgb/255)
                # # plt.imshow(instance_mask, alpha=0.5)
                # plt.title(f"Stem: {stem}, Apples: {len(recs)}")
                # plt.show()
                for r in recs:
                    r["sample"] = self._build_sample(r, xyz=xyz, rgb=rgb)
                    self.records.append(r)
                # print("Loaded all samples for stem", stem)
    
    def _load_scene_xyzrgb(self, stem: str):
        """Load (or zip-cache) full scene xyzrgb array."""
        zipped = os.path.join(self.root, "zipped", f"{stem}.npz")
        try:
            with np.load(zipped) as data:
                return data["xyz"], data["rgb"]
        except Exception:
            xyz = np.load(os.path.join(self.root, f"{stem}_pc.npy"))
            rgb = cv2.cvtColor(
                cv2.imread(os.path.join(self.root, f"{stem}_rgb0000.png")),
                cv2.COLOR_BGR2RGB)
            os.makedirs(os.path.dirname(zipped), exist_ok=True)
            np.savez_compressed(zipped, xyz=xyz, rgb=rgb)
            return xyz, rgb
    def _build_sample(self, rec, xyz=None, rgb=None):
        """Creates (pc, center, meta) for one apple."""
        stem, bbox, center, occ = \
            rec["stem"], rec["bbox"], rec["center"], rec["occ_rate"]
        if xyz is None or rgb is None:
            xyz, rgb  = self._load_scene_xyzrgb(stem)
        xyzrgb    = np.concatenate((xyz, rgb), axis=2)

        if self.augment:
            bbox = augment_bounding_box(bbox)

        x1, y1, x2, y2 = map(int, bbox)
        crop = xyzrgb[min(y1,y2):max(y1,y2), min(x1,x2):max(x1,x2)]
        crop[:, :, 3:] = augment_rgb(crop[:, :, 3:]) if self.augment else crop[:, :, 3:]

        pc = crop.reshape(-1, 6)
        pc = pc[~((np.abs(pc[:,2]) < .45) | (np.abs(pc[:,2]) > 2.75))]
        pc = pc[~np.isnan(pc).any(1)]
        pc = pc[~np.isinf(pc).any(1)]
        if self.augment: pc = random_dropout(pc, (0.3, 0.7))
        if self.normalize:
            norm_pc, norm_ctr, scale = voxel_normalize(
                pc[:, :3], voxel_size=self.voxel_size, percentile=self.percentile)
        else:
            norm_pc = pc[:, :3].copy()
            norm_ctr = np.array([0.0, 0.0, 0.0])
            scale = 1.0
        pc[:, :3] = norm_pc
        # normalize rgb channels to [0, 1]
        pc[:, 3:6] = pc[:, 3:6] / 255.0
        center_t  = ((torch.tensor(center) - norm_ctr)/scale).float()
        meta = dict(stem=stem, bbox=bbox, occ_rate=occ,
                    norm_center=norm_ctr, norm_scale=scale, orig_center=torch.tensor(center))
        return pc.astype(np.float32), center_t, meta
    
    def __getitem__(self, idx):
        rec = self.records[idx]
        if STORE_IN_RAM:
            return rec["sample"]          # already built & cached
        return self._build_sample(rec)
    def __len__(self):
        return len(self.records)




In [24]:
config = {
    'voxel_size': 0.0045,  # default voxel size for normalization
    'percentile': 95,     # default percentile for normalization
    'subset_size': 0.0025,   # use all data
    'normalize': False,
    'SEED': 42,  # for reproducibility
}

train_ds = ApplePointCloudDataset(
    data_root='/home/siddhartha/RIVAL/learning2localize/blender/dataset/raw/apple_orchard-5-20-fp-only',
    manifest_path='/home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl',
    augment=False,
    config=config
)

test_ds = ApplePointCloudDataset(
    data_root='/home/siddhartha/RIVAL/learning2localize/blender/dataset/raw/apple_orchard-5-20-fp-only',
    manifest_path='/home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/test.jsonl',
    augment=False,
    config=config
)

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
Loading 3 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/test.jsonl …
Processed scene 1/3: 6f77f32a-699e-4ce8-93b9-4f2dfa4e5d41 with 3 apples.
Processed scene 2/3: f9eaf674-fd24-4b5e-ba78-66c19020a22e with 1 apples.
Processed scene 3/3: f94bfa33-44ec-4738-9685-04012b5a4db2 with 1 apples.
Pre-loading 5 apples into RAM …


In [25]:
import plotly.graph_objects as go


percentiles = [50, 75, 90, 95, 99]
for percentile in percentiles:
    test_config = config.copy()
    test_config['percentile'] = percentile
    test_config['normalize'] = True  # always normalize for visualization
    ds = ApplePointCloudDataset(
        data_root='/home/siddhartha/RIVAL/learning2localize/blender/dataset/raw/apple_orchard-5-20-fp-only',
        manifest_path='/home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl',
        augment=True,
        config=test_config
    )
    dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=True, collate_fn=pad_collate_fn)
    for pc, center, mask, meta in dl:
        pc = pc[0].squeeze(0).numpy()
        center = center[0].squeeze(0).numpy()
        #drop zero 
        pc = pc[~np.all(pc == 0, axis=1)]


        print(f"PC shape: {pc.shape}, Center: {center}, Meta: {meta}")
        # show point cloud using go

        fig = go.Figure()
        fig.add_trace(go.Scatter3d(
            x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
            mode='markers',
            marker=dict(size=2, color=pc[:, 3:6])
        ))
        fig.add_trace(go.Scatter3d(
            x=[center[0]], y=[center[1]], z=[center[2]],
            mode='markers',
            marker=dict(size=5, color='red'),
            name='Center'
        ))
        fig.update_layout(title=f"Point Cloud with Percentile {percentile}")
        fig.show()
        break
    test_config['normalize'] = False  # disable normalization for visualization
    ds = ApplePointCloudDataset(
        data_root='/home/siddhartha/RIVAL/learning2localize/blender/dataset/raw/apple_orchard-5-20-fp-only',
        manifest_path='/home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl',
        augment=True,
        config=test_config
    )
    dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=True, collate_fn=pad_collate_fn)
    for pc, center, mask, meta in dl:
        pc = pc[0].squeeze(0).numpy()
        center = center[0].squeeze(0).numpy()
        #drop zero 
        pc = pc[~np.all(pc == 0, axis=1)]


        print(f"PC shape: {pc.shape}, Center: {center}, Meta: {meta}")
        # show point cloud using go

        fig = go.Figure()
        fig.add_trace(go.Scatter3d(
            x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
            mode='markers',
            marker=dict(size=2, color=pc[:, 3:6])
        ))
        fig.add_trace(go.Scatter3d(
            x=[center[0]], y=[center[1]], z=[center[2]],
            mode='markers',
            marker=dict(size=5, color='red'),
            name='Center'
        ))
        fig.update_layout(title=f"Un-normalized Point Cloud with Percentile {percentile}")
        fig.show()
        break

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (4096, 6), Center: [ 0.17501667 -0.17115146  0.9878926 ], Meta: {'stem': ['1904b255-a77e-4a0e-ace8-2063621bdfb5', 'd14c031f-e61b-4042-b49d-9bb4983fcff6'], 'bbox': [array([434, 611, 598, 467]), array([751, 519, 897, 397])], 'occ_rate': tensor([0.8017, 0.2885]), 'norm_center': [array([-0.08367842,  0.13290914, -1.07855656]), array([ 0.11961846,  0.06514136, -1.25789832])], 'norm_scale': tens

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (1192, 6), Center: [ 0.37074056  0.10650968 -1.7828145 ], Meta: {'stem': ['dfc5776d-2121-4c0e-a87a-00a9711802da', 'ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c'], 'bbox': [array([ 983,  477, 1036,  422]), array([  1, 411,  81, 334])], 'occ_rate': tensor([0.6988, 0.7068]), 'norm_center': [array([0., 0., 0.]), array([0., 0., 0.])], 'norm_scale': tensor([1., 1.]), 'orig_center': tensor([[ 0.3707,  0.

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (3280, 6), Center: [0.28651348 0.1883363  0.89601684], Meta: {'stem': ['1904b255-a77e-4a0e-ace8-2063621bdfb5', 'dfc5776d-2121-4c0e-a87a-00a9711802da'], 'bbox': [array([ 18, 230, 138, 115]), array([249, 262, 334, 186])], 'occ_rate': tensor([0.4625, 0.7520]), 'norm_center': [array([-0.30017265, -0.12285734, -0.94876478]), array([-0.34224878, -0.15478063, -1.75297076])], 'norm_scale': tensor(

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (369, 6), Center: [ 0.63529664 -0.08804248 -1.9781961 ], Meta: {'stem': ['1904b255-a77e-4a0e-ace8-2063621bdfb5', '1904b255-a77e-4a0e-ace8-2063621bdfb5'], 'bbox': [array([1185,  316, 1237,  271]), array([633, 483, 726, 387])], 'occ_rate': tensor([0.3706, 0.6215]), 'norm_center': [array([0., 0., 0.]), array([0., 0., 0.])], 'norm_scale': tensor([1., 1.]), 'orig_center': tensor([[ 0.6353, -0.0

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (3337, 6), Center: [-0.04282379 -0.008199    0.13011077], Meta: {'stem': ['dfc5776d-2121-4c0e-a87a-00a9711802da', '138dbf82-441b-4898-bfae-4065a18d5451'], 'bbox': [array([878, 553, 960, 473]), array([1256,  257, 1280,  210])], 'occ_rate': tensor([0.8007, 0.4260]), 'norm_center': [array([ 0.30390327,  0.19034194, -1.88418304]), array([ 0.79724905, -0.20053871, -2.25422204])], 'norm_scale': 

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (369, 6), Center: [ 0.63529664 -0.08804248 -1.9781961 ], Meta: {'stem': ['1904b255-a77e-4a0e-ace8-2063621bdfb5', 'dfc5776d-2121-4c0e-a87a-00a9711802da'], 'bbox': [array([1185,  316, 1237,  271]), array([383, 254, 445, 197])], 'occ_rate': tensor([0.3706, 0.1588]), 'norm_center': [array([0., 0., 0.]), array([0., 0., 0.])], 'norm_scale': tensor([1., 1.]), 'orig_center': tensor([[ 0.6353, -0.0

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (556, 6), Center: [-0.07438499  0.01593652  0.05911535], Meta: {'stem': ['dfc5776d-2121-4c0e-a87a-00a9711802da', '4a31a671-6ff1-4658-929d-efeb57dca187'], 'bbox': [array([  0, 526,  25, 471]), array([515, 517, 629, 410])], 'occ_rate': tensor([0.2495, 0.6955]), 'norm_center': [array([-0.84310678,  0.21491903, -2.38934221]), array([-0.03005575,  0.07374731, -0.91682183])], 'norm_scale': tenso

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (4096, 6), Center: [ 0.49132076  0.17387155 -1.4489166 ], Meta: {'stem': ['ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c', 'dfc5776d-2121-4c0e-a87a-00a9711802da'], 'bbox': [array([1189,  589, 1285,  491]), array([124, 559, 193, 494])], 'occ_rate': tensor([0.8320, 0.5729]), 'norm_center': [array([0., 0., 0.]), array([0., 0., 0.])], 'norm_scale': tensor([1., 1.]), 'orig_center': tensor([[ 0.4913,  0.

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (579, 6), Center: [-0.02085172 -0.01098124 -0.01339353], Meta: {'stem': ['dfc5776d-2121-4c0e-a87a-00a9711802da', 'ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c'], 'bbox': [array([ 17, 478,  63, 431]), array([  1, 411,  81, 334])], 'occ_rate': tensor([0.3292, 0.7068]), 'norm_center': [array([-0.76852615,  0.14663589, -2.28836453]), array([-0.43330534,  0.00883666, -1.26837986])], 'norm_scale': tenso

Loading 6 scenes from /home/siddhartha/RIVAL/learning2localize/blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Processed scene 1/6: 4a31a671-6ff1-4658-929d-efeb57dca187 with 3 apples.
Processed scene 2/6: 1904b255-a77e-4a0e-ace8-2063621bdfb5 with 6 apples.
Processed scene 3/6: ceef1f10-1d3c-4d6b-bf54-22adf6a69b7c with 3 apples.
Processed scene 4/6: dfc5776d-2121-4c0e-a87a-00a9711802da with 14 apples.
Processed scene 5/6: d14c031f-e61b-4042-b49d-9bb4983fcff6 with 2 apples.
Processed scene 6/6: 138dbf82-441b-4898-bfae-4065a18d5451 with 1 apples.
Pre-loading 29 apples into RAM …
PC shape: (3484, 6), Center: [-0.03474403  0.06281897 -0.90720534], Meta: {'stem': ['4a31a671-6ff1-4658-929d-efeb57dca187', '1904b255-a77e-4a0e-ace8-2063621bdfb5'], 'bbox': [array([515, 517, 629, 410]), array([1185,  316, 1237,  271])], 'occ_rate': tensor([0.6955, 0.3706]), 'norm_center': [array([0., 0., 0.]), array([0., 0., 0.])], 'norm_scale': tensor([1., 1.]), 'orig_center': tensor([[-0.0347,  0.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Utilities for point cloud processing (FPS, grouping)
def square_distance(src, dst):
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

def farthest_point_sample(xyz, npoint):
    B, N, _ = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(xyz.device)
    distance = torch.ones(B, N).to(xyz.device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(xyz.device)
    batch_indices = torch.arange(B, dtype=torch.long).to(xyz.device)

    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

def index_points(points, idx):
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape.append(-1)
    idx = idx.reshape(B, -1)  # ← fixed here
    res = torch.gather(points, 1, idx.unsqueeze(-1).expand(-1, -1, points.shape[-1]))
    return res.view(*view_shape)


def query_ball_point(radius, nsample, xyz, new_xyz):
    dist = square_distance(new_xyz, xyz)
    group_idx = dist.argsort()[:, :, :nsample]
    return group_idx

def sample_and_group(npoint, radius, nsample, xyz, points):
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint)
    new_xyz = index_points(xyz, fps_idx)
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx)
    grouped_xyz_norm = grouped_xyz - new_xyz.unsqueeze(2)
    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)
    else:
        new_points = grouped_xyz_norm
    return new_xyz, new_points

class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp):
        super().__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        last_channel = in_channel + 3
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            last_channel = out_channel

    def forward(self, xyz, points):
        new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        new_points = new_points.permute(0, 3, 2, 1)  # (B, D, nsample, npoint)
        for conv in self.mlp_convs:
            new_points = F.relu(conv(new_points))
        new_points = torch.max(new_points, 2)[0]  # (B, D', npoint)
        new_xyz = new_xyz
        return new_xyz, new_points.permute(0, 2, 1)

class PointNetPlusPlusNoMask(nn.Module):
    def __init__(self, input_dim=3, output_dim=1,npoints = [1024,256,64], radii=[0.005, 0.1, 0.3], nsamples=[64, 128, 256],
                    mlp_channels=[[128, 128, 256], [256, 256, 512], [512, 512, 1024]]):
            
        super().__init__()
        self.input_dim = input_dim - 3  # Only the additional features beyond XYZ

        self.sa1 = PointNetSetAbstraction(npoint=npoints[0], radius=radii[0], nsample=nsamples[0], in_channel=self.input_dim, mlp=mlp_channels[0])
        self.sa2 = PointNetSetAbstraction(npoint=npoints[1], radius=radii[1], nsample=nsamples[1], in_channel=mlp_channels[0][-1], mlp=mlp_channels[1])
        self.sa3 = PointNetSetAbstraction(npoint=npoints[2], radius=radii[2], nsample=nsamples[2], in_channel=mlp_channels[1][-1], mlp=mlp_channels[2])

        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )

    def forward(self, x):
        """
        x: Tensor of shape (B, N, input_dim)
        First 3 dims are XYZ. Remaining are features (optional).
        """
        xyz = x[:, :, :3]
        points = x[:, :, 3:] if x.shape[2] > 3 else None

        l1_xyz, l1_points = self.sa1(xyz, points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        _, l3_points = self.sa3(l2_xyz, l2_points)
        # max pool
        feat = torch.max(l3_points, 1, keepdim=True)[0]
        x = feat.squeeze(1)
        out = self.fc(x)
        return out


In [None]:
model = PointNetPlusPlusNoMask(
    input_dim=6, output_dim=1)

model = model.cuda()
train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=1, shuffle=True, num_workers=12, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_ds, batch_size=1, shuffle=False, num_workers=12, pin_memory=True)

for _ in tqdm(train_loader, desc="Training"):
    pass
for _ in tqdm(test_loader, desc="Testing"):
    pass

In [None]:
import plotly.graph_objects as go   

In [None]:
def voxel_normalize(points, voxel_size=0.005, percentile=95):
    """Normalize using voxel grid to handle irregular density.
    NaNs are preserved and ignored in normalization."""
    
    # Create voxel grid with valid (non-NaN) points only
    voxel_grid = {}
    for i, point in enumerate(points):
        if np.any(np.isnan(point)):
            continue  # skip NaNs in voxel computation
        voxel_idx = tuple(np.floor(point / voxel_size).astype(int))
        if voxel_idx not in voxel_grid:
            voxel_grid[voxel_idx] = []
        voxel_grid[voxel_idx].append(i)

    # Compute voxel centers from valid points
    voxel_centers = []
    for point_indices in voxel_grid.values():
        voxel_points = points[point_indices]
        voxel_centers.append(np.mean(voxel_points, axis=0))

    voxel_centers = np.array(voxel_centers)
    
    # Center and scale using voxel centers 
    center = np.median(voxel_centers, axis=0)
    distances = np.linalg.norm(voxel_centers - center, axis=1)
    scale = np.percentile(distances, percentile)

    # Normalize all points 
    centered_points = points - center
    scaled_points = centered_points / scale

    return scaled_points, center, scale

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
for epoch in range(1):
    model.train()
    epoch_z_err = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        pc, centers, meta = batch
        pc = pc.cpu().numpy().squeeze(0)
        centers = centers.cpu().numpy().squeeze(0)
        fig = go.Figure(data=[go.Scatter3d(
            x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
            mode='markers',
            marker=dict(size=2, color=pc[:,3:])
        )])
        fig.add_trace(go.Scatter3d(
            x=[centers[0]], y=[centers[1]], z=[centers[2]],
            mode='markers',
            marker=dict(size=5, color='red')
        ))
        fig.show()

        norm_pc, norm_ctr, scale = voxel_normalize(
                pc[:, :3], voxel_size=train_ds.voxel_size, percentile=train_ds.percentile)
        
        new_ctr = ((np.array(centers) - norm_ctr)/scale).astype(np.float32)
        fig = go.Figure(data=[go.Scatter3d(
            x=norm_pc[:, 0], y=norm_pc[:, 1], z=norm_pc[:, 2],
            mode='markers',
            marker=dict(size=2, color=pc[:,3:])
        )])
        fig.add_trace(go.Scatter3d(
            x=[new_ctr[0]], y=[new_ctr[1]], z=[new_ctr[2]],
            mode='markers',
            marker=dict(size=5, color='red')
        ))
        fig.show()

        # select a random point in original pc
        rand_idx = np.random.randint(0, pc.shape[0])
        rand_point_original = pc[rand_idx, :3]
        rand_point_normalized = norm_pc[rand_idx, :3]
        orig_dist_to_ctr = np.linalg.norm(rand_point_original - centers)
        norm_dist_to_ctr = np.linalg.norm(rand_point_normalized - new_ctr)
        reconstructed_dist_to_ctr = np.linalg.norm((rand_point_normalized * scale) + norm_ctr - centers)
        print(f"Original distance to center: {orig_dist_to_ctr:.4f}, "
              f"Normalized distance to center: {norm_dist_to_ctr:.4f}, "
              f"Reconstructed distance to center: {reconstructed_dist_to_ctr:.4f}")

        # normalized_centers = centers.cuda() # (1,3)
        # z_vals = normalized_centers[:, 2].unsqueeze(1).float()

        # orignal_centers = meta['orig_center']
        
        

        # # reconstruct original from the normalized
        norm_scale = meta["norm_scale"].cuda().view(-1, 1)
        norm_center = meta["norm_center"].cuda().view(-1, 3)
        print("Norm scale:", norm_scale, "Norm center:", norm_center)
        # reconstructed_centers = (normalized_centers * norm_scale) + norm_center
        # print(f"Reconstructed centers: {reconstructed_centers}, Original centers: {orignal_centers}")
        print("Num points ", pc.shape[0])
        print("Stem", meta['stem'])

        break
        

In [None]:
model = PointNetPlusPlusNoMask(
    input_dim=6, output_dim=1,
).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
fake_data = torch.randn(10, 2048, 6).cuda()  # Example batch of point clouds
fake_labels = torch.randn(10, 1).cuda()  # Example batch of labels
fig, ax = plt.subplots(figsize=(10, 5))
losses = []
for epoch in range(1000):
    model.train()
    epoch_loss = 0.0
    for batch in train_loader:
        pc, centers, meta = batch
        pc = pc.cuda()
        centers = centers.cuda()
        optimizer.zero_grad()
        outputs = model(pc)
        loss = F.mse_loss(outputs, centers[:,2].unsqueeze(1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        print("Raw z diff", (centers[:,2] - outputs.squeeze()).abs().mean().item())
    epoch_loss /= len(train_loader)
    losses.append(epoch_loss)
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
plt.plot(losses, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

In [13]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch_geometric.data
import torch.nn.functional as F
from torch.nn import init



# Define the GCN model
class GCN(torch.nn.Module):
    def __init__(self, in_len, out_len, hidden_sizes=[128, 512, 1024, 512, 128]):
        super(GCN, self).__init__()
        self.layers     = torch.nn.ModuleList()
        self.layers.append(GCNConv(in_len, hidden_sizes[0]))
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(GCNConv(hidden_sizes[i], hidden_sizes[i + 1]))
        self.layers.append(GCNConv(hidden_sizes[-1], out_len))
        self.reset_parameters()

    def reset_parameters(self):
        # Initialize conv1 and conv2 weights and biases
        for conv in self.layers:
            init.xavier_uniform_(conv.lin.weight)
            if conv.lin.bias is not None:
                init.zeros_(conv.lin.bias)
            conv.reset_parameters()
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for layer in self.layers:
            x = layer(x, edge_index)
            x = F.relu(x)
        # min pool 
        x = torch.min(x, dim=0)[0].unsqueeze(0)
        return x 


# Load model
model = GCN(in_len=6, out_len=2)
print("Num params:", sum(p.numel() for p in model.parameters() if p.requires_grad))
model.eval()

# random point cloud and edges
num_points = 100
edges = torch.randint(0, num_points, (2, 200))  # Random edges
point_cloud = torch.randn(num_points, 6)  # Random point cloud with 6 features
# Create a PyTorch Geometric data object
data = torch_geometric.data.Data(x=point_cloud, edge_index=edges)
# Forward pass through the model
with torch.no_grad():
    output = model(data)
    print("Output shape:", output.shape)


Num params: 1182978
Output shape: torch.Size([1, 2])


In [14]:
from learning.datatools import ApplePointCloudDataset, pad_collate_fn

config = {
    'voxel_size': 0.045,
    'percentile': 95
}
trainset = ApplePointCloudDataset(data_root="./blender/dataset/raw/apple_orchard-5-20-fp-only",
                                 manifest_path="./blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl",
                                 config=config,
                                 augment=True,
)
train_size = int(0.8 * len(trainset))
trainset, valset = torch.utils.data.random_split(trainset, [train_size, len(trainset) - train_size])
testset = ApplePointCloudDataset(data_root="./blender/dataset/raw/apple_orchard-5-20-fp-only",
                                manifest_path="./blender/dataset/curated/apple-orchard-v2-fp-only/test.jsonl",
                                config=config,
                                augment=False,
)
# from learning.datatools import pad_collate_fn
train_dl = torch.utils.data.DataLoader(trainset,
                                         batch_size=8,
                                    shuffle=True,
                                    num_workers=0,
                                    pin_memory=True,
                                    collate_fn=pad_collate_fn,
                                    )
val_dl = torch.utils.data.DataLoader(valset,
                                       batch_size=8,
                                   shuffle=False,
                                   num_workers=0,
                                   pin_memory=True,
                                   collate_fn=pad_collate_fn,
                                   )
test_dl = torch.utils.data.DataLoader(testset,
                                        batch_size=8,
                                    shuffle=False,
                                    num_workers=0,
                                    pin_memory=True,
                                    collate_fn=pad_collate_fn,
                                    )

Loading 2751 scenes from ./blender/dataset/curated/apple-orchard-v2-fp-only/train.jsonl …
Loading 1581 scenes from ./blender/dataset/curated/apple-orchard-v2-fp-only/test.jsonl …


In [17]:
from sklearn.neighbors import NearestNeighbors
def create_edges(pc, k=10):
    '''Create edges for a point cloud using k-nearest neighbors.'''
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(pc[:, :3])
    distances, indices = nbrs.kneighbors(pc[:, :3])
    # distances is the distance to the k nearest neighbors
    # indices is the indices of the k nearest neighbors
    edges = []
    for i in range(pc.shape[0]):
        for j in indices[i]:
            if i != j:
                edges.append((i, j))
    edges = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edges
model = GCN(in_len=6, out_len=1).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(10):
    model.train()
    epoch_loss = 0.0
    for i, (clouds, centers, masks, aux) in enumerate(train_dl):
        clouds = clouds.cuda()
        centers = centers.cuda()
        masks = masks.cuda()
        occ_rate = aux['occ_rate'].unsqueeze(1).float().cuda()
        norm_scale = aux["norm_scale"].view(-1, 1).cuda()
        norm_center = aux["norm_center"]
        labels = centers[:,2].unsqueeze(1).float()

        # Create edges for the point cloud
        z_errs_m = 0.0
        loss = 0.0
        for j, cloud in enumerate(clouds):
            edges = create_edges(cloud.cpu().numpy())
            edges = edges.cuda()
            data = torch_geometric.data.Data(x=cloud, edge_index=edges)
            output = model(data)
            loss = loss + F.mse_loss(output, labels[j].unsqueeze(0))
            z_err_m = ( output - labels[j].unsqueeze(0)).abs()*norm_scale
            z_errs_m  += z_err_m    
        z_errs_m/=len(clouds)

        loss = loss / len(clouds)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        optimizer.zero_grad()

        if i % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {loss.item():.4f} Z Error: {z_errs_m.mean().item():.4f}")
    epoch_loss /= len(train_dl)
    print(f"Epoch {epoch+1}, Average Loss: {epoch_loss:.4f}")

Epoch 1, Batch 1, Loss: 0.3029 Z Error: 0.2923


KeyboardInterrupt: 