# Evidence Based LM

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

from tbp.monty.frameworks.utils.logging_utils import (load_stats,
                                                        print_overall_stats,
                                                        print_unsupervised_stats)
from tbp.monty.frameworks.utils.plot_utils import (plot_graph,
                                                         show_initial_hypotheses, 
                                                         plot_evidence_at_step)

In [None]:
pretrain_path = os.path.expanduser("~/tbp/results/monty/pretrained_models/")
# pretrained_dict = pretrain_path + "pretrained_ycb_v4_test/surf_agent_1lm_10distinctobj/pretrained/"
pretrained_dict = pretrain_path + "pretrained_ycb_v7/surf_agent_1lm_10distinctobj/pretrained/"
# pretrained_dict = pretrain_path + "pretrained_ycb/supervised_pre_training_location_noise005/pretrained/"
log_path = os.path.expanduser("~/tbp/results/monty/projects/monty_runs/")
# log_path = os.path.expanduser("~/tbp/results/monty/projects/evidence_eval_runs/logs/")
exp_name = "evidence_tests_nomt/"
exp_path = log_path + exp_name
# exp_path = os.path.expanduser("/private/var/folders/g2/tbt9tg416wjbn478nv98kt6c0000gp/T/tmpyx51cdv3/")

save_path = exp_path + '/stepwise_examples/'
# save_path = os.path.expanduser("~/tbp/results/monty/figures/evidenceLM/stepwise_examples/"+exp_name)
train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=False,
                                                                load_eval=True,
                                                                load_detailed=False,
                                                                pretrained_dict=pretrained_dict,
                                                               )

In [None]:
%matplotlib notebook

In [None]:
# Load and analyze unsupervised learning experiments
log_path = os.path.expanduser("~/tbp/results/monty/projects/monty_runs/")
pretrained_dict = pretrain_path + "pretrained_ycb_v4/surf_agent_1lm_10distinctobj/pretrained/"

exp_name = "surf_agent_unsupervised_10distinctobj_pt/"
exp_path = log_path + exp_name
save_path = exp_path + '/stepwise_examples/'
train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=True,
                                                                load_eval=False,
                                                                load_detailed=False,
                                                                pretrained_dict=pretrained_dict,
                                                               )

In [None]:
epoch = '0'
for graph in lm_models[epoch]['LM_0'].keys():
    plot_graph(lm_models[epoch]['LM_0'][graph])
    plt.show()

In [None]:
print_unsupervised_stats(train_stats, epoch_len=10)

In [None]:
# Load unit test stats (for debugging)
exp_path = os.path.expanduser("/private/var/folders/g2/tbt9tg416wjbn478nv98kt6c0000gp/T/tmpm61wy18k")

save_path = exp_path + '/stepwise_examples/'
# save_path = os.path.expanduser("~/tbp/results/monty/figures/evidenceLM/stepwise_examples/"+exp_name)
train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=True,
                                                                load_eval=False,
                                                                load_detailed=True,
                                                               )

In [None]:
plot_graph(lm_models['3']['LM_0']['new_object0'])
plt.show()

In [None]:
plt.figure()
plt.imshow(detailed_stats['0']['SM_1']['raw_observations'][1]['rgba'])
# plt.title('mug')
plt.axis('off')
plt.show()

In [None]:
eval_stats

In [None]:
train_stats

In [None]:
lm_models.keys()

In [None]:
graph = lm_models['0']['new_object0']
color_id = graph.feature_mapping['principal_curvatures_log'][0]
curv_dir_ids = graph.feature_mapping['curvature_directions']
norm_len = 0.02

fig = plt.figure()
ax = plt.subplot(1,2,1,projection='3d')
s = ax.scatter(
    graph.pos[:, 0],
    graph.pos[:, 1],
    graph.pos[:, 2],
    s=10,
#     alpha=0.5,
    c='grey'#np.arctan(np.array(graph.x[:,color_id])/100)
)

ax.set_xticks([]), ax.set_yticks([]), ax.set_zticks([])
ax.set_aspect("equal")

