# 1. mesh读取测试

In [1]:
import kaolin as kal
import numpy as np
import torch
import torch.nn.functional as F
import math
import os
from packaging import version


In [2]:
import torch
import torch.nn.functional as F

from packaging import version

def grid_sample_bilinear(input, grid):
    # PyTorch 1.3 introduced an API change (breaking change in version 1.4), therefore we check this explicitly
    # to make sure that the behavior is consistent across different versions
    if version.parse(torch.__version__) < version.parse('1.3'):
        return F.grid_sample(input, grid, mode='bilinear')
    else:
        return F.grid_sample(input, grid, mode='bilinear', align_corners=True)


def symmetrize_texture(x):
    # Apply even symmetry along the x-axis (from length N to 2N)
    # 先沿最后一个维度反转
    x_flip = torch.flip(x, (len(x.shape) - 1,))
    # 翻转后，将原来最后一维（列），向左和向右分别复制一半，这样拓宽多出一倍
    return torch.cat((x_flip[:, :, :, x_flip.shape[3]//2:], x, x_flip[:, :, :, :x_flip.shape[3]//2]), dim=-1)


def adjust_poles(tex):
    # Average top and bottom rows (corresponding to poles) -- for mesh only
    top = tex[:, :, :1].mean(dim=3, keepdim=True).expand(-1, -1, -1, tex.shape[3])
    middle = tex[:, :, 1:-1]
    bottom = tex[:, :, -1:].mean(dim=3, keepdim=True).expand(-1, -1, -1, tex.shape[3])
    return torch.cat((top, middle, bottom), dim=2)
    

def circpad(x, amount=1):
    # Circular padding along x-axis (before a convolution)
    left = x[:, :, :, :amount]
    right = x[:, :, :, -amount:]
    return torch.cat((right, x, left), dim=3)


def qrot(q, v):
    """
    Quaternion-vector multiplication (rotation of a vector)
    """
    assert q.shape[-1] == 4
    assert v.shape[-1] == 3
    
    qvec = q[:, 1:].unsqueeze(1).expand(-1, v.shape[1], -1)
    uv = torch.cross(qvec, v, dim=2)
    uuv = torch.cross(qvec, uv, dim=2)
    return v + 2 * (q[:, :1].unsqueeze(1) * uv + uuv)

def qmul(q, r):
    """
    Quaternion-quaternion multiplication
    """
    assert q.shape[-1] == 4
    assert r.shape[-1] == 4
    
    original_shape = q.shape
    
    # Compute outer product
    terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))

    w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
    x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
    y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
    z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
    return torch.stack((w, x, y, z), dim=1).view(original_shape)

In [13]:
class MeshTemplate:
    
    def __init__(self, mesh_path, is_symmetric=True):
        
        MeshTemplate._monkey_patch_dependencies()
        
        mesh = kal.rep.TriangleMesh.from_obj(mesh_path, enable_adjacency=True)
        mesh.cuda()
        
        print('---- Mesh definition ----')
        print(f'Vertices: {mesh.vertices.shape}')
        print(f'Indices: {mesh.faces.shape}')
        print(f'UV coords: {mesh.uvs.shape}')
        print(f'UV indices: {mesh.face_textures.shape}')

        poles = [mesh.vertices[:, 1].argmax().item(), mesh.vertices[:, 1].argmin().item()] # North pole, south pole

        # Compute reflection information (for mesh symmetry)
        axis = 0
        if version.parse(torch.__version__) < version.parse('1.2'):
            neg_indices = torch.nonzero(mesh.vertices[:, axis] < -1e-4)[:, 0].cpu().numpy()
            zero_indices = torch.nonzero(torch.abs(mesh.vertices[:, axis]) < 1e-4)[:, 0].cpu().numpy()
        else:
            neg_indices = torch.nonzero(mesh.vertices[:, axis] < -1e-4, as_tuple=False)[:, 0].cpu().numpy()
            zero_indices = torch.nonzero(torch.abs(mesh.vertices[:, axis]) < 1e-4, as_tuple=False)[:, 0].cpu().numpy()
            
        pos_indices = []
        for idx in neg_indices:
            opposite_vtx = mesh.vertices[idx].clone()
            opposite_vtx[axis] *= -1
            dists = (mesh.vertices - opposite_vtx).norm(dim=-1)
            minval, minidx = torch.min(dists, dim=0)
            # assert minval < 1e-4, minval
            pos_indices.append(minidx.item())
        # assert len(pos_indices) == len(neg_indices)
        # assert len(pos_indices) == len(set(pos_indices)) # No duplicates
        pos_indices = np.array(pos_indices)

        pos_indices = torch.LongTensor(pos_indices).cuda()
        neg_indices = torch.LongTensor(neg_indices).cuda()
        zero_indices = torch.LongTensor(zero_indices).cuda()
        nonneg_indices = torch.LongTensor(list(pos_indices) + list(zero_indices)).cuda()

        total_count = len(pos_indices) + len(neg_indices) + len(zero_indices)
        # assert total_count == len(mesh.vertices), (total_count, len(mesh.vertices))

        index_list = {}
        segments = 32
        rings = 31 if '31rings' in mesh_path else 16
        print(f'The mesh has {rings} rings')
        print('-------------------------')
        for faces, vertices in zip(mesh.face_textures, mesh.faces):
            for face, vertex in zip(faces, vertices):
                if vertex.item() not in index_list:
                    index_list[vertex.item()] = []
                res = mesh.uvs[face].cpu().numpy() * [segments, rings]
                if math.isclose(res[0], segments, abs_tol=1e-4):
                    res[0] = 0 # Wrap around
                index_list[vertex.item()].append(res)

        topo_map = torch.zeros(mesh.vertices.shape[0], 2)
        for idx, data in index_list.items():
            avg = np.mean(np.array(data, dtype=np.float32), axis=0) / [segments, rings]
            topo_map[idx] = torch.Tensor(avg)

        # Flip topo map
        topo_map = topo_map * 2 - 1
        topo_map = topo_map * torch.FloatTensor([1, -1]).to(topo_map.device)
        topo_map = topo_map.cuda()
        nonneg_topo_map = topo_map[nonneg_indices]

        # Force x = 0 for zero-indices if symmetry is enabled
        symmetry_mask = torch.ones_like(mesh.vertices).unsqueeze(0)
        symmetry_mask[:, zero_indices, 0] = 0

        # Compute mesh tangent map (per-vertex normals, tangents, and bitangents)
        mesh_normals = F.normalize(mesh.vertices, dim=1)
        up_vector = torch.Tensor([[0, 1, 0]]).to(mesh_normals.device).expand_as(mesh_normals)
        mesh_tangents = F.normalize(torch.cross(mesh_normals, up_vector, dim=1), dim=1)
        mesh_bitangents = torch.cross(mesh_normals, mesh_tangents, dim=1)
        # North pole and south pole have no (bi)tangent
        mesh_tangents[poles[0]] = 0
        mesh_bitangents[poles[0]] = 0
        mesh_tangents[poles[1]] = 0
        mesh_bitangents[poles[1]] = 0
        
        tangent_map = torch.stack((mesh_normals, mesh_tangents, mesh_bitangents), dim=1).cuda()
        nonneg_tangent_map = tangent_map[nonneg_indices] # For symmetric meshes
        
        self.mesh = mesh
        self.topo_map = topo_map
        self.nonneg_topo_map = nonneg_topo_map
        self.nonneg_indices = nonneg_indices
        self.neg_indices = neg_indices
        self.pos_indices = pos_indices
        self.symmetry_mask = symmetry_mask
        self.tangent_map = tangent_map
        self.nonneg_tangent_map = nonneg_tangent_map
        self.is_symmetric = is_symmetric
        
    def deform(self, deltas):
        """
        Deform this mesh template along its tangent map, using the provided vertex displacements.
        """
        # tangent_map : precomputed rotation matrix
        tgm = self.nonneg_tangent_map if self.is_symmetric else self.tangent_map
        # R@delta
        return (deltas.unsqueeze(-2) @ tgm.expand(deltas.shape[0], -1, -1, -1)).squeeze(-2)

    def compute_normals(self, vertex_positions):
        """
        Compute face normals from the *final* vertex positions (not deltas).
        """
        a = vertex_positions[:, self.mesh.faces[:, 0]]
        b = vertex_positions[:, self.mesh.faces[:, 1]]
        c = vertex_positions[:, self.mesh.faces[:, 2]]
        v1 = b - a
        v2 = c - a
        normal = torch.cross(v1, v2, dim=2)
        return F.normalize(normal, dim=2)

    def get_vertex_positions(self, displacement_map):
        """
        Deform this mesh template using the provided UV displacement map.
        Output: 3D vertex positions in object space.
        """
        topo = self.nonneg_topo_map if self.is_symmetric else self.topo_map
        _, displacement_map_padded = self.adjust_uv_and_texture(displacement_map)
        if self.is_symmetric:
            # Compensate for even symmetry in UV map
            delta = 1/(2*displacement_map.shape[3])
            expansion = (displacement_map.shape[3]+1)/displacement_map.shape[3]
            topo = topo.clone()
            topo[:, 0] = (topo[:, 0] + 1 + 2*delta - expansion)/expansion # Only for x axis
        topo_expanded = topo.unsqueeze(0).unsqueeze(-2).expand(displacement_map.shape[0], -1, -1, -1)
        vertex_deltas_local = grid_sample_bilinear(displacement_map_padded, topo_expanded).squeeze(-1).permute(0, 2, 1)
        vertex_deltas = self.deform(vertex_deltas_local)
        if self.is_symmetric:
            # Symmetrize
            vtx_n = torch.Tensor(vertex_deltas.shape[0], self.topo_map.shape[0], 3).to(vertex_deltas.device)
            vtx_n[:, self.nonneg_indices] = vertex_deltas
            vtx_n2 = vtx_n.clone()
            vtx_n2[:, self.neg_indices] = vtx_n[:, self.pos_indices] * torch.Tensor([-1, 1, 1]).to(vtx_n.device)
            vertex_deltas = vtx_n2 * self.symmetry_mask
        # v' = v+R@delta
        vertex_positions = self.mesh.vertices.unsqueeze(0) + vertex_deltas
        return vertex_positions

    def adjust_uv_and_texture(self, texture, return_texture=True):
        """
        Returns the UV coordinates of this mesh template,
        and preprocesses the provided texture to account for boundary conditions.
        If the mesh is symmetric, the texture and UVs are adjusted accordingly.
        """
        
        if self.is_symmetric:
            delta = 1/(2*texture.shape[3])
            expansion = (texture.shape[3]+1)/texture.shape[3]
            uvs = self.mesh.uvs.clone()
            uvs[:, 0] = (uvs[:, 0] + delta)/expansion
            
            uvs = uvs.expand(texture.shape[0], -1, -1)
            texture = circpad(texture, 1) # Circular padding
        else:
            uvs = self.mesh.uvs.expand(texture.shape[0], -1, -1)
            texture = torch.cat((texture, texture[:, :, :, :1]), dim=3)
            
        return uvs, texture
    
    def forward_renderer(self, vertex_positions, texture, num_gpus=1, **kwargs):
        mesh_faces = self.mesh.faces
        mesh_face_textures = self.mesh.face_textures
        if num_gpus > 1:
            mesh_faces = mesh_faces.repeat(num_gpus, 1)
            mesh_face_textures = mesh_face_textures.repeat(num_gpus, 1)

        input_uvs, input_texture = self.adjust_uv_and_texture(texture)

        # image, alpha, _ = renderer(points=[vertex_positions, mesh_faces],
        #                            uv_bxpx2=input_uvs,
        #                            texture_bx3xthxtw=input_texture,
        #                            ft_fx3=mesh_face_textures,
        #                            **kwargs)
        return vertex_positions,mesh_faces,input_uvs,input_texture,mesh_face_textures
    
    def export_obj(self, path_prefix, vertex_positions, texture):
        assert len(vertex_positions.shape) == 2
        mesh_path = path_prefix + '.obj'
        material_path = path_prefix + '.mtl'
        material_name = os.path.basename(path_prefix)
        
        # Export mesh .obj
        with open(mesh_path, 'w') as file:
            print('mtllib ' + os.path.basename(material_path), file=file)
            for v in vertex_positions:
                print('v {:.5f} {:.5f} {:.5f}'.format(*v), file=file)
            for uv in self.mesh.uvs:
                print('vt {:.5f} {:.5f}'.format(*uv), file=file)
            print('usemtl ' + material_name, file=file)
            for f, ft in zip(self.mesh.faces, self.mesh.face_textures):
                print('f {}/{} {}/{} {}/{}'.format(f[0]+1, ft[0]+1, f[1]+1, ft[1]+1, f[2]+1, ft[2]+1), file=file)
                
        # Export material .mtl
        with open(material_path, 'w') as file:
            print('newmtl ' + material_name, file=file)
            print('Ka 1.000 1.000 1.000', file=file)
            print('Kd 1.000 1.000 1.000', file=file)
            print('Ks 0.000 0.000 0.000', file=file)
            print('d 1.0', file=file)
            print('illum 1', file=file)
            print('map_Ka ' + material_name + '.png', file=file)
            print('map_Kd ' + material_name + '.png', file=file)
            
        # Export texture
        import imageio
        texture = (texture.permute(1, 2, 0)*255).clamp(0, 255).cpu().byte().numpy()
        imageio.imwrite(path_prefix + '.png', texture)
                
    @staticmethod
    def _monkey_patch_dependencies():
        if version.parse(torch.__version__) < version.parse('1.2'):
            def torch_where_patched(*args, **kwargs):
                if len(args) == 1:
                    return (torch.nonzero(args[0]), )
                else:
                    return torch._where_original(*args)

            torch._where_original = torch.where
            torch.where = torch_where_patched
            
        if version.parse(torch.__version__) >= version.parse('1.5'):
            from .monkey_patches import compute_adjacency_info_patched
            # Monkey patch
            kal.rep.Mesh.compute_adjacency_info = staticmethod(compute_adjacency_info_patched)
                
                

In [None]:
temp = MeshTemplate('/home/zf/vscode/3d/DR_3DFM/data/cylinder_template_mesh/0001_2_01  1022_norm_.obj')

# 2. mesh变化


In [35]:
import math

In [97]:
a = kal.rep.TriangleMesh.from_obj('/home/zf/vscode/3d/DR_3DFM/data/cylinder_template_mesh/uvsphere_31rings.obj', enable_adjacency=True)
vertices = a.vertices
y = []
# 排序
axis = 1
vertices_new = sorted(vertices,key=lambda x:x[axis])
for i in range(len(vertices_new)):
    if i == 0 or not math.isclose(y[-1],vertices_new[i][axis],abs_tol = 0.0001):
        # 如果接近则调出
        y.append(vertices_new[i][axis])
len(y)


32

In [98]:
import copy
b = copy.deepcopy(a)

In [99]:
for i in range(b.vertices.shape[0]):
    b.vertices[i] = torch.Tensor(np.array(b.vertices[i]) *  [0.8,2,0.8])

In [100]:
b.save_mesh('1.obj')

# 3. adj测试

In [5]:
import os
import pickle

import numpy as np
import torch
import trimesh
from scipy.sparse import coo_matrix



def torch_sparse_tensor(indices, value, size):
    coo = coo_matrix((value, (indices[:, 0], indices[:, 1])), shape=size)
    values = coo.data
    indices = np.vstack((coo.row, coo.col))

    i = torch.tensor(indices, dtype=torch.long)
    v = torch.tensor(values, dtype=torch.float)
    shape = coo.shape

    return torch.sparse.FloatTensor(i, v, shape)


In [10]:

class Ellipsoid(object):

    def __init__(self, mesh_pos, file):
        with open(file, "rb") as fp:
            fp_info = pickle.load(fp, encoding='latin1')

        # shape: n_pts * 3
        self.coord = torch.tensor(fp_info[0]) - torch.tensor(mesh_pos, dtype=torch.float)

        # edges & faces & lap_idx
        # edge: num_edges * 2
        # faces: num_faces * 4
        # laplace_idx: num_pts * 10
        self.edges, self.laplace_idx = [], []

        for i in range(3):
            self.edges.append(torch.tensor(fp_info[1 + i][1][0], dtype=torch.long))
            self.laplace_idx.append(torch.tensor(fp_info[7][i], dtype=torch.long))

        # unpool index
        # num_pool_edges * 2
        # pool_01: 462 * 2, pool_12: 1848 * 2
        self.unpool_idx = [torch.tensor(fp_info[4][i], dtype=torch.long) for i in range(2)]

        # loops and adjacent edges
        self.adj_mat = []
        for i in range(1, 4):
            # 0: np.array, 2D, pos
            # 1: np.array, 1D, vals
            # 2: tuple - shape, n * n
            adj_mat = torch_sparse_tensor(*fp_info[i][1])
            self.adj_mat.append(adj_mat)

        ellipsoid_dir = os.path.dirname(file)
        self.faces = []
        self.obj_fmt_faces = []
        # faces: f * 3, original ellipsoid, and two after deformations
        for i in range(1, 4):
            face_file = os.path.join(ellipsoid_dir, "face%d.obj" % i)
            faces = np.loadtxt(face_file, dtype='|S32')
            self.obj_fmt_faces.append(faces)
            self.faces.append(torch.tensor(faces[:, 1:].astype(np.int) - 1))


In [11]:
mesh_pos = torch.Tensor([0,0,0])
obj =  Ellipsoid(mesh_pos,'/home/zf/vscode/3d/DR_3DFM/data/info_ellipsoid.dat')

In [21]:
a = torch.rand((156,156))

In [24]:
a

tensor([[0.7474, 0.6870, 0.2964,  ..., 0.2934, 0.4218, 0.5662],
        [0.9128, 0.7624, 0.7298,  ..., 0.3978, 0.3712, 0.7036],
        [0.4203, 0.7345, 0.1927,  ..., 0.3878, 0.3320, 0.3252],
        ...,
        [0.9169, 0.6973, 0.8225,  ..., 0.0531, 0.8146, 0.3423],
        [0.6804, 0.5123, 0.3365,  ..., 0.2018, 0.5788, 0.4836],
        [0.0626, 0.2697, 0.1756,  ..., 0.0088, 0.8361, 0.9275]])

In [25]:
b = obj.adj_mat[0]

In [26]:
b

tensor(indices=tensor([[  0,   0,   0,  ..., 155, 155, 155],
                       [  0,   1,   2,  ..., 153, 154, 155]]),
       values=tensor([ 0.3548, -0.1935, -0.2290,  ..., -0.1935, -0.2290,
                       0.3548]),
       size=(156, 156), nnz=1080, layout=torch.sparse_coo)

In [32]:
b.shape


torch.Size([156, 156])

In [42]:
obj.adj_mat[0]

tensor(indices=tensor([[  0,   0,   0,  ..., 155, 155, 155],
                       [  0,   1,   2,  ..., 153, 154, 155]]),
       values=tensor([ 0.3548, -0.1935, -0.2290,  ..., -0.1935, -0.2290,
                       0.3548]),
       size=(156, 156), nnz=1080, layout=torch.sparse_coo)