In [1]:
import os
import sys
import glob
import torch
import numpy as np
import open3d as o3d
import meshplot as mp

In [2]:
base_dir = os.path.dirname(os.getcwd())
out_dir = os.path.join(base_dir, "results", "models")
data_dir = os.path.join(base_dir, "data", "original")
train_files = glob.glob(os.path.join(data_dir, "train", "*", "*.obj"))
valid_files = glob.glob(os.path.join(data_dir, "val", "*", "*.obj"))
print(len(train_files), len(valid_files))

src_dir = os.path.join(base_dir, "src")
sys.path.append(os.path.join(src_dir))

7003 1088


In [3]:
from utils_polygen import load_pipeline
from pytorch_trainer import Trainer, Reporter
from models import FacePolyGenConfig, FacePolyGen, VertexPolyGenConfig, VertexPolyGen

In [4]:
def read_objfile(file_path):
    vertices = []
    normals = []
    faces = []
    
    with open(file_path) as fr:
        for line in fr:
            data = line.split()
            if len(data) > 0:
                if data[0] == "v":
                    vertices.append(data[1:])
                elif data[0] == "vn":
                    normals.append(data[1:])
                elif data[0] == "f":
                    face = np.array([
                        [int(p.split("/")[0]), int(p.split("/")[2])]
                        for p in data[1:]
                    ]) - 1
                    faces.append(face)
    
    vertices = np.array(vertices, dtype=np.float32)
    normals = np.array(normals, dtype=np.float32)
    return vertices, normals, faces

def read_objfile_for_validate(file_path, return_o3d=False):
    # only for develop-time validation purpose.
    # this func force to load .obj file as triangle-mesh.
    
    obj = o3d.io.read_triangle_mesh(file_path)
    if return_o3d:
        return obj
    else:
        v = np.asarray(obj.vertices, dtype=np.float32)
        f = np.asarray(obj.triangles, dtype=np.int32)
        return v, f

def write_objfile(file_path, vertices, normals, faces):
    # write .obj file input-obj-style (mainly, header string is copy and paste).
    
    with open(file_path, "w") as fw:
        print("# Blender v2.82 (sub 7) OBJ File: ''", file=fw)
        print("# www.blender.org", file=fw)
        print("o test", file=fw)
        
        for v in vertices:
            print("v " + " ".join([str(c) for c in v]), file=fw)
        print("# {} vertices\n".format(len(vertices)), file=fw)
        
        for n in normals:
            print("vn " + " ".join([str(c) for c in n]), file=fw)
        print("# {} normals\n".format(len(normals)), file=fw)
            
        for f in faces:
            print("f " + " ".join(["{}//{}".format(c[0]+1, c[1]+1) for c in f]), file=fw)
        print("# {} faces\n".format(len(faces)), file=fw)
        
        print("# End of File", file=fw)

def validate_pipeline(v, n, f, out_dir):
    temp_path = os.path.join(out_dir, "temp.obj")
    write_objfile(temp_path, v, n, f)
    v_valid, f_valid = read_objfile_for_validate(temp_path)
    print(v_valid.shape, f_valid.shape)
    mp.plot(v_valid, f_valid)

In [5]:
now_state = "lamp"
indeces = {
    "lamp": 0,
}
for i, path in enumerate(train_files):
    state = path.split("/")[9]
    if now_state != state:
        now_state = state
        indeces[state] = i
print(indeces)

now_state = "lamp"
indeces = {
    "lamp": 0,
}
for i, path in enumerate(valid_files):
    state = path.split("/")[9]
    if now_state != state:
        now_state = state
        indeces[state] = i
print(indeces)

{'lamp': 0, 'basket': 402, 'chair': 452, 'sofa': 2294, 'table': 3231}
{'lamp': 0, 'basket': 60, 'chair': 66, 'sofa': 388, 'table': 517}


In [6]:
mode2files = {
    0: train_files,
    1: valid_files,
}

In [18]:
mode = 0
#idx = 458
idx = 460
#mode = 1
#idx = 458
vertices, normals, faces = read_objfile(mode2files[mode][idx])
print(vertices.shape, normals.shape, len(faces))
validate_pipeline(vertices, normals, faces, out_dir)

(58, 3) (18, 3) 31
(174, 3) (112, 3)


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

In [19]:
vs, ns, fs = load_pipeline(mode2files[mode][idx], remove_normal_ids=False)
validate_pipeline(vs, ns, fs, out_dir)

(174, 3) (112, 3)


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…

In [20]:
config = FacePolyGenConfig(embed_dim=128, src__reformer__depth=9, tgt__reformer__depth=9)
model = FacePolyGen(config)
ckpt = torch.load(os.path.join(out_dir, "model_epoch_47"), map_location=torch.device('cpu'))
model.load_state_dict(ckpt['state_dict'])

