-
Notifications
You must be signed in to change notification settings - Fork 1
/
infer.py
65 lines (48 loc) · 1.89 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sys
from absl import flags, app, logging
from utils.model_3D import *
import numpy as np
import pickle
from os.path import join as j
from utils.metrics import chamfer_distance, mean_p2p_distance
from utils.gp_utils import (
direction_distance_given_class,
distance_from_centers,
ClusteredGPs,
)
from utils.common import *
from vis.vis import *
from rich.progress import track
from rich.console import Console
sys.path.append("..")
FLAGS = flags.FLAGS
flags.DEFINE_bool("vis", False, "Visualize gp output or not")
flags.DEFINE_string("model_3d", None, "Test model 3D")
flags.DEFINE_boolean('normalize',False,'Unit sphere normalization')
def main(args):
# load the test points
points = load3DModel(jn(MODEL_PATH,'test',FLAGS.model_3d+'.ply'))
np.random.shuffle(points)
if FLAGS.normalize:
# scale to unit sphere
points = fitModel2UnitSphere(points, buffer=1.03)
# load the fitted kmeans centers
with open(jn(RESULTS_PATH,FLAGS.model_3d,'kmeans.pkl'), "rb") as f:
kmeans = pickle.load(f)
centers = kmeans.cluster_centers_
class_idxs = kmeans.predict(points.astype("double"))
# get distance each point from the reference points
distances = distance_from_centers(points, centers, class_idxs)
_, phi_thetas, ds_observed, sorted_indices = direction_distance_given_class(
points, distances, centers, class_idxs
)
# load the trained GP's
cls_gps = ClusteredGPs(centers)
cls_gps.__load__(jn(RESULTS_PATH,FLAGS.model_3d))
_, xyz, _ = cls_gps.predict_(phi_thetas, centers, class_idxs)
cls_gps.eval(points[sorted_indices], xyz)
export_3D_points(xyz, jn(RESULTS_PATH,FLAGS.model_3d,"out_points.txt"))
if FLAGS.vis:
vis_pcd_open3D(xyz, jn(RESULTS_PATH,FLAGS.model_3d,"./pcd_vis.png"))
if __name__ == "__main__":
app.run(main)