graph = lm_models['1']['new_object0']
color_id = graph.feature_mapping['principal_curvatures_log'][0]
curv_dir_ids = graph.feature_mapping['curvature_directions']
norm_len = 0.02
ax = plt.subplot(1,2,2,projection='3d')
s = ax.scatter(
    graph.pos[:, 0],
    graph.pos[:, 1],
    graph.pos[:, 2],
    s=10,
#     alpha=0.5,
    c='grey'#np.arctan(np.array(graph.x[:,color_id])/100)
)

ax.set_xticks([]), ax.set_yticks([]), ax.set_zticks([])
ax.set_aspect("equal")
# ax.set_title("Learned Model with Location Noise Std=0.002 \n mug")

In [None]:
objects = ['new_object0']

episode = 1
step = 2
lm = 'LM_0'
current_evidence_update_th = -1
save_fig = False
for step in range(train_stats['monty_matching_steps'][episode]):
    plot_evidence_at_step(detailed_stats,
                          lm_models,
                              episode, 
                              step,
                              objects,
                              save_fig=save_fig)

In [None]:
len(detailed_stats[str(1)]["SM_0"]["raw_observations"])

### Show Pose Feaures

In [None]:
obj = 'mug'
graph = lm_models['pretrained'][0][obj]
# graph = lm_models['3']['LM_0']['new_object1']
pc1ispc2 = graph.feature_mapping['patch']['pose_fully_defined']
posev_ids = graph.feature_mapping['patch']['pose_vectors']
pcdir_ids = graph.feature_mapping['patch']['principal_curvatures']
norm_len = 0.01
fig = plt.figure()
ax = plt.subplot(1,1,1,projection='3d')
s = ax.scatter(
    graph.pos[:, 0],
    graph.pos[:, 1],
    graph.pos[:, 2],
    s=5,
    c=graph.x[:,pcdir_ids[0]],
    cmap='coolwarm',
    vmin=-50,
    vmax=50,
)

ax.set_xticks([]), ax.set_yticks([]), ax.set_zticks([])
# ax_range = 0.07
# ax.set_xlim(0-ax_range, 0+ax_range)
# ax.set_ylim(0.05-ax_range,0.05+ax_range)
# ax.set_zlim(0-ax_range, 0+ax_range)
ax.set_title("PC1 and its direction")
fig.colorbar(s)

for p_id, p in enumerate(np.array(graph.pos)):
    if p[0] < 0.1 and p[1] >1.52:# and p[2] < -0.1:
        norm = graph.x[p_id][posev_ids[0]:posev_ids[0]+3]
        ax.plot([p[0], p[0] + norm[0] * norm_len],
                    [p[1], p[1] + norm[1] * norm_len],
                    [p[2], p[2] + norm[2] * norm_len],
                    c=color,
                   )
        if graph.x[p_id, pc1ispc2[0]]:
            cd1 = graph.x[p_id][posev_ids[0]+3:posev_ids[0]+6]
            cd2 = graph.x[p_id][posev_ids[0]+6:posev_ids[0]+9]
            ax.plot([p[0], p[0] + cd1[0] * norm_len],
                    [p[1], p[1] + cd1[1] * norm_len],
                    [p[2], p[2] + cd1[2] * norm_len],
                    c=[int(graph.x[p_id][13]>=graph.x[p_id][14]),0,0]#'red',
                   )
            ax.plot([p[0], p[0] + cd2[0] * norm_len * 0.5],
                    [p[1], p[1] + cd2[1] * norm_len * 0.5],
                    [p[2], p[2] + cd2[2] * norm_len * 0.5],
                    c='orange',
                   )
plt.show()

## Show Features and Angles in Search Radius

In [None]:
max_match_distance = 0.01
graph = lm_models['pretrained'][0]['mug']
ids = np.where(np.linalg.norm(graph.pos - graph.pos[1],axis=1)<max_match_distance)
color_id = graph.feature_mapping['principal_curvatures_log'][0]
curv_dir_ids = graph.feature_mapping['curvature_directions']
norm_len = 0.02

def get_angles_for_all_hypotheses(hyp_f, query_f):
    dot_product = np.einsum("ijk,ik->ij", hyp_f, query_f)
    angle = np.arccos(np.clip(dot_product, -1, 1))
    return angle

