In [None]:
import os
os.chdir("/data/projects/punim2016/zexianh/LMSeg")

import pyvista as pv
pv.start_xvfb()  

w_size = (1000, 1000)
zoom = 1.2

import torch
import torch.nn.functional as F
import numpy as np
from pathlib import PosixPath
from data.dataset import BudjBimWallMeshDataset

idx = 140
bbw_set = BudjBimWallMeshDataset(root='data/BBW', split='test', test_area='area2')

dual = bbw_set[idx]
mesh = pv.read(PosixPath(str(bbw_set.data_files[idx]).replace("/processed/", "/mesh/")).with_suffix(".ply"))

face_v = dual.pos.cpu().numpy()

face_adj = dual.edge_index.cpu().t().numpy()
face_adj = np.hstack((np.full((face_adj.shape[0], 1), 2), face_adj))

f_normal = dual.normals.cpu().numpy()
f_rgba = dual.face_rgba.cpu().numpy()
f_mask = dual.y.cpu().numpy()

## Mesh

In [None]:
mesh.plot(scalars='RGBA', 
          rgb=True,
          cpos='iso', 
          window_size=w_size,
          zoom=zoom,
          show_axes=False,
          show_scalar_bar=False)

In [None]:
mesh.cell_data['Mask'] = f_mask
mesh.cell_data.active_scalars_name = 'Mask'
mesh.plot(scalars='Mask', 
          cpos='iso', 
          window_size=w_size,
          zoom=zoom,
          show_axes=False,
          show_scalar_bar=False)

In [None]:
face_pcd = pv.PolyData(face_v)
face_pcd.lines = face_adj
face_pcd.point_data['RGBA'] = f_rgba
face_pcd.point_data.active_scalars_name = 'RGBA'
face_pcd.plot(scalars='RGBA', 
              rgb=True, 
              cpos='iso', 
              window_size=w_size, 
              zoom=zoom,
              render_points_as_spheres=True, 
              point_size=6, 
              line_width=1.5,
              show_axes=False)

In [None]:
face_pcd.point_data['Mask'] = f_mask
face_pcd.point_data.active_scalars_name = 'Mask'
face_pcd.plot(scalars='Mask', 
              cpos='iso', 
              window_size=w_size, 
              zoom=zoom,
              render_points_as_spheres=True, 
              point_size=6, 
              line_width=1.5,
              show_axes=False,
              show_scalar_bar=False)

## Surface Normal

In [None]:
face_pcd.point_data['Normal'] = f_normal[:,-3:]
arrows = face_pcd.glyph(orient='Normal', scale=False, factor=1.0, tolerance=0.005)
arrows.plot(cpos='iso', 
            window_size=w_size,
            zoom=zoom,
            show_axes=False,
            show_scalar_bar=False)

## Prediction

In [None]:
import yaml
import torch

from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torchmetrics.classification import BinaryF1Score, BinaryJaccardIndex

from train.trainer import Trainer
from model.net import LMSegNet
from model.PointNet.net import PointNetSeg
from model.PointNet2.net import PointNet2
from model.RandlaNet.net import RandlaNet
from model.DeeperGCN.net import DeeperGCN
from model.GraphUNet.net import GraphUNet
from model.PointTransformer.net import PointTransformer


with open('cfg/bbw/lmseg_feature.yaml', 'r') as f:
    cfg = yaml.safe_load(f) 
    model = LMSegNet(cfg['in_channels'], 
                     cfg['out_channels'],
                     cfg['hid_channels'], 
                     cfg['num_convs'], 
                     cfg['pool_factors'], 
                     cfg['num_nbrs'],
                     cfg['num_block'],
                     cfg['alpha'], 
                     cfg['beta'])


# with open('cfg/bbw/lmseg_rgb.yaml', 'r') as f:
#     cfg = yaml.safe_load(f)    
#     model = LMSegNet(cfg['in_channels'], 
#                      cfg['out_channels'],
#                      cfg['hid_channels'], 
#                      cfg['num_convs'], 
#                      cfg['pool_factors'], 
#                      cfg['num_nbrs'],
#                      cfg['num_block'],
#                      cfg['alpha'], 
#                      cfg['beta'])


