In [4]:
import os
import glob
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, HTML

try:
    from scipy.spatial import cKDTree as KDTree
except Exception:
    KDTree = None


PAIRS_GLOB = '../../../data/processed/faust/preprocessed/fixedN_4096/pairs/*.npz'
ICP_DIRS = ['../../../data/processed/faust/preprocessed/fixedN_4096/icp_test', '../../../data/processed/faust/preprocessed/fixedN_4096/icp_all']
MATCH_DIR = '../../../data/processed/faust/preprocessed/fixedN_4096/hypergct_test'

print('Visualizer loaded. If you see empty lists below, ensure the paths exist in this workspace.')
print('Found pairs:', len(glob.glob(PAIRS_GLOB)))

Visualizer loaded. If you see empty lists below, ensure the paths exist in this workspace.
Found pairs: 4950


In [5]:
def apply_transform(pts, T):
    pts_h = np.concatenate([pts, np.ones((len(pts),1))], axis=1)
    pts_t = pts_h @ T.T
    return pts_t[:, :3]

def load_pair(path):
    d = np.load(path)
    if 'src_coords' in d and 'tgt_coords' in d:
        return d['src_coords'], d['tgt_coords']
    # fallback names
    return d.get('xyz0'), d.get('xyz1')

def find_icp_for_pair(pair_path):
    base = os.path.basename(pair_path).replace('.npz','')
    # search icp dirs for matching file name pattern containing base
    for d in ICP_DIRS:
        if not os.path.isdir(d):
            continue
        # look for any file containing base
        for f in os.listdir(d):
            if base in f and f.endswith('.npz'):
                return os.path.join(d, f)
    return None

def load_matches_for_pair(pair_path):
    base = os.path.basename(pair_path).replace('.npz','')
    d = MATCH_DIR
    f = os.path.join(d, f'{base}_hypergct_fallback.npz')
    if os.path.exists(f):
        m = np.load(f)
        if 'src_matches' in m and 'tgt_matches' in m:
            return m['src_matches'].astype(int), m['tgt_matches'].astype(int)
    return None

def build_figure(src, src_t, tgt, matches=None, match_limit=500, outlier_idx=None):
    traces = []
    traces.append(go.Scatter3d(x=src[:,0], y=src[:,1], z=src[:,2], mode='markers', marker=dict(size=2,color='orange'), name='src'))
    traces.append(go.Scatter3d(x=src_t[:,0], y=src_t[:,1], z=src_t[:,2], mode='markers', marker=dict(size=2,color='green'), name='src_transformed'))
    traces.append(go.Scatter3d(x=tgt[:,0], y=tgt[:,1], z=tgt[:,2], mode='markers', marker=dict(size=2,color='cyan'), name='tgt'))

    if outlier_idx is not None and len(outlier_idx) > 0:
        out_pts = src_t[outlier_idx]
        traces.append(go.Scatter3d(x=out_pts[:,0], y=out_pts[:,1], z=out_pts[:,2],
                                   mode='markers', marker=dict(size=4,color='red'), name='outliers'))


    if matches is not None and len(matches[0])>0 and match_limit>0:
        sidx, tidx = matches
        n = min(len(sidx), match_limit)
        xs = []
        ys = []
        zs = []
        for i in range(n):
            s = src_t[sidx[i]]
            t = tgt[tidx[i]]
            xs += [s[0], t[0], None]
            ys += [s[1], t[1], None]
            zs += [s[2], t[2], None]
        traces.append(go.Scatter3d(x=xs, y=ys, z=zs, mode='lines', line=dict(color='red', width=1), name='matches'))

    fig = go.Figure(data=traces)
    fig.update_layout(scene=dict(aspectmode='data'), height=800)
    return fig

In [7]:
# Widget UI: pair selector, show matches, slider, save button
pair_files = sorted(glob.glob(PAIRS_GLOB))
pair_dropdown = widgets.Dropdown(options=pair_files, description='Pair:', layout=widgets.Layout(width='800px'))
show_matches_cb = widgets.Checkbox(value=True, description='Show matches')
match_slider = widgets.IntSlider(value=200, min=0, max=2000, step=10, description='Match count')
save_btn = widgets.Button(description='Save PNG')
out = widgets.Output(layout={'border':'1px solid black'})