fig = plt.figure()
ax = plt.subplot(1,1,1,projection='3d')
s = ax.scatter(
    graph.pos[:, 0],
    graph.pos[:, 1],
    graph.pos[:, 2],
    s=5,
    alpha=0.5,
    c='grey'#np.arctan(np.array(graph.x[:,color_id])/100)
)
s2 = ax.scatter(
    graph.pos[ids, 0],
    graph.pos[ids, 1],
    graph.pos[ids, 2],
    s=15,
    alpha=0.5,
    c='red'#np.arctan(np.array(graph.x[:,color_id])/100)
)

ax.set_xticks([]), ax.set_yticks([]), ax.set_zticks([])
# ax.set_title("Learned Locations and Norm")
ax.set_aspect("equal")

p = np.array(graph.pos)[1]
norm = np.array(graph.norm[1])
all_norms = np.array([np.array(graph.norm)[ids]])
colors = get_angles_for_all_hypotheses(all_norms, np.array([norm]))/np.pi
for c_id, p_id in enumerate(ids[0]):
    p = np.array(graph.pos)[p_id]
    norm = np.array(graph.norm[p_id])
    
    ax.plot([p[0], p[0] + norm[0] * norm_len],
            [p[1], p[1] + norm[1] * norm_len],
            [p[2], p[2] + norm[2] * norm_len],
            color=[colors[0,c_id],1-colors[0,c_id],colors[0,c_id]],
           )
plt.xlabel('x')
plt.ylabel('y')
plt.show()

## Show Evidences

In [None]:
show_initial_hypotheses(detailed_stats, 0, 'mug', save_fig=False, 
                        save_path=save_path)

In [None]:
objects = ['mug','bowl','dice','banana']

episode = 0
step = 2
lm = 'LM_0'
current_evidence_update_th = -1
save_fig = True

In [None]:
plot_evidence_at_step(detailed_stats,
                      lm_models,
                              episode,
                              objects=objects,
                              step=5,
                              save_fig=False, 
                              save_path=save_path)

In [None]:
episode = 8
objects = ["new_object0", "new_object1"]
for step in range(eval_stats['monty_matching_steps'][2]):
    plot_evidence_at_step(detailed_stats,
                          lm_models,
                              episode, 
                              step,
                              objects,
#                               is_touch_sensor=True,
                              save_fig=True, 
                              save_path=save_path)

In [None]:
detailed_stats[str(5)]['LM_0']['evidences'][0].keys()

## Plot Observations

In [None]:
episode = '0'
toshow='rgba'
obj='mug'
viz_path = os.path.expanduser("~/tbp/results/monty/figures/vizanimations/visionLM/")

for step in range(len(detailed_stats[episode]['SM_0']['raw_observations'])):
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(detailed_stats[episode]['SM_1']['raw_observations'][step][toshow])
    plt.axis('off')
    plt.title("View Finder")
    plt.subplot(1,2,2)
    plt.imshow(detailed_stats[episode]['SM_0']['raw_observations'][step][toshow])
    plt.axis('off')
    plt.title("Patch 0")
    plt.savefig(
            viz_path + f"visionLMObs_{obj}_{step}.png",
            bbox_inches="tight",
        )

## 5LMs

In [None]:
episode = '0'
# step = 40
toshow='rgba'
viz_path = os.path.expanduser("~/tbp/results/monty/figures/vizanimations/fiveLMs/")

for step in range(len(detailed_stats[episode]['SM_5']['raw_observations'])):
    plt.figure()
    plt.subplot(2,3,1)
    plt.imshow(detailed_stats[episode]['SM_5']['raw_observations'][step][toshow])
    plt.axis('off')
    plt.title("View Finder")
    plt.subplot(2,3,2)
    plt.imshow(detailed_stats[episode]['SM_0']['raw_observations'][step][toshow])
    plt.axis('off')
    plt.title("Patch 0")
    plt.subplot(2,3,3)
    plt.imshow(detailed_stats[episode]['SM_1']['raw_observations'][step][toshow])
    plt.axis('off')
    plt.title("Patch 1")
    plt.subplot(2,3,4)
    plt.imshow(detailed_stats[episode]['SM_2']['raw_observations'][step][toshow])
    plt.axis('off')
    plt.title("Patch 2")
    plt.subplot(2,3,5)
    plt.imshow(detailed_stats[episode]['SM_3']['raw_observations'][step][toshow])
    plt.axis('off')
    plt.title("Patch 3")
    plt.subplot(2,3,6)
    plt.imshow(detailed_stats[episode]['SM_4']['raw_observations'][step][toshow])
    plt.axis('off')
    plt.title("Patch 4")
    plt.savefig(
            viz_path + f"fiveLMObs_{obj}_{step}.png",
            bbox_inches="tight",
        )