# with open('cfg/bbw/lmseg_normals.yaml', 'r') as f:
#     cfg = yaml.safe_load(f)    
#     model = LMSegNet(cfg['in_channels'], 
#                      cfg['out_channels'],
#                      cfg['hid_channels'], 
#                      cfg['num_convs'], 
#                      cfg['pool_factors'], 
#                      cfg['num_nbrs'],
#                      cfg['num_block'],
#                      cfg['alpha'], 
#                      cfg['beta'])


# with open('cfg/bbw/randlanet_feature.yaml', 'r') as f:
#     cfg = yaml.safe_load(f)    
#     model = RandlaNet(cfg['in_channels'], 
#                       cfg['out_channels'],
#                       cfg['decimation'],
#                       cfg['num_nbrs'])


# with open('cfg/bbw/pointnet_feature.yaml', 'r') as f:    
#     cfg = yaml.safe_load(f)    
#     import torch_geometric.transforms as T
#     bbw_set.transform.transforms.append(T.FixedPoints(cfg['num_points']))
#     model = PointNetSeg(cfg['in_channels'], cfg['out_channels'], get_trans_feat=False)   


# with open('cfg/bbw/pointnet2_feature.yaml', 'r') as f:
#     cfg = yaml.safe_load(f)    
#     model = PointNet2(cfg['in_channels'], 
#                       cfg['out_channels'],
#                       cfg['pool_ratio'],
#                       cfg['num_nbrs'])


# with open('cfg/bbw/ptr_feature.yaml', 'r') as f:
#     cfg = yaml.safe_load(f)    
#     model = PointTransformer(cfg['in_channels'], 
#                              cfg['out_channels'], 
#                              cfg['hid_channels'], 
#                              cfg['pool_ratio'], 
#                              cfg['num_nbrs'])
    
    
# with open('cfg/bbw/deepergcn_feature.yaml', 'r') as f:
#     cfg = yaml.safe_load(f)    
#     model = DeeperGCN(cfg['in_channels'], 
#                       cfg['out_channels'],
#                       cfg['hid_channels'], 
#                       cfg['num_layers'])


# with open('cfg/bbw/gunet_feature.yaml', 'r') as f:
#     cfg = yaml.safe_load(f)    
#     model = GraphUNet(cfg['in_channels'], 
#                       cfg['hid_channels'], 
#                       cfg['out_channels'], 
#                       cfg['depth'],
#                       cfg['pool_ratios'],
#                       cfg['sum_res'],
#                       cfg['act'])
    
cfg['path'] = cfg['path'] + '/area2'
cfg['device'] = 'cpu'

trainer = Trainer(cfg=cfg) 
model = trainer.load_weights(model, f"epoch{cfg['epoch']}.pth")

## Evaluation

In [None]:
test_loader = DataLoader(bbw_set, 
                         batch_size=cfg['batch'], 
                         shuffle=False, 
                         num_workers=cfg['workers'])

metric_dict = {
    'f1': BinaryF1Score(), 
    'mIoU': BinaryJaccardIndex()
    }

cm = trainer.eval(model, 
                  test_loader, 
                  metric=metric_dict, 
                  ckpt=f"epoch{cfg['epoch']}.pth",
                  verbose=True)

In [None]:
with torch.no_grad():
    y = model(Batch.from_data_list([bbw_set[idx]]).to(cfg['device']))['y']
    pred = torch.nn.functional.sigmoid(y).detach().cpu().numpy()

thred = 0.5
pred[pred > thred] = 1
pred[pred <= thred] = 0

In [None]:
mesh.cell_data['Pred'] = pred
mesh.cell_data.active_scalars_name = 'Pred'
mesh.plot(scalars='Pred', 
          cpos='iso', 
          window_size=w_size,           
          zoom=zoom,
          show_axes=False,
          show_scalar_bar=False)

In [None]:
error = (pred != f_mask)

mesh.cell_data['Error'] = error
mesh.cell_data.active_scalars_name = 'Error'
mesh.plot(scalars='Error', 
          cpos='iso', 
          cmap=['white', 'red'],
          window_size=w_size,           
          zoom=zoom,
          show_axes=False,
          show_scalar_bar=False)

