In [1]:
import torch
from scipy.spatial import ConvexHull
import torch.nn as nn
import torch.nn.functional as F
from itertools import combinations

def get_commont_vertex(edge_pair):
    a = edge_pair[:, 0] == edge_pair[:, 1]
    b = edge_pair[:, 0] == torch.flip(edge_pair[:, 1], dims=[1])

    return edge_pair[:, 0][a + b]

class Non(nn.Module):
    def __init__(self):
        super(Non, self).__init__()

    def forward(self, x):
        return x

def adjacency_matrix(vertices, faces):
    B, N, D = vertices.shape

    halfedges = torch.tensor(list(combinations(range(D), 2)))
    edges = torch.cat([halfedges, torch.flip(halfedges,dims=[1])], dim=0)


    A = torch.zeros(1, N, N, device=faces.device)

    all_edges = faces[:, :, edges].long()
    all_edges = all_edges.view(1, -1, 2)
    A[0, all_edges[0, :, 0], all_edges[0, :, 1]] = 1 
    D = torch.diag(1 / torch.squeeze(torch.sum(A, dim=1)))[None]

    A = A.repeat(B, 1, 1)
    D = D.repeat(B, 1, 1)

    return A, D 


def adaptive_unpool(vertices, faces_prev, sphere_vertices, latent_features, N_prev):
    print("vertices", vertices.shape)
    vertices_primary = vertices[0,:N_prev, :]
    print("vertices_primary", vertices_primary.shape)
    vertices_secondary = vertices[0,N_prev:, :]
    faces_primary = faces_prev[0]
    print("vertices_secondary", vertices_secondary.shape)
    sphere_vertices_primary = sphere_vertices[0,:N_prev]
    sphere_vertices_secondary = sphere_vertices[0,N_prev:]

    if latent_features is not None:
        latent_features_primary = latent_features[0,:N_prev]
        latent_features_secondary = latent_features[0,N_prev:]

    face_count, _ = faces_primary.shape
    vertices_count = len(vertices_primary)
    edge_combinations_3 = torch.tensor(list(combinations(range(3), 2)), device = vertices.device)
    edges = faces_primary[:, edge_combinations_3]
    unique_edges = edges.view(-1, 2)
    unique_edges, _ = torch.sort(unique_edges, dim=1)
    unique_edges, unique_edge_indices = torch.unique(unique_edges, return_inverse=True, dim=0)
    print("unique_edges", unique_edges.shape)
    face_edges_primary = vertices_primary[unique_edges]
    print("face_edges", face_edges_primary.shape)
    print(type(face_edges_primary))

    a = face_edges_primary[:,0]
    b = face_edges_primary[:,1]
    v = vertices_secondary
    print("a", a.shape)
    print("b", b.shape)
    print("v", v.shape) 

    va = v - a
    vb = v - b
    ba = b - a

    cond1 = (va * ba).sum(1)
    norm1 = torch.norm(va, dim=1)

    cond2 = (vb * ba).sum(1)
    norm2 = torch.norm(vb, dim=1)

    dist = torch.norm(torch.cross(va, ba), dim=1)/torch.norm(ba, dim=1)
    dist[cond1 < 0] = norm1[cond1 < 0]
    dist[cond2 < 0] = norm2[cond2 < 0]

    sorted_, _ = torch.sort(dist)
    threshold = sorted_[int(0.3*len(sorted_))] 

    vertices_needed = vertices_secondary[dist > threshold]
    
    sphere_vertices_needed = sphere_vertices_secondary[dist > threshold] 
    if latent_features is not None:
        latent_features_needed = latent_features_secondary[dist > threshold]

    vertices = torch.cat([vertices_primary,vertices_needed],dim=0)[None]
    if latent_features is not None:
        latent_features = torch.cat([latent_features_primary,latent_features_needed],dim=0)[None]

    sphere_vertices = torch.cat([sphere_vertices_primary,sphere_vertices_needed],dim=0) 
    sphere_vertices = sphere_vertices/torch.sqrt(torch.sum(sphere_vertices**2,dim=1)[:,None])
    hull = ConvexHull(sphere_vertices.data.cpu().numpy())  
    faces = torch.from_numpy(hull.simplices).long().cuda()[None] 

    sphere_vertices = sphere_vertices[None]  

    return vertices, faces, latent_features, sphere_vertices