In [None]:
detailed_stats.keys()

## YCB Objects

In [None]:
from tbp.monty.frameworks.environments.ycb import YCB_OBJECTS_LIST
path_to_images = os.path.expanduser("~/tbp/results/monty/figures/all_object_views/")

In [None]:
# rotation_to_show = '0_0_0'
rotation_to_show = '90_180_270'
plt.figure(figsize=(25,15))
for i, obj_name in enumerate(YCB_OBJECTS_LIST):
    img_path = path_to_images + obj_name + '_' + rotation_to_show + '.png'
    img = plt.imread(img_path)
    plt.subplot(8, 12, i+1)
    plt.imshow(img[28:]) # Don't show title from image
    plt.axis('off')
    plt.title(obj_name)
# plt.suptitle('All YCB objects in rotation ' + rotation_to_show.replace('_',', '), fontsize=20)
plt.show()

In [None]:
obj_to_show = 'potted_meat_can'
plt.figure(figsize=(7,15))
i = 1
for file in os.listdir(path_to_images):
    if file.startswith(obj_to_show):
        img_path = path_to_images + file
        img = plt.imread(img_path)
        plt.subplot(8, 4, i)
        plt.imshow(img[28:]) # Don't show title from image
        plt.axis('off')
        plt.title(file[len(obj_to_show) + 1:-4].replace('_', ', '))
        i += 1
plt.show()

# Experimenting with stuff (delete eventually)

In [None]:
def get_angle(vec1, vec2):
#     unit_vector_1 = vec1 / np.linalg.norm(vec1)
#     unit_vector_2 = vec2 / np.linalg.norm(vec2)
    dot_product = np.dot(vec1, vec2)
    angle = np.arccos(np.clip(dot_product, -1, 1))
    return angle

def get_pose_error(query_features, node_features, weights=[1, 1, 1]):

    pn_error = get_angle(
        query_features["point_normal"],
        node_features["point_normal"],
    )
    print(pn_error)
    cd1_error = get_angle(
        query_features["curvature_directions"][:3],
        node_features["curvature_directions"][:3],
    )
    cd2_error = get_angle(
        query_features["curvature_directions"][3:],
        node_features["curvature_directions"][3:],
    )
    overall_error = np.array([pn_error, cd1_error, cd2_error]) * weights

    return np.sum(overall_error)

def get_node_features(graph, node_id, feature_keys):
    node_features = {}
    for key in feature_keys:
        key_ids = graph.feature_mapping[key]
        feature = graph.x[node_id, key_ids[0] : key_ids[1]].clone()
        node_features[key] = feature
    return node_features

In [None]:
dir1 = np.array([-0.07582931, -0.99661191, -0.03185299])
dir2 = np.array([0.996991  , -0.07531361,-0.01835211])

In [None]:
import math

In [None]:
math.degrees(get_angle(dir1, dir2))

In [None]:
pn = detailed_stats['0']['LM_0']['point_normal'][0]
cds = detailed_stats['0']['LM_0']['curvature_directions'][0]
qf = [pn, cds[:3], cds[3:]]
qff = {'point_normal':pn, 'curvature_directions':cds}

In [None]:
nfs = get_node_features(graph, np.linspace(0,19,20), ['point_normal','curvature_directions'])
nfsm = np.hstack([nfs['point_normal'], 
                  nfs['curvature_directions'][:,:3], 
                  nfs['curvature_directions'][:,3:]])