In [None]:
activations = {}

def hook_encoder(name):
    def _hook(mod, inp, out):
        pos_down, x_down, batch_down, edge_index_down = out
        activations[f'{name}/pos_down']   = [p.detach().cpu().numpy() for p in pos_down]
        activations[f'{name}/batch_down'] = [b.detach().cpu().numpy() for b in batch_down]
        activations[f'{name}/edge_down']  = [e.detach().cpu().numpy() for e in edge_index_down]
    return _hook

def hook_tensor(name):
    def _hook(mod, inp, out):
        activations[name] = out.detach().cpu().numpy()  # [N_pool, d]
    return _hook

In [None]:
model.eval()
with torch.no_grad():
    out = model(Batch.from_data_list([bbw_set[idx]]).to(cfg['device']))
    Z = out['y'].detach().cpu().squeeze(-1)  # [N]

eps = 1e-12
P = torch.sigmoid(Z)                      
P_np = P.numpy()
entropy = -(P_np * np.log(P_np + eps) + (1.0 - P_np) * np.log(1.0 - P_np + eps)) 
# Normalize for display
emin, emax = entropy.min(), entropy.max()
entropy_norm = (entropy - emin) / (emax - emin + 1e-12)

mesh.cell_data['Entropy'] = entropy_norm
mesh.cell_data.active_scalars_name = 'Entropy'

mesh.plot(
    scalars='Entropy',
    cpos='iso',
    window_size=(1000, 1050),
    zoom=1.2,
    cmap='viridis',            # perceptual; high=uncertain
    show_axes=False,
    show_scalar_bar=True,
    scalar_bar_args={
        "title": "Uncertainty (Entropy)",
        "position_x": 0.2,     # horizontal placement (0 = left, 1 = right)
        "position_y": 0.0,     # lower values push the bar down
        "title_font_size": 20,
        "label_font_size": 16,
        "italic": False,
        "bold": True           # honored on some backends; harmless otherwise
    }
)

In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import numpy as np

def pca_whiten(X, n_components=32, seed=0):
    """PCA -> whiten -> float32 features"""
    pca = PCA(n_components=n_components, whiten=True, random_state=seed)
    return pca.fit_transform(np.asarray(X, dtype=np.float32))

def l2_normalize(X, eps=1e-12):
    n = np.linalg.norm(X, axis=1, keepdims=True)
    return X / (n + eps)

def idw_upsample(src_pos, src_feat, dst_pos, k=1, eps=1e-12, power=2.0):
    """Inverse distance weighted upsampling."""
    nbrs = NearestNeighbors(n_neighbors=k).fit(src_pos)
    idx = nbrs.kneighbors(dst_pos, return_distance=False)
    d2  = np.sum((dst_pos[:, None, :] - src_pos[idx])**2, axis=-1) + eps
    w   = 1.0 / (d2 ** (power/2.0))
    w   = w / (w.sum(axis=1, keepdims=True) + eps)
    feat_k = src_feat[idx]       # (Nd, k, d)
    return np.einsum('nk,nkd->nd', w, feat_k)

def kmeans_labels(X, n_clusters=8, seed=0):
    km = KMeans(n_clusters=n_clusters, n_init=10, random_state=seed)
    return km.fit_predict(X).astype(np.int32)

h_enc = model.Enc.register_forward_hook(hook_encoder('Enc'))
h_hga = model.Enc.down_convs_hier[-1].register_forward_hook(hook_tensor('HGA_deep'))
h_lga = model.Enc.down_convs_local[-1].register_forward_hook(hook_tensor('LGA_last'))

model.eval()
with torch.no_grad():
    _ = model(Batch.from_data_list([bbw_set[idx]]).to(cfg['device']))

pos_list   = activations['Enc/pos_down']
pos_orig   = pos_list[0]                   # (N0, 3)
pos_deep   = pos_list[-1]                  # (Nd, 3)
H_hga_pool = activations['HGA_deep']       # (Nd, d_h)
H_lga_pool = activations['LGA_last']       # (Nd, d_l)