class GraphConv(nn.Module):
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, batch_norm=False):
        super(GraphConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fc = nn.Linear(in_features, out_features)
        self.neighbours_fc = nn.Linear(in_features, out_features)

        self.bc = nn.BatchNorm1d(out_features) if batch_norm else Non()

    def forward(self, input, A, Dinv, vertices, faces):

        # coeff = torch.bmm(torch.bmm(Dsqrtinv, A), Dsqrtinv)
        coeff = torch.bmm(Dinv, A) # row normalization, zmienione z bmm na mm ze względu na batch 1
        #coeff ma na celu wyrownanie wpływu wierzchołków o większej liczbie sąsiadów na deformację, dlatego to nie ejst samo A, tylko pomnożone przez Dinv
        y = self.fc(input) #to jest zwykła transformacja liniowa
        y_neightbours = torch.bmm(coeff, input)#zmienione bmm na matmul, ze względu na batch 1.
        #linijka wyżej mnoży współczynnik z cechami. Współczynnik jest reprezentacją connectivity w meshu deformowanym, więc to jest przerzucenie cech na siatkę
        y_neightbours = self.neighbours_fc(y_neightbours)
 
 
        # y_neightbours = self.bc(y_neightbours.permute(0, 2, 1)).permute(0, 2, 1)
        y = y + y_neightbours
        # y = self.bc(y.permute(0, 2, 1)).permute(0, 2, 1)
        return y

    def extra_repr(self):
        return 'in_features={}, out_features={}'.format(
            self.in_features, self.out_features is not None
        )
    

class Feature2VertexLayer(nn.Module):

    def __init__(self, in_features, hidden_layer_count, batch_norm=False):
        super(Feature2VertexLayer, self).__init__()
        self.gconv = []
        for i in range(hidden_layer_count, 1, -1):
            self.gconv += [GraphConv(i * in_features // hidden_layer_count, (i-1) * in_features // hidden_layer_count, batch_norm)]
        self.gconv_layer = nn.Sequential(*self.gconv)
        self.gconv_last = GraphConv(in_features // hidden_layer_count, 3, batch_norm)

    def forward(self, features, adjacency_matrix, degree_matrix, vertices, faces):
        for gconv_hidden in self.gconv:
            features = F.relu(gconv_hidden(features, adjacency_matrix, degree_matrix,vertices,faces))
        return self.gconv_last(features, adjacency_matrix, degree_matrix,vertices,faces)

class Features2Features(nn.Module):

    def __init__(self, in_features, out_features, hidden_layer_count=2, graph_conv=GraphConv):
        super(Features2Features, self).__init__()

        self.gconv_first = graph_conv(in_features, out_features)
        gconv_hidden = []
        for i in range(hidden_layer_count):
            gconv_hidden += [graph_conv(out_features, out_features)]
        self.gconv_hidden = nn.Sequential(*gconv_hidden)
        self.gconv_last = graph_conv(out_features, out_features)

    def forward(self, features, adjacency_matrix, degree_matrix, vertices, faces):
        features = F.relu(self.gconv_first(features, adjacency_matrix, degree_matrix, vertices,faces))
        for gconv_hidden in self.gconv_hidden:
            features = F.relu(gconv_hidden(features, adjacency_matrix, degree_matrix, vertices,faces))
        return self.gconv_last(features, adjacency_matrix, degree_matrix, vertices, faces)

def uniform_unpool(vertices_, faces_, identical_face_batch=True):
    if vertices_ is None:
        return None, None
    batch_size , _, _ = vertices_.shape
    new_faces_all = []
    new_vertices_all = []

    for vertices, faces in zip(vertices_, faces_):
        face_count, _ = faces.shape
        vertices_count = len(vertices)
        edge_combinations_3 = torch.tensor(list(combinations(range(3), 2)))
        edges = faces[:, edge_combinations_3]
        unique_edges = edges.view(-1, 2)
        unique_edges, _ = torch.sort(unique_edges, dim=1)
        unique_edges, unique_edge_indices = torch.unique(unique_edges, return_inverse=True, dim=0)
        face_edges = vertices[unique_edges]

        ''' Computer new vertices '''
        new_vertices = torch.mean(face_edges, dim=1)
        new_vertices = torch.cat([vertices, new_vertices], dim=0)  # <----------------------- new vertices + old vertices
        new_vertices_all += [new_vertices[None]]

        ''' Compute new faces '''
        corner_faces = []
        middle_face = []
        for j, combination in enumerate(edge_combinations_3):
            edge_pair = edges[:, combination]
            common_vertex = get_commont_vertex(edge_pair)

            new_vertex_1 = unique_edge_indices[torch.arange(0, 3 * face_count, 3) + combination[0]] + vertices_count
            new_vertex_2 = unique_edge_indices[torch.arange(0, 3 * face_count, 3) + combination[1]] + vertices_count

            middle_face += [new_vertex_1[:, None], new_vertex_2[:, None]]
            corner_faces += [torch.cat([common_vertex[:, None], new_vertex_1[:, None], new_vertex_2[:, None]], dim=1)]

        corner_faces = torch.cat(corner_faces, dim=0)
        middle_face = torch.cat(middle_face, dim=1)
        middle_face = torch.unique(middle_face, dim=1)
        new_faces_all += [torch.cat([corner_faces, middle_face], dim=0)[None]]  # new faces-3

        if identical_face_batch:
            new_vertices_all = new_vertices_all[0].repeat(batch_size, 1, 1)
            new_faces_all = new_faces_all[0].repeat(batch_size, 1, 1)
            break

    return new_vertices_all, new_faces_all


In [2]:
config = {
    "latent_features_count": [1024, 32],
    "graph_conv_layer_count": 3
}

In [3]:
class MeshDecoder(nn.Module):

    def __init__(self, config):
        super(MeshDecoder, self).__init__()

        self.config = config

        self.graph_conv_net = Features2Features(config["latent_features_count"][0], config["latent_features_count"][1], hidden_layer_count=config["graph_conv_layer_count"])
        self.feature_to_vertex = Feature2VertexLayer(config["latent_features_count"][1], 3)

    def forward(self, vertices, faces, latent_features):

        #A, D = adjacency_matrix(vertices, faces) #jeśli dobrze rozumiem działanie tej funkcji?

        _, N_prev, _ = vertices.shape 
        #to jest grafowa CNN do przekształcenia cech z enkodera do cech które przyjmie F2V
        #latent_features = self.graph_conv_net(latent_features, A, D, vertices, faces)
        vertices, faces_ = uniform_unpool(vertices, faces)  
        latent_features, _ = uniform_unpool(latent_features, faces)
        faces = faces_
        print(vertices.shape)
        print(vertices.shape)
        A, D = adjacency_matrix(vertices, faces)
        updated_latent_features = self.graph_conv_net(latent_features, A, D, vertices, faces)

        #być może latent_features = torch.cat()

        #f2v - obliczenie wektorów deformacji z cech, czyli przeniesienie na przestrzeń wierzchołków? współrzędnych?
        deformation_vectors = self.feature_to_vertex(updated_latent_features, A, D, vertices, faces)

        #deformacja siatki - przesunięcie o wektor wierzchołków. Faces się nie zmieniają, bo one wskazują wierzchołki, nie mają współrzędnych
        deformed_vertices = vertices + deformation_vectors

        #unpooling - upsampling, może być uniform, ale adaptive jest lepszy, bo nie dodaje niepotrzebnych wierzchołków na płaskich przestrzeniach.
        upsampled_vertices, upsampled_faces = adaptive_unpool(vertices, faces, deformed_vertices, latent_features, N_prev)

        return upsampled_vertices, upsampled_faces

In [9]:
import numpy as np
from stl import mesh

SPHERE_PATH = "Sphere.stl"

vertices = np.load("0a0f3b60/vertices.npy")
faces = np.load("0a0f3b60/faces.npy")

latent_features = torch.rand(1, 1250, 1024)

loaded_mesh = mesh.Mesh.from_file(SPHERE_PATH)
vertex_dict = {} #for storing vertices and their indices, helps to avoid duplicating the same vertices
sphere_vertices = np.empty((0, 3), dtype=float) #array of vertices - shape (n, 3), stores x, y, z coordinates of each vertex
sphere_faces = [] #array of faces - shape (n, 3), stores indices of vertices in each face

for vectors in loaded_mesh.vectors:
        for vertex in vectors:
            vertex_tuple = tuple(vertex)

            if vertex_tuple in vertex_dict:
                index = vertex_dict[vertex_tuple]
            else:
                index = len(sphere_vertices)
                sphere_vertices = np.append(sphere_vertices, [vertex], axis=0)
                vertex_dict[vertex_tuple] = index

        sphere_faces.append([vertex_dict[tuple(vertex)] for vertex in vectors])

ValueError: cannot reshape array of size 4456416 into shape (3882114,3)

In [None]:
sphere_vertices, sphere_faces = torch.tensor(sphere_vertices), torch.tensor(sphere_faces)
vertices, faces = torch.tensor(vertices), torch.tensor(faces)

sphere_faces = sphere_faces.unsqueeze(0)
sphere_vertices = sphere_vertices.unsqueeze(0)

decoder = MeshDecoder(config)
decoder.eval()
with torch.no_grad():
    deformed_vertices, deformed_faces = decoder(sphere_vertices, sphere_faces, latent_features)