In [None]:
all_pe = []
for i in range(nfs['point_normal'].shape[0]):
    nff = {'point_normal': nfs['point_normal'][i],
           'curvature_directions': nfs['curvature_directions'][i]}
    pe = get_pose_error(qff,nff)
    all_pe.append(pe)
print(all_pe)

In [None]:
def get_pose_error_matrix(query_features, node_features):
    # node features shape: [n, 3]
    # query features shape: [3]
    angles = get_angle(node_features, query_features)
    return angles

In [None]:
from scipy.spatial.transform import Rotation

In [None]:
qf_o = np.array(qf)
qf_o

In [None]:
rotation = Rotation.from_euler('xyz', [50,10,0],degrees=True)
qf_t = rotation.apply(qf_o)
qf_t

In [None]:
r_scipy, err =  Rotation.align_vectors(qf_o, qf_t)
print(np.round(r_scipy.as_euler('xyz', degrees=True),2))
print(r_scipy.as_quat())
print(np.round(r_scipy.as_matrix(),2))

In [None]:
def check_orthonormal(matrix):
    is_orthogonal = np.mean(np.abs((np.linalg.inv(matrix) - matrix.T))) < 0.01
    if not is_orthogonal:
        print(
            f"not orthogonal. Error: {np.mean(np.abs((np.linalg.inv(matrix) - matrix.T)))}"
        )
    is_normal = (
        np.mean(np.abs(np.linalg.norm(matrix, axis=1) - [1, 1, 1])) < 0.01
    )
    if not is_normal:
        print(
            f"not normal. Error: {np.mean(np.abs(np.linalg.norm(matrix, axis=1) - [1, 1, 1]))}"
        )
    return is_orthogonal and is_normal

print(check_orthonormal(qf_o))
print(check_orthonormal(qf_t))

In [None]:
nfs = get_node_features(graph, np.linspace(0,199,200), ['point_normal','curvature_directions'])
nfsm = np.hstack([nfs['point_normal'], 
              nfs['curvature_directions'][:,:3], 
              nfs['curvature_directions'][:,3:]])

In [None]:
for i, nf in enumerate(nfsm):
    nf_rs = nf.reshape((3,3))
    if not check_orthonormal(nf_rs):
        print(i)
        print(nf_rs)

In [None]:
def get_right_hand_angle(v1, v2, pn):
    rha = np.arctan2(np.dot(np.cross(v1, v2), pn), np.dot(v1, v2))
    return rha

In [None]:
get_right_hand_angle(nf_rs[1], nf_rs[2], nf_rs[0])

In [None]:
get_right_hand_angle([1,0,0], [0,0,-1], [0,1,0])

In [None]:
def align_orthonormal_vectors(m1, m2):
    assert check_orthonormal(m1), "m1 is not orthonormal"
    assert check_orthonormal(m2), "m2 is not orthonormal"
    # to get rotation matrix between m1 and m2 calculate
    # m1*m2.T (can do m2.T instead of m2.inv because vectors
    # are orthogonal). Since vectors here are rows instead
    # of columns we apply T to m1 instead of m2.
    rot_mat = np.matmul(m1.T, m2)
    rotation = Rotation.from_matrix(rot_mat).inv()
    error = np.mean(np.abs(rotation.apply(m1) - m2))
    return rotation, error

def align_multiple_orthonormal_vectors(ms1, ms2, as_scipy=True):
    transpose_mts1 = np.transpose(ms1, axes=[0, 2, 1])
    rot_mats = np.matmul(transpose_mts1, ms2)
    if as_scipy:
        all_rotations = []
        for rot_mat in rot_mats:
            all_rotations.append(Rotation.from_matrix(rot_mat))
        return all_rotations
    else:
        return rot_mats

In [None]:
r_alt, err = align_orthonormal_vectors(qf_o, qf_t)

In [None]:
nodefs = nfsm.reshape((nfsm.shape[0], 3, 3))
first_v = nodefs[0]
nodefs[0]

In [None]:
rot_mats = align_multiple_orthonormal_vectors(nodefs, first_v, as_scipy=False)

In [None]:
np.round(rot_mats.dot(first_v[0]) - nodefs[:,0],2)

In [None]:
nf = nodefs[:,:2,:]
qf = nodefs[:,2,:]
print(nf.shape)
print(qf.shape)