# new widgets for outlier display
show_outliers_cb = widgets.Checkbox(value=False, description='Highlight far points')
outlier_thresh = widgets.FloatSlider(value=0.02, min=0.0, max=1.0, step=0.001, description='Distance thresh')

fig_box = widgets.Output()

def compute_outliers(src_t, tgt, thresh):
    if KDTree is not None:
        tree = KDTree(tgt)
        dists, _ = tree.query(src_t, k=1)
    else:
        # fallback (slower)
        dists = np.sqrt(((src_t[:,None,:] - tgt[None,:,:])**2).sum(-1)).min(axis=1)
    mask = dists > thresh
    return np.nonzero(mask)[0], dists

def update_fig(*args):
    with fig_box:
        fig_box.clear_output(wait=True)
        with out:
            out.clear_output(wait=True)
        if not pair_dropdown.value:
            print('No pair selected')
            return
        src, tgt = load_pair(pair_dropdown.value)
        icp_path = find_icp_for_pair(pair_dropdown.value)
        T = np.eye(4)
        if icp_path is not None:
            icp = np.load(icp_path)
            if 'trans' in icp:
                T = icp['trans']
        src_t = apply_transform(src, T)

        matches = None
        if show_matches_cb.value:
            matches = load_matches_for_pair(pair_dropdown.value)

        outlier_idx = None
        if show_outliers_cb.value:
            idxs, dists = compute_outliers(src_t, tgt, outlier_thresh.value)
            outlier_idx = idxs
            pct = 100.0 * len(idxs) / len(src_t) if len(src_t)>0 else 0.0
            with out:
                print(f'Outliers (> {outlier_thresh.value:.4f}): {len(idxs)} / {len(src_t)} ({pct:.1f}%)')

        fig = build_figure(src, src_t, tgt, matches=matches, match_limit=match_slider.value, outlier_idx=outlier_idx)
        display(fig)

def on_save_clicked(b):
    if not pair_dropdown.value:
        return
    src, tgt = load_pair(pair_dropdown.value)
    icp_path = find_icp_for_pair(pair_dropdown.value)
    T = np.eye(4)
    if icp_path is not None:
        icp = np.load(icp_path)
        if 'trans' in icp:
            T = icp['trans']
    src_t = apply_transform(src, T)
    matches = None
    if show_matches_cb.value:
        matches = load_matches_for_pair(pair_dropdown.value)

    outlier_idx = None
    if show_outliers_cb.value:
        idxs, dists = compute_outliers(src_t, tgt, outlier_thresh.value)
        outlier_idx = idxs
        with out:
            out.clear_output(wait=True)
            pct = 100.0 * len(idxs) / len(src_t) if len(src_t)>0 else 0.0
            print(f'Outliers (> {outlier_thresh.value:.4f}): {len(idxs)} / {len(src_t)} ({pct:.1f}%)')

    fig = build_figure(src, src_t, tgt, matches=matches, match_limit=match_slider.value, outlier_idx=outlier_idx)
    out_path = os.path.splitext(pair_dropdown.value)[0] + '_vis.png'
    try:
        fig.write_image(out_path)
        print('Saved', out_path)
    except Exception as e:
        print('Failed to save image (kaleido may be missing):', e)

pair_dropdown.observe(update_fig, names='value')
show_matches_cb.observe(update_fig, names='value')
match_slider.observe(update_fig, names='value')
# hook new widgets
show_outliers_cb.observe(update_fig, names='value')
outlier_thresh.observe(update_fig, names='value')

save_btn.on_click(on_save_clicked)

ui = widgets.VBox([widgets.HBox([pair_dropdown, save_btn]), widgets.HBox([show_matches_cb, match_slider, show_outliers_cb, outlier_thresh]), fig_box, out])
display(ui)
# initialize
if pair_files:
    pair_dropdown.value = pair_files[0]
else:
    print('No pair files found at', PAIRS_GLOB)

VBox(children=(HBox(children=(Dropdown(description='Pair:', layout=Layout(width='800px'), options=('../../../dâ€¦