# PCA-whiten
Z_hga = pca_whiten(H_hga_pool, n_components=32, seed=0)
Z_lga = pca_whiten(H_lga_pool, n_components=32, seed=0)

# L2 normalization
Z_hga = l2_normalize(Z_hga)
Z_lga = l2_normalize(Z_lga)

# K-means clustering
n_clusters = 8
labels_hga_pool = kmeans_labels(Z_hga, n_clusters=n_clusters, seed=0)
labels_lga_pool = kmeans_labels(Z_lga, n_clusters=n_clusters, seed=0)

# 4) Upsample back to original resolution
labels_hga_full = idw_upsample(pos_deep, labels_hga_pool[:, None], pos_orig, k=1).ravel().astype(np.int32)
labels_lga_full = idw_upsample(pos_deep, labels_lga_pool[:, None], pos_orig, k=1).ravel().astype(np.int32)

mesh.cell_data['HGA_clusters'] = labels_hga_full
mesh.cell_data['LGA_clusters'] = labels_lga_full

annotations_hga = {int(i): f"HGA-{i}" for i in range(n_clusters)}
annotations_lga = {int(i): f"LGA-{i}" for i in range(n_clusters)}

# plot
mesh.cell_data.active_scalars_name = 'HGA_clusters'
mesh.plot(
    scalars='HGA_clusters',
    cpos='iso',
    window_size=(800, 800),
    zoom=1.1,
    cmap='Set3',
    categories=True,
    annotations=annotations_hga,
    show_axes=False,
    show_scalar_bar=True,
    scalar_bar_args={"title": "", "position_x": 0.2, "position_y": 0.0,
                     "title_font_size": 20, "label_font_size": 16, "italic": False, "bold": True}
)

mesh.cell_data.active_scalars_name = 'LGA_clusters'
mesh.plot(
    scalars='LGA_clusters',
    cpos='iso',
    window_size=(800, 800),
    zoom=1.1,
    cmap='Set3',
    categories=True,
    annotations=annotations_lga,
    show_axes=False,
    show_scalar_bar=True,
    scalar_bar_args={"title": "", "position_x": 0.2, "position_y": 0.0,
                     "title_font_size": 20, "label_font_size": 16, "italic": False, "bold": True}
)

## Gen Aggr

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_theme(style="darkgrid", 
              palette="pastel")
sns.set_theme(rc={"figure.dpi": 250, 
                  'savefig.dpi': 250,
                  "axes.spines.right": False, 
                  "axes.spines.top": False})

In [None]:
def flatten_all(arrs):
    return [np.asarray(a).reshape(-1) for a in arrs if a is not None]

def stack_long(module_name, mode, arrays):
    rows = []
    for li, a in enumerate(arrays, 1):
        if a is None: 
            continue
        flat = np.asarray(a).reshape(-1)
        rows += [{"Module": module_name, "Mode": mode, "Layer": li, "t": v} for v in flat]
    return rows

def plot_compact_split_violin(HGA_t_avg, LGA_t_avg, HGA_t_max, LGA_t_max):
    df = pd.DataFrame(
        stack_long("HGA+", "t=0.0", flatten_all(HGA_t_avg)) +
        stack_long("LGA+", "t=0.0", flatten_all(LGA_t_avg)) +
        stack_long("HGA+", "t=1.0", flatten_all(HGA_t_max)) +
        stack_long("LGA+", "t=1.0", flatten_all(LGA_t_max))
    )

    layers = sorted(df["Layer"].unique())
    x_order = []
    for l in layers:
        x_order += [f"L{l}\n(t=0.0)", f"L{l}\n(t=1.0)"]
    df["LayerMode"] = df.apply(lambda r: f"L{r['Layer']}\n({r['Mode']})", axis=1)
    df["LayerMode"] = pd.Categorical(df["LayerMode"], categories=x_order, ordered=True)

    sns.set_context("paper") 
    plt.figure(figsize=(9, 4))

    ax = sns.violinplot(
        data=df, x="LayerMode", y="t",
        hue="Module", split=True, inner="quartile", cut=0, linewidth=0.8
    )

    ax.set_xlabel("", fontsize=10, weight="bold")
    ax.set_ylabel("t", fontsize=10, weight="bold")
    ax.set_title("Learnable t by layer & module — HGA+ vs. LGA+", fontsize=11, weight="bold")

    # --- add vertical dotted lines between L1, L2, L3 groups ---
    num_per_layer = 2  # each layer has 2 entries (t=0.0, t=1.0)
    for i in range(1, len(layers)):  # after L1, after L2
        ax.axvline(i * num_per_layer - 0.5, color="gray", linestyle="--", linewidth=0.8)

    # legend styling
    leg = ax.legend(title="Module", fontsize=9, title_fontsize=9, frameon=False, loc="upper left")
    for t in leg.texts: 
        t.set_fontsize(9)

    plt.tight_layout(pad=0.4)
    plt.show()