In [None]:
def get_angles_for_all_hypotheses(hyp_f, query_f):
    """
    hyp_f shape = (num_hyp, num_nn, 3)
    query_f shape = (num_hyp, 3)
        for each hypothesis we want to get num_nn angles.
    return shape = (num_hyp, num_nn)
    """
    dot_product = np.einsum("ijk,ik->ij",hyp_f, query_f)
    angle = np.arccos(np.clip(dot_product, -1, 1))
    return angle

In [None]:
get_angles_for_all_hypotheses(nf, qf)

In [None]:
angle_err = np.linspace(0,np.pi/2,100)

In [None]:
angle_err

In [None]:
plt.figure()
plt.plot(angle_err, -(np.sin(angle_err)-0.5))
plt.xlabel('angle error (radians)')
plt.ylabel('evidence')
plt.show()

In [None]:
a = np.random.normal(0,0.1,100)

In [None]:
plt.figure()
plt.hist(a)
plt.title("standard deviation = 0.1")
plt.show()

#### Point Normal Calculation

In [None]:
plt.figure()
plt.imshow(detailed_stats['3']['SM_0']['raw_observations'][0]['depth'])
plt.title('mug')
plt.axis('off')
plt.show()

In [None]:
detailed_stats['0']['SM_0'].keys()

In [None]:
from tbp.monty.frameworks.environment_utils.transforms import DepthTo3DLocations
from tbp.monty.frameworks.environment_utils.graph_utils import get_point_normal, get_center_point_normal
import quaternion as qt

In [None]:
transform = DepthTo3DLocations(
                agent_id="agent_id_0",
                sensor_ids=["patch"],
                resolutions=[[64,64]],
                world_coord=True,
                zooms=[10],
                get_all_points=True,
                use_semantic_sensor=True,
            )

In [None]:
pc = np.array(detailed_stats['3']['SM_0']['raw_observations'][0]['semantic_3d'])
pc.shape

In [None]:
obs_dim = int(np.sqrt(pc.shape[0]))
half_obs_dim = obs_dim // 2
center_id = half_obs_dim + obs_dim * half_obs_dim
on_obj = pc[:, 3] > 0
adjusted_center_id = sum(on_obj[:center_id])
oopc = pc[on_obj, :3]

In [None]:
norm = detailed_stats['3']['SM_0']['processed_observations'][0]['features']['point_normal']
cd1 = detailed_stats['3']['SM_0']['processed_observations'][0]['features']['curvature_directions'][:3]

In [None]:
plt.figure()
ax = plt.subplot(1,1,1,projection='3d')
s = ax.scatter(
    pc[:, 0],
    pc[:, 1],
    pc[:, 2],
    c=pc[:, 3],
    alpha=0.05
)
s = ax.scatter(
    pc[center_id, 0],
    pc[center_id, 1],
    pc[center_id, 2],
    s=50,
    c='red'
)
cp = pc[center_id]
norm_len = 0.01
ax.plot([cp[0], cp[0] + norm[0] * norm_len],
            [cp[1], cp[1] + norm[1] * norm_len],
            [cp[2], cp[2] + norm[2] * norm_len],
            c='blue',
           )
ax.plot([cp[0], cp[0] + cd1[0] * norm_len],
        [cp[1], cp[1] + cd1[1] * norm_len],
        [cp[2], cp[2] + cd1[2] * norm_len],
        c='red',
       )

plt.show()

In [None]:
obs_dim = int(np.sqrt(pc.shape[0]))
half_obs_dim = obs_dim // 2
center_id = half_obs_dim + obs_dim * half_obs_dim

In [None]:
obs_dim*0.156

In [None]:
64//1

In [None]:
10/64

In [None]:
center_id_1up = half_obs_dim + obs_dim * (half_obs_dim - 1)
center_id_1down = half_obs_dim + obs_dim * (half_obs_dim + 1)

