In [45]:
import numpy as np
import torch
import json
import open3d as o3d

In [46]:
data_dir = '/data/biophys/schimmenti/Repositories/single-cell-analysis-of-organoids/measurements/point_cloud_approach/'
annotation_json = data_dir + 'annotations_DD.json'

In [47]:
with open(annotation_json) as f:
    annotations = json.load(f)

In [48]:
dataset = {}
for file_key, anns in annotations.items():
    pc_filename = data_dir + file_key
    points = np.asarray(o3d.io.read_point_cloud(pc_filename).points)
    identifier = file_key.split('/')[-1].split('.ply')[0]
    dataset[identifier] = {'points': points, 'annotations': np.array(anns)}

In [56]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from torch_geometric.utils import to_undirected

def build_edge_labelled_knn_graph(points, point_labels, k):
    """
    points: (N, d) float32 array/tensor with XYZ (or 2D) coords
    point_labels: (N,) int/bool per-node annotations
        Example rule below sets edge_label=1 if endpoints share the same node label.
        Swap the rule to match your own definition.

    Returns a PyG Data with:
      - pos, x (=pos as default), edge_index
      - edge_label_index (== edge_index)
      - edge_label (E,)
      - train/val/test masks over edges
    """
    if not torch.is_tensor(points): points = torch.tensor(points, dtype=torch.float32)
    if not torch.is_tensor(point_labels): point_labels = torch.tensor(point_labels)

    pos = points
    x   = pos  # or replace with your point features (N, F)

    # kNN graph
    edge_index = knn_graph(x=pos, k=k, loop=False)
    edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
    src, dst = edge_index
    edge_label = (point_labels[src] == point_labels[dst]).to(torch.long)

    with torch.no_grad():
        edge_vec = pos[dst] - pos[src]
        edge_len = torch.linalg.norm(edge_vec, dim=1, keepdim=True)  # (E,1)
    edge_attr = edge_len  # or torch.cat([edge_vec, edge_len], dim=1)

    data = Data(
        x=x, pos=pos, edge_index=edge_index,
        edge_attr=edge_attr,
        edge_label=edge_label,
        edge_label_index=edge_index, 
    )
    return data

In [57]:
for key in dataset.keys():
    data = build_edge_labelled_knn_graph(dataset[key]['points'], dataset[key]['annotations'], k=16)

ImportError: 'knn_graph' requires 'torch-cluster'