In [None]:
import numpy as np
import torch
import json
import open3d as o3d
from pathlib import Path

In [None]:
data_dir = '/data/biophys/schimmenti/Repositories/single-cell-analysis-of-organoids/measurements/point_cloud_approach/'
data_dir = '/Users/schimmenti/Desktop/DresdenProjects/Organoids/single-cell-analysis-of-organoids/measurements/point_cloud_approach/'
annotation_json = data_dir + 'annotations_DD.json'

In [None]:
with open(annotation_json) as f:
    annotations = json.load(f)

In [None]:
dataset = {}
for file_key, anns in annotations.items():
    pc_filename = data_dir + file_key
    points = np.asarray(o3d.io.read_point_cloud(pc_filename).points)
    identifier = file_key.split('/')[-1].split('.ply')[0]
    dataset[identifier] = {'points': points, 'annotations': np.array(anns)}

In [None]:
unlabelled_dataset = {}
for file in Path(data_dir).joinpath('point_clouds_OO/').glob('*.ply'):
    points = np.asarray(o3d.io.read_point_cloud(file.absolute()).points)
    identifier = file.name.split('.ply')[0]
    unlabelled_dataset[identifier] = {'points': points, 'annotations': np.full(len(points), -1, dtype=int)}

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph, SAGEConv
from torch_geometric.utils import to_undirected
from torch_geometric.data import InMemoryDataset, Dataset
from torch_geometric.loader import DataLoader
import math, random
from torch.utils.data import Dataset as _TorchDataset

def build_edge_labelled_knn_graph(points, point_labels, k):
    """
    points: (N, d) float32 array/tensor with XYZ (or 2D) coords
    point_labels: (N,) int/bool per-node annotations
        Example rule below sets edge_label=1 if endpoints share the same node label.
        Swap the rule to match your own definition.

    Returns a PyG Data with:
      - pos, x (=pos as default), edge_index
      - edge_label_index (== edge_index)
      - edge_label (E,)
      - train/val/test masks over edges
    """
    if not torch.is_tensor(points): points = torch.tensor(points, dtype=torch.float32)
    if not torch.is_tensor(point_labels): point_labels = torch.tensor(point_labels)

    pos = points
    x   = pos  # or replace with your point features (N, F)

    # kNN graph
    edge_index = knn_graph(x=pos, k=k, loop=False)
    edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
    src, dst = edge_index
    edge_label = (point_labels[src] == point_labels[dst]).to(torch.long)

    with torch.no_grad():
        edge_vec = pos[dst] - pos[src]
        edge_len = torch.linalg.norm(edge_vec, dim=1, keepdim=True)  # (E,1)
    edge_attr = torch.cat([edge_vec, edge_len], dim=1)

    data = Data(
        x=x, pos=pos, edge_index=edge_index,
        edge_attr=edge_attr,
        edge_label=edge_label,
        edge_label_index=edge_index, 
    )
    return data
class RandomRotation:
    """Randomly rotate 3D points about a random axis or fixed axis."""
    def __init__(self, angle_deg=180, axis=None, seed=None):
        self.angle_rad = math.radians(angle_deg)
        self.axis = None if axis is None else np.asarray(axis, dtype=float)
        self.seed = seed
    def __call__(self, points):
        # Accept numpy array or torch tensor; return same type as input (torch tensor if input was torch)
        was_torch = False
        if 'torch' in str(type(points)):
            was_torch = True
            pts = points.detach().cpu().numpy()
        else:
            pts = np.asarray(points)
        if self.seed is not None:
            np.random.seed(self.seed)
            random.seed(self.seed)
        angle = random.uniform(-self.angle_rad, self.angle_rad)
        if self.axis is None:
            axis = np.random.normal(size=3)
        else:
            axis = np.array(self.axis, dtype=float)
        axis = axis / (np.linalg.norm(axis) + 1e-12)
        ux, uy, uz = axis
        c = math.cos(angle)
        s = math.sin(angle)
        R = np.array([
            [c + ux*ux*(1-c), ux*uy*(1-c) - uz*s, ux*uz*(1-c) + uy*s],
            [uy*ux*(1-c) + uz*s, c + uy*uy*(1-c), uy*uz*(1-c) - ux*s],
            [uz*ux*(1-c) - uy*s, uz*uy*(1-c) + ux*s, c + uz*uz*(1-c)],
        ])
        rotated = pts.dot(R.T)
        if was_torch:
            return torch.tensor(rotated, dtype=torch.float32)
        return rotated

class PCGraphDataset(_TorchDataset):
    """Lightweight dataset that builds kNN graph Data objects on-the-fly."""
    def __init__(self, data_dict, k=6, transform=None):
        self.keys = list(data_dict.keys())
        self.data_dict = data_dict
        self.k = k
        self.transform = transform
    def __len__(self):
        return len(self.keys)
    def __getitem__(self, idx):
        key = self.keys[idx]
        pts = self.data_dict[key]['points']
        labs = self.data_dict[key]['annotations']
        # center & scale
        centered = pts - pts.mean(axis=0)
        centered = centered / (centered.std(axis=0) + 1e-12)
        if self.transform is not None:
            centered = self.transform(centered)
        return build_edge_labelled_knn_graph(centered, labs, k=self.k)