In [None]:
def get_center_point_normal(point_cloud, patch_radius_frac=5):
    obs_dim = int(np.sqrt(point_cloud.shape[0]))
    half_obs_dim = obs_dim // 2
    center_id = half_obs_dim + obs_dim * half_obs_dim
    tan_len = obs_dim // patch_radius_frac
    center_id_up = half_obs_dim + obs_dim * (half_obs_dim - tan_len)
    center_id_down = half_obs_dim + obs_dim * (half_obs_dim + tan_len)

    vecup = point_cloud[center_id_up, :3] - point_cloud[center_id, :3]
    vecdown = point_cloud[center_id_down, :3] - point_cloud[center_id, :3]
    vecright = point_cloud[center_id + tan_len, :3] - point_cloud[center_id, :3]
    vecleft = point_cloud[center_id - tan_len, :3] - point_cloud[center_id, :3]

    vecup_norm = vecup / np.linalg.norm(vecup)
    vecdown_norm = vecdown / np.linalg.norm(vecdown)
    vecright_norm = vecright / np.linalg.norm(vecright)
    vecleft_norm = vecleft / np.linalg.norm(vecleft)

    cp = lambda v1, v2: np.cross(
        v1, v2
    )  # comment in to avoid annoying pylance bug
    norm1, norm2 = None, None
    if (point_cloud[center_id_up, 3] > 0) and (
        point_cloud[center_id + tan_len, 3] > 0
    ):
        norm1 = -cp(vecup_norm, vecright_norm)
    else:
        print("first off")
    if (point_cloud[center_id_down, 3] > 0) and (
        point_cloud[center_id - tan_len, 3] > 0
    ):
        norm2 = -cp(vecdown_norm, vecleft_norm)
    else:
        print("second off")
    if norm1 is None:
        norm1 = norm2
    if norm2 is None:
        norm2 = norm1
    if norm1 is None and norm2 is None:
        print("Too many off object points around center")
        norm1 = norm2 = [0, 0, 1]
    print(norm1)
    print(norm2)
    norm = np.mean([norm1, norm2], axis=0)
    # norm = np.cross(vec1_norm, vec2_norm)
    norm = norm / np.linalg.norm(norm)
    return vecdown_norm, vecleft_norm, norm, norm1, norm2

In [None]:
vec1_norm, vec2_norm, norm_alt, norm1, norm2 = get_center_point_normal(pc)

In [None]:
plt.figure()
to_plot = np.array([center_id-2, center_id, center_id +2, center_id_1up, center_id_1down])
ax = plt.subplot(1,1,1,projection='3d')
s = ax.scatter(
    pc[to_plot, 0],
    pc[to_plot, 1],
    pc[to_plot, 2],
    c=pc[to_plot, 3],
)
s = ax.scatter(
    pc[center_id, 0],
    pc[center_id, 1],
    pc[center_id, 2],
    s=5,
    c='red'
)
cp = pc[center_id]
norm_len = 0.0001

ax.plot([cp[0], cp[0] + vec1_norm[0] * norm_len],
        [cp[1], cp[1] + vec1_norm[1] * norm_len],
        [cp[2], cp[2] + vec1_norm[2] * norm_len],
        c='red',
       )
ax.plot([cp[0], cp[0] + vec2_norm[0] * norm_len],
        [cp[1], cp[1] + vec2_norm[1] * norm_len],
        [cp[2], cp[2] + vec2_norm[2] * norm_len],
        c='red',
       )
ax.plot([cp[0], cp[0] + norm_alt[0] * norm_len*2],
        [cp[1], cp[1] + norm_alt[1] * norm_len*2],
        [cp[2], cp[2] + norm_alt[2] * norm_len*2],
        c='green',
       )
ax.plot([cp[0], cp[0] + norm[0] * norm_len],
            [cp[1], cp[1] + norm[1] * norm_len],
            [cp[2], cp[2] + norm[2] * norm_len],
            c='blue',
           )
ax.plot([cp[0], cp[0] + norm1[0] * norm_len],
            [cp[1], cp[1] + norm1[1] * norm_len],
            [cp[2], cp[2] + norm1[2] * norm_len],
            c='pink',
           )
ax.plot([cp[0], cp[0] + norm2[0] * norm_len],
            [cp[1], cp[1] + norm2[1] * norm_len],
            [cp[2], cp[2] + norm2[2] * norm_len],
            c='yellow',
           )
ax.set_aspect('equal')
plt.show()