HGA_t_avg = []

for name, param in model.named_parameters():
  if name in ['Enc.down_convs_hier.0.gen_aggr_avg.t',
              'Enc.down_convs_hier.1.gen_aggr_avg.t',
              'Enc.down_convs_hier.2.gen_aggr_avg.t']:
    HGA_t_avg.append(param.detach().cpu().numpy())

LGA_t_avg = []
for name, param in model.named_parameters():
  if name in ['Enc.down_convs_local.0.gen_aggr_avg.t',
              'Enc.down_convs_local.1.gen_aggr_avg.t',
              'Enc.down_convs_local.2.gen_aggr_avg.t']:
    LGA_t_avg.append(param.detach().cpu().numpy())

HGA_t_max = []

for name, param in model.named_parameters():
  if name in ['Enc.down_convs_hier.0.gen_aggr_max.t',
              'Enc.down_convs_hier.1.gen_aggr_max.t',
              'Enc.down_convs_hier.2.gen_aggr_max.t']:
    HGA_t_max.append(param.detach().cpu().numpy())

LGA_t_max = []
for name, param in model.named_parameters():
  if name in ['Enc.down_convs_local.0.gen_aggr_max.t',
              'Enc.down_convs_local.1.gen_aggr_max.t',
              'Enc.down_convs_local.2.gen_aggr_max.t']:
    LGA_t_max.append(param.detach().cpu().numpy())
    
    
plot_compact_split_violin(HGA_t_avg, LGA_t_avg, HGA_t_max, LGA_t_max)

## Surface Area

In [None]:
import trimesh 
import torch

test_mask = [] 
face_areas = []

for idx, f in enumerate(bbw_set.mesh_list):
    plydata = trimesh.load(f, force="mesh")
    
    face_areas.append(torch.from_numpy(np.asarray(plydata.area_faces)))
    test_mask.append(bbw_set[idx].y.cpu().view(-1).long())
    
test_mask = torch.hstack(test_mask)
face_areas = torch.hstack(face_areas)

In [None]:
from torch_geometric.nn import aggr 

sum_aggr = aggr.SumAggregation()
area_per_mask = sum_aggr(face_areas.view(-1, 1), test_mask)

area_per_mask / area_per_mask.sum()

## FPS Pooling

In [None]:
import torch
from model.pool import RandomPooling, FPSPooling, EdgeRandomPooling

In [None]:
num_pool = 4
fps = FPSPooling()

p = bbw_set[idx].pos
y = bbw_set[idx].y
edge_index = bbw_set[idx].edge_index

batch = torch.zeros(p.shape[0], dtype=torch.long)
ptr = torch.Tensor([0, p.shape[0]]).long()

In [None]:
from torch_geometric.nn import radius_graph, knn_graph

p_down = [p]
y_down = [y]
edge_down = [edge_index]

for i in range(0, num_pool):    
    edge_index_pool, node_index, ptr_pool = fps(p, 
                                                edge_index, 
                                                3,
                                                ptr)
    p_pool = p[node_index]
    y_pool = y[node_index]
    
    p = p_pool
    y = y_pool
    edge_index = radius_graph(x=p[:,0:2], r= 0.05 * (i + 1), loop=False, max_num_neighbors=3, flow = 'target_to_source')
    ptr = ptr_pool
    
    p_down.append(p)
    y_down.append(y)
    edge_down.append(edge_index)

In [None]:
pool_idx = 3