src__max_seq_len changed, because of lsh-attention's bucket_size
before: 2400 --> after: 2592 (with bucket_size: 48)
tgt__max_seq_len changed, because of lsh-attention's bucket_size
before: 5600 --> after: 5664 (with bucket_size: 48)


<All keys matched successfully>

In [21]:
inputs = {"vertices": [torch.tensor(vs)]}
lengths = [len(f) for f in fs]
print(sum(lengths))
[f[:, 0] for f in fs]

174


[array([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 44, 45, 47,
        49, 51]),
 array([57, 51, 36, 38,  1,  7, 15, 11, 19, 23]),
 array([57, 23, 22, 56]),
 array([56, 22, 18, 53]),
 array([55, 52, 17, 21]),
 array([55, 21, 20, 54]),
 array([54, 20, 16,  8, 12,  4,  0, 37, 35, 50]),
 array([53, 18, 19, 11, 10,  3,  2,  9,  8, 16, 17, 52]),
 array([51, 49, 34, 36]),
 array([50, 35, 33, 48]),
 array([49, 47, 32, 34]),
 array([48, 33, 31, 46]),
 array([47, 45, 30, 32]),
 array([46, 31, 27, 42]),
 array([45, 44, 29, 30]),
 array([44, 41, 26, 29]),
 array([43, 42, 27, 28]),
 array([43, 28, 24, 39]),
 array([41, 40, 25, 26]),
 array([40, 39, 24, 25]),
 array([38, 37,  0,  1]),
 array([38, 36, 34, 32, 30, 29, 26, 25, 24, 28, 27, 31, 33, 35, 37]),
 array([23, 19, 18, 22]),
 array([21, 17, 16, 20]),
 array([15, 14, 10, 11]),
 array([15,  7,  6, 14]),
 array([14,  6,  3, 10]),
 array([13, 12,  8,  9]),
 array([13,  9,  2,  5]),
 array([13,  5,  4, 12]),
 array([7, 1, 0, 4, 5, 2, 3,

In [22]:
model.eval()
with torch.no_grad():
    pred = model.predict(inputs, seed=0, max_seq_len=sum(lengths))
    # pred = model.predict(inputs, seed=0, max_seq_len=83)

0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, tensor([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 45, 47, 49, 51])
19, 20, 21, 22, 23, tensor([57, 51, 36, 49])
24, 25, 26, 27, 28, tensor([57, 23, 22, 56])
29, 30, 31, 32, 33, 34, 35, 36, 37, tensor([56, 22, 18, 10,  3,  2, 16, 20])
38, 39, 40, 41, 42, tensor([55, 52, 17, 53])
43, 44, 45, 46, 47, tensor([47, 32, 28, 43])
48, 49, 50, 51, 52, tensor([47, 45, 30, 29])
53, 54, 55, 56, 57, tensor([44, 40, 25, 39])
58, 59, 60, 61, 62, 63, 64, tensor([41, 40, 25, 26, 29, 30])
65, 66, 67, 68, 69, tensor([38, 37, 35, 36])
70, 71, 72, 73, 74, tensor([23, 22, 21,  5])
75, 76, 77, 78, 79, tensor([23, 19, 18, 22])
80, 81, 82, 83, 84, tensor([19, 11, 10, 18])
85, 86, 87, 88, 89, 90, 91, 92, 93, 

In [24]:
pred

[tensor([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 45, 47, 49, 51]),
 tensor([57, 51, 36, 49]),
 tensor([57, 23, 22, 56]),
 tensor([56, 22, 18, 10,  3,  2, 16, 20]),
 tensor([55, 52, 17, 53]),
 tensor([47, 32, 28, 43]),
 tensor([47, 45, 30, 29]),
 tensor([44, 40, 25, 39]),
 tensor([41, 40, 25, 26, 29, 30]),
 tensor([38, 37, 35, 36]),
 tensor([23, 22, 21,  5]),
 tensor([23, 19, 18, 22]),
 tensor([19, 11, 10, 18]),
 tensor([ 7,  1,  0,  4, 12,  8,  3,  6])]

In [25]:
faces = []
for f in pred[:-1]:
    if len(f) <= 2:
        continue
    f = f[:, None].repeat(1, 2)
    faces.append(f.numpy())

In [26]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(vs)
pcd.estimate_normals(
    search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30)
)
normals = np.asarray(pcd.normals)

In [27]:
vs.shape, normals.shape

((58, 3), (58, 3))

In [28]:
print(vs.shape, normals.shape, len(faces))
validate_pipeline(vertices, normals, faces, out_dir)

(58, 3) (58, 3) 13
(41, 3) (40, 3)


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…