class EdgeClassifier(nn.Module):
    """GNN encoder + edge MLP for binary edge-label prediction."""
    def __init__(self, in_ch, hid=64, edge_in=1, dropout=0.1):
        super().__init__()
        self.dropout = dropout
        # Node encoder: 2-layer GraphSAGE
        self.gnn = nn.ModuleList([
            SAGEConv(in_ch, hid),
            SAGEConv(hid, hid)
        ])
        # optional projection for edge_attr (if present)
        self.edge_attr_mlp = nn.Sequential(
            nn.Linear(edge_in, max(8, edge_in)),
            nn.ReLU(),
            nn.Linear(max(8, edge_in), 16),
            nn.ReLU()
        ) if edge_in is not None and edge_in > 0 else None
        # Edge head: concat(h_i, h_j, |h_i-h_j|, proj(edge_attr))
        mlp_in = 3*hid + (16 if self.edge_attr_mlp is not None else 0)
        self.edge_mlp = nn.Sequential(
            nn.Linear(mlp_in, 128),
            nn.ReLU(),
            nn.Dropout(p=self.dropout),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        # encode node features
        for i, conv in enumerate(self.gnn):
            x = conv(x, edge_index)
            x = F.relu(x)
            if self.dropout > 0:
                x = F.dropout(x, p=self.dropout, training=self.training)
        # edge pairs to score
        src, dst = data.edge_label_index
        h_i, h_j = x[src], x[dst]
        feats = [h_i, h_j, (h_i - h_j).abs()]
        if hasattr(data, "edge_attr") and data.edge_attr is not None and self.edge_attr_mlp is not None:
            # project edge attributes before concatenation
            e = data.edge_attr
            # ensure shape (E, edge_in)
            e_proj = self.edge_attr_mlp(e)
            feats.append(e_proj)
        z = torch.cat(feats, dim=1)
        logit = self.edge_mlp(z).squeeze(-1)
        return logit


In [None]:
# Training loop (concise)
import torch.optim as optim
import numpy as np
from sklearn import metrics
from torch_geometric.loader import DataLoader as PyGDataLoader

k_for_nn = 12
transform = RandomRotation(angle_deg=180, axis=None)
pcg_dataset = PCGraphDataset(dataset, k=k_for_nn, transform=transform)
data_ldr = PyGDataLoader(pcg_dataset, batch_size=5, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Note: build_edge_labelled_knn_graph creates edge_attr of size 4 (vec3 + length1)
model = EdgeClassifier(in_ch=3, hid=64, edge_in=4, dropout=0.1).to(device)
opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = torch.nn.BCEWithLogitsLoss()

def train_epoch(loader):
    model.train()
    total_loss = 0.0
    total_e = 0
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        labels = batch.edge_label.float().to(device)
        loss = criterion(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.item() * labels.numel()
        total_e += labels.numel()
    return total_loss / max(1, total_e)

def evaluate(loader):
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            probs = torch.sigmoid(model(batch))
            preds = (probs > 0.5).long().cpu().numpy()
            labels = batch.edge_label.long().cpu().numpy()
            y_pred.append(preds)
            y_true.append(labels)
    if len(y_true) == 0:
        return dict(acc=0.0, prec=0.0, rec=0.0, f1=0.0)
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    acc = metrics.accuracy_score(y_true, y_pred)
    prec = metrics.precision_score(y_true, y_pred, zero_division=0)
    rec = metrics.recall_score(y_true, y_pred, zero_division=0)
    f1 = metrics.f1_score(y_true, y_pred, zero_division=0)
    return dict(acc=acc, prec=prec, rec=rec, f1=f1)

# Quick run: adjust epochs and batch_size as needed
epochs = 100
for ep in range(1, epochs+1):
    loss = train_epoch(data_ldr)
    stats = evaluate(data_ldr)
    print(f'ep {ep:02d} loss={loss:.4f} acc={stats["acc"]:.4f} f1={stats["f1"]:.4f}')

print('done')

In [None]:
# Create dataset and loader (batch=1 so we can map per-cloud)
pcg_unlabel = PCGraphDataset(unlabelled_dataset, k=k_for_nn, transform=None)
unl_ldr = PyGDataLoader(pcg_unlabel, batch_size=1, shuffle=False)

# Run inference
model.eval()
predictions = {}
with torch.no_grad():
    for i, batch in enumerate(unl_ldr):
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.sigmoid(logits).cpu().numpy()
        # edge_label_index is (2, E)
        edge_index = batch.edge_label_index.cpu().numpy()
        src = edge_index[0].tolist()
        dst = edge_index[1].tolist()
        key = pcg_unlabel.keys[i]
        predictions[key] = {'src': np.array(src), 'dst': np.array(dst), 'prob': np.array(probs)}
        print(f'[{i}] {key}: edges={len(probs)} mean_prob={np.mean(probs):.4f}')

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
for key in predictions.keys():
    points = unlabelled_dataset[key]['points']
    src = predictions[key]['src']
    dst = predictions[key]['dst']
    probs = predictions[key]['prob']
    pred = (probs > 0.5).astype(bool)
    graph = nx.from_edgelist(np.vstack([src, dst]).T[pred])
    components = list(nx.connected_components(graph))
    labels = np.concatenate([ [c_idx]*len(components[c_idx]) for c_idx in range(len(components)) ])
    fig = plt.figure(figsize=(4,4))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(points[:, 2], points[:, 1], points[:,0], c=labels, cmap='tab10', s=3)
    plt.show()