p_pool = p_down[pool_idx].cpu().numpy()
y_pool = y_down[pool_idx].cpu().numpy()

edge_pool = edge_down[pool_idx].cpu().t().numpy()
edge_pool = np.hstack((np.full((edge_pool.shape[0], 1), 2), edge_pool), dtype=edge_pool.dtype)

In [None]:
face_pcd = pv.PolyData(p_pool)

# face_pcd.lines = edge_pool
face_pcd.point_data['mask'] = y_pool
face_pcd.point_data.active_scalars_name = 'mask'
face_pcd.plot(scalars='mask', cmap='viridis', cpos='iso', window_size=w_size, render_points_as_spheres=True, point_size=6, line_width=1)

## Random Decimation

In [None]:
num_pool = 4
decimate = RandomPooling()

p = bbw_set[idx].pos
y = bbw_set[idx].y
edge_index = bbw_set[idx].edge_index

batch = torch.zeros(p.shape[0], dtype=torch.long)
ptr = torch.Tensor([0, p.shape[0]]).long()

In [None]:
from torch_geometric.nn import radius_graph, knn_graph
from torch_geometric.utils import coalesce, to_undirected

p_down = [p]
y_down = [y]
edge_down = [edge_index]

for i in range(0, num_pool):    
    edge_index_pool, node_index, ptr_pool = decimate(p[:,0:2], 
                                                     edge_index, 
                                                     3,
                                                     ptr)
    p_pool = p[node_index]
    y_pool = y[node_index]
    
    p = p_pool
    y = y_pool
    new_edge_index = radius_graph(x=p[:,0:2], 
                                  r= 0.05 * (i + 1), 
                                  loop=False, max_num_neighbors=3, 
                                  flow = 'target_to_source')
    edge_index = to_undirected(coalesce(torch.cat([new_edge_index, 
                                                   edge_index_pool], dim=-1)))
    ptr = ptr_pool
    
    p_down.append(p)
    y_down.append(y)
    edge_down.append(edge_index)

In [None]:
pool_idx = 3

p_pool = p_down[pool_idx].cpu().numpy()
y_pool = y_down[pool_idx].cpu().numpy()

edge_pool = edge_down[pool_idx].cpu().t().numpy()
edge_pool = np.hstack((np.full((edge_pool.shape[0], 1), 2), edge_pool))

In [None]:
face_pcd = pv.PolyData(p_pool)

# face_pcd.lines = edge_pool
face_pcd.point_data['mask'] = y_pool
face_pcd.point_data.active_scalars_name = 'mask'
face_pcd.plot(scalars='mask', cmap='viridis', cpos='iso', window_size=w_size, render_points_as_spheres=True, point_size=6, line_width=1)

## Edge Decimation

In [None]:
num_pool = 4
edge_decimate = EdgeRandomPooling()

p = bbw_set[idx].pos
y = bbw_set[idx].y
edge_index = bbw_set[idx].edge_index

In [None]:

p_down = [p]
y_down = [y]
edge_down = [edge_index]
dropout_rate = [0.8, 0.75, 0.75, 0.5]

for i in range(0, num_pool):
    edge_index_pool, node_index = edge_decimate(p, edge_index)
    edge_decimate.dropout_rate = dropout_rate[i]
    
    p_pool = p[node_index]
    y_pool = y[node_index]
    
    p = p_pool
    y = y_pool
    edge_index = edge_index_pool
    
    p_down.append(p)
    y_down.append(y)
    edge_down.append(edge_index)

In [None]:
pool_idx = 2

p_pool = p_down[pool_idx].cpu().numpy()
y_pool = y_down[pool_idx].cpu().numpy()

edge_pool = edge_down[pool_idx].cpu().t().numpy()
edge_pool = np.hstack((np.full((edge_pool.shape[0], 1), 2), edge_pool))

In [None]:
face_pcd = pv.PolyData(p_pool)

face_pcd.lines = edge_pool
face_pcd.point_data['mask'] = y_pool
face_pcd.point_data.active_scalars_name = 'mask'
face_pcd.plot(scalars='mask', cmap='viridis', cpos='iso', window_size=w_size, render_points_as_spheres=True, point_size=6, line_width=1)