In [18]:
import torch
import numpy as np
from itertools import product, chain
from torch import nn
import torch.nn.functional as F
import argparse


In [11]:
class LearntNeighbourhoodSampling(nn.Module):

    def __init__(self, config, features_count, step):
        super(LearntNeighbourhoodSampling, self).__init__()

        D, H, W = config.patch_shape 
        self.shape = torch.tensor([W, H, D]).cuda().float()

        self.shift = torch.tensor(list(product((-1, 0, 1), repeat=3)))[None].float() * torch.tensor([[[2 ** (config.steps+1 - step)/(W), 2 ** (config.steps+1 - step)/(H), 2 ** (config.steps+1 - step)/(D)]]])[None]
        self.shift = self.shift.cuda()

        self.sum_neighbourhood = nn.Conv2d(features_count, features_count, kernel_size=(1, 27), padding=0).cuda()

        # torch.nn.init.kaiming_normal_(self.sum_neighbourhood.weight, nonlinearity='relu')
        # torch.nn.init.constant_(self.sum_neighbourhood.bias, 0)
        self.shift_delta = nn.Conv1d(features_count, 27*3, kernel_size=(1), padding=0).cuda()
        self.shift_delta.weight.data.fill_(0.0)
        self.shift_delta.bias.data.fill_(0.0)

        self.feature_diff_1 = nn.Linear(features_count + 3, features_count)
        self.feature_diff_2 = nn.Linear(features_count, features_count) 

        self.feature_center_1 = nn.Linear(features_count + 3, features_count)
        self.feature_center_2 = nn.Linear(features_count, features_count)

    def forward(self, voxel_features, vertices):

        B, N, _ = vertices.shape
        center = vertices[:, :, None, None]
        features = F.grid_sample(voxel_features, center, mode='bilinear', padding_mode='border', align_corners=True)
        features = features[:, :, :, 0, 0]
        shift_delta = self.shift_delta(features).permute(0, 2, 1).view(B, N, 27, 1, 3)
        shift_delta[:,:,0,:,:] = shift_delta[:,:,0,:,:] * 0 # setting first shift to zero so it samples at the exact point
 
        # neighbourhood = vertices[:, :, None, None] + self.shift[:, :, :, None] + shift_delta
        neighbourhood = vertices[:, :, None, None] + shift_delta
        features = F.grid_sample(voxel_features, neighbourhood, mode='bilinear', padding_mode='border', align_corners=True)
        features = features[:, :, :, :, 0]
        features = torch.cat([features, neighbourhood.permute(0,4,1,2,3)[:,:,:,:,0]], dim=1)

        features_diff_from_center = features - features[:,:,:,0][:,:,:,None] # 0 is the index of the center cordinate in shifts
        features_diff_from_center = features_diff_from_center.permute([0,3,2,1])
        features_diff_from_center = self.feature_diff_1(features_diff_from_center)
        features_diff_from_center = self.feature_diff_2(features_diff_from_center)
        features_diff_from_center = features_diff_from_center.permute([0,3,2,1])
        
        features_diff_from_center = self.sum_neighbourhood(features_diff_from_center)[:, :, :, 0].transpose(2, 1)

        center_feautres =  features[:,:,:,13].transpose(2, 1)
        center_feautres = self.feature_center_1(center_feautres)
        center_feautres = self.feature_center_2(center_feautres)

        features = center_feautres + features_diff_from_center 
        return features
    
class GraphConv(nn.Module):
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, batch_norm=False):
        super(GraphConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fc = nn.Linear(in_features, out_features)
        self.neighbours_fc = nn.Linear(in_features, out_features)

        self.bc = nn.BatchNorm1d(out_features) if batch_norm else Non()

    def forward(self, input, A, Dinv, vertices, faces):
        # coeff = torch.bmm(torch.bmm(Dsqrtinv, A), Dsqrtinv)
        coeff = torch.bmm(Dinv, A) # row normalization 

        y = self.fc(input)
        y_neightbours = torch.bmm(coeff, input)
        y_neightbours = self.neighbours_fc(y_neightbours)
 
 
        # y_neightbours = self.bc(y_neightbours.permute(0, 2, 1)).permute(0, 2, 1)
        y = y + y_neightbours
        # y = self.bc(y.permute(0, 2, 1)).permute(0, 2, 1)
        return y

    def extra_repr(self):
        return 'in_features={}, out_features={}'.format(
            self.in_features, self.out_features is not None
        )
    
class Feature2VertexLayer(nn.Module):

    def __init__(self, in_features, hidden_layer_count, batch_norm=False):
        super(Feature2VertexLayer, self).__init__()
        self.gconv = []
        for i in range(hidden_layer_count, 1, -1):
            self.gconv += [GraphConv(i * in_features // hidden_layer_count, (i-1) * in_features // hidden_layer_count, batch_norm)]
        self.gconv_layer = nn.Sequential(*self.gconv)
        self.gconv_last = GraphConv(in_features // hidden_layer_count, 3, batch_norm)

    def forward(self, features, adjacency_matrix, degree_matrix, vertices, faces):
        for gconv_hidden in self.gconv:
            features = F.relu(gconv_hidden(features, adjacency_matrix, degree_matrix,vertices,faces))
        return self.gconv_last(features, adjacency_matrix, degree_matrix,vertices,faces)

class Features2Features(nn.Module):

    def __init__(self, in_features, out_features, hidden_layer_count=2, graph_conv=GraphConv):
        super(Features2Features, self).__init__()

        self.gconv_first = graph_conv(in_features, out_features)
        gconv_hidden = []
        for i in range(hidden_layer_count):
            gconv_hidden += [graph_conv(out_features, out_features)]
        self.gconv_hidden = nn.Sequential(*gconv_hidden)
        self.gconv_last = graph_conv(out_features, out_features)

    def forward(self, features, adjacency_matrix, degree_matrix, vertices, faces):
        features = F.relu(self.gconv_first(features, adjacency_matrix, degree_matrix, vertices,faces))
        for gconv_hidden in self.gconv_hidden:
            features = F.relu(gconv_hidden(features, adjacency_matrix, degree_matrix, vertices,faces))
        return self.gconv_last(features, adjacency_matrix, degree_matrix, vertices, faces)

class UNetLayer(nn.Module):
    """ U-Net Layer """
    def __init__(self, num_channels_in, num_channels_out, ndims, batch_norm=False):

        super(UNetLayer, self).__init__()

        conv_op = nn.Conv2d if ndims == 2 else nn.Conv3d
        batch_nrom_op = nn.BatchNorm2d if ndims == 2 else nn.BatchNorm3d

        conv1 = conv_op(num_channels_in,  num_channels_out, kernel_size=3, padding=1)
        conv2 = conv_op(num_channels_out, num_channels_out, kernel_size=3, padding=1)

        bn1 = batch_nrom_op(num_channels_out)
        bn2 = batch_nrom_op(num_channels_out)
        self.unet_layer = nn.Sequential(conv1, bn1, nn.ReLU(), conv2, bn2, nn.ReLU())

    def forward(self, x):
        return self.unet_layer(x)

In [24]:
config = argparse.Namespace(
    batch_size=1,
    ndims=3,
    num_classes=5,
    first_layer_channels=64,
    steps=3,
    graph_conv_layer_count=2,
    batch_norm=True,
    patch_shape = (32, 32, 32)
    # Dodaj inne parametry w razie potrzeby
)

In [25]:
class V2M_Decoder(nn.Module):

    def __init__(self, config):
        
        super(V2M_Decoder, self).__init__()

        self.config = config
        self.max_pool = nn.MaxPool3d(2) if config.ndims == 3 else nn.MaxPool2d(2) 

        ConvTransposeLayer = nn.ConvTranspose3d #if config.ndims == 3, ale u nas zawsze mamy 3D, nie operujemy na płaskich meshach (?)


        ''' Up layers ''' 
        self.skip_count = []
        self.latent_features_coount = []
        for i in range(config.steps+1):
            self.skip_count += [config.first_layer_channels * 2 ** (config.steps-i)] 
            self.latent_features_coount += [32]

        dim = 3

        up_std_conv_layers = []
        up_f2f_layers = []
        up_f2v_layers = []
        for i in range(config.steps+1):
            graph_unet_layers = []
            feature2vertex_layers = []
            skip = LearntNeighbourhoodSampling(config, self.skip_count[i], i)
            # lyr = Feature2VertexLayer(self.skip_count[i])
            if i == 0:
                grid_upconv_layer = None
                grid_unet_layer = None
                for k in range(config.num_classes-1):
                    graph_unet_layers += [Features2Features(self.skip_count[i] + dim, self.latent_features_coount[i], hidden_layer_count=config.graph_conv_layer_count)] # , graph_conv=GraphConv
            else:
                grid_upconv_layer = ConvTransposeLayer(in_channels=config.first_layer_channels   * 2**(config.steps - i+1), out_channels=config.first_layer_channels * 2**(config.steps-i), kernel_size=2, stride=2)
                grid_unet_layer = UNetLayer(config.first_layer_channels * 2**(config.steps - i + 1), config.first_layer_channels * 2**(config.steps-i), config.ndims, config.batch_norm)
                for k in range(config.num_classes-1):
                    graph_unet_layers += [Features2Features(self.skip_count[i] + self.latent_features_coount[i-1] + dim, self.latent_features_coount[i], hidden_layer_count=config.graph_conv_layer_count)] #, graph_conv=GraphConv if i < config.steps else GraphConvNoNeighbours
            for k in range(config.num_classes-1):
                feature2vertex_layers += [Feature2VertexLayer(self.latent_features_coount[i], 3)] 

            up_std_conv_layers.append((skip, grid_upconv_layer, grid_unet_layer))
            up_f2f_layers.append(graph_unet_layers)
            up_f2v_layers.append(feature2vertex_layers)

        self.up_std_conv_layers = up_std_conv_layers
        self.up_f2f_layers = up_f2f_layers
        self.up_f2v_layers = up_f2v_layers

        self.decoder_std_conv = nn.Sequential(*chain(*up_std_conv_layers))
        self.decoder_f2f = nn.Sequential(*chain(*up_f2f_layers))
        self.decoder_f2v = nn.Sequential(*chain(*up_f2v_layers))

    def forward(self, x):
        # Wejście x to już cechy obiektu, przekazujemy je przez dekoder
        x = self.decoder_std_conv(x)
        x = self.decoder_f2f(x)
        x = self.decoder_f2v(x)
        return x

In [26]:
# Inicjalizuj model z podaną konfiguracją
model = V2M_Decoder(config)

# Przykładowe dane wejściowe
input_features = torch.rand((config.batch_size, config.first_layer_channels, config.ndims))

# Przetestuj propagację przez model
output = model.forward(input_features)

# Wypisz wynik
print("Wynik propagacji przez model:")
print(output)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [123]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MeshDeformationEncoder(nn.Module):
    def __init__(self, num_vertices, num_faces, input_size):
        super(MeshDeformationEncoder, self).__init__()
        self.embedding_layer = nn.Linear(input_size, 256)
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, num_vertices * 3)  # 3 współrzędne na wierzchołek

    def forward(self, vertices, faces, latent_features):
        x = F.relu(self.embedding_layer(latent_features))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        decoded_vertices = self.fc3(x)
        decoded_vertices = decoded_vertices.view(decoded_vertices.size(0), -1, 3)
        decoded_vertices = F.normalize(decoded_vertices, p=2, dim=2) #nie wiem czy potrzebuję normalizować, i czy to nie zaszkodzi bardziej szczerze mówiąc
        num_vertices_decoded = decoded_vertices.size(1)
        decoded_vertices = decoded_vertices.view(1, num_vertices_decoded, 3)

        print("original vertices size:", vertices.size())
        print("new vertices size:", decoded_vertices.size())
        deformed_vertices = decoded_vertices

        return deformed_vertices

In [124]:
import numpy as np
from stl import mesh

loaded_mesh = mesh.Mesh.from_file('Sphere.stl')
vertex_dict = {} #for storing vertices and their indices, helps to avoid duplicating the same vertices
vertices = np.empty((0, 3), dtype=float) #array of vertices - shape (n, 3), stores x, y, z coordinates of each vertex
faces = [] #array of faces - shape (n, 3), stores indices of vertices in each face

for vectors in loaded_mesh.vectors:
        for vertex in vectors:
            vertex_tuple = tuple(vertex)

            if vertex_tuple in vertex_dict:
                index = vertex_dict[vertex_tuple]
            else:
                index = len(vertices)
                vertices = np.append(vertices, [vertex], axis=0)
                vertex_dict[vertex_tuple] = index

        faces.append([vertex_dict[tuple(vertex)] for vertex in vectors])

In [125]:
input_size = 500

# Załaduj dane latentne (przykładowe dane, zastąp je własnymi)
latent_features = torch.randn(1, input_size)

# Wygeneruj przykładową tablicę wierzchołków reprezentującą sferę
num_vertices = 256  # Dla przykładu używam 256 wierzchołków

sphere_vertices, faces = torch.tensor(vertices), torch.tensor(faces)

faces = torch.tensor(faces, dtype=torch.long).view(1, -1, 3)
num_faces = (num_vertices - 1) * (num_vertices - 1) * 2

# Utwórz instancję enkodera
encoder = MeshDeformationEncoder(num_vertices, num_faces, input_size)

# Ustaw model w tryb ewaluacji (bez trenowania)
encoder.eval()

# Uzyskaj zdeformowane współrzędne wierzchołków i ścian
with torch.no_grad():
    deformed_vertices = encoder(sphere_vertices, faces, latent_features)

print("Oryginalne wierzchołki sfery:")
print(sphere_vertices)

print("\nIndeksy ścian:")
print(faces)

print("\nZdeformowane współrzędne wierzchołków:")
print(deformed_vertices)

original vertices size: torch.Size([1250, 3])
new vertices size: torch.Size([1, 256, 3])
Oryginalne wierzchołki sfery:
tensor([[  3.1148,   0.3935,  49.9013],
        [  3.1148,  -0.3935,  49.9013],
        [  3.1395,   0.0000,  49.9013],
        ...,
        [  2.9191,  -1.1557, -49.9013],
        [  3.0409,  -0.7808, -49.9013],
        [  3.1148,  -0.3935, -49.9013]], dtype=torch.float64)

Indeksy ścian:
tensor([[[   0,    1,    2],
         [   3,    1,    0],
         [   3,    4,    1],
         ...,
         [1226, 1223, 1227],
         [1226, 1224, 1223],
         [1224, 1226, 1225]]])

Zdeformowane współrzędne wierzchołków:
tensor([[[-1.7938e-01,  5.7605e-01,  7.9749e-01],
         [ 6.6789e-01,  7.3533e-01,  1.1495e-01],
         [ 5.1060e-02, -7.6379e-01, -6.4344e-01],
         [-5.8193e-01, -7.4212e-01, -3.3259e-01],
         [-9.1538e-01,  3.3410e-01,  2.2463e-01],
         [ 6.3569e-01, -1.6621e-01, -7.5384e-01],
         [-9.6677e-01, -3.1371e-02, -2.5372e-01],
         [

  faces = torch.tensor(faces, dtype=torch.long).view(1, -1, 3)


In [127]:
import numpy as np
from stl import mesh

def save_mesh_to_stl(deformed_vertices, faces, stl_filename):
    # Przygotuj obiekt STL
    mesh_data = mesh.Mesh(np.zeros(len(faces[0]), dtype=mesh.Mesh.dtype))

    for i in range(len(faces[0])):
        for j in range(3):
            print(faces[0][i][j])
            try:
                mesh_data.vectors[i][j] = deformed_vertices[0, faces[0][i][j], :]
            except IndexError:
                mesh_data.vectors[i][j] = deformed_vertices[0, 0, :]

    # Zapisz obiekt STL do pliku
    mesh_data.save(stl_filename)

save_mesh_to_stl(deformed_vertices, faces, "output_mesh.stl")
print("File saved!")

tensor(0)
tensor(1)
tensor(2)
tensor(3)
tensor(1)
tensor(0)
tensor(3)
tensor(4)
tensor(1)
tensor(5)
tensor(4)
tensor(3)
tensor(5)
tensor(6)
tensor(4)
tensor(7)
tensor(6)
tensor(5)
tensor(7)
tensor(8)
tensor(6)
tensor(9)
tensor(8)
tensor(7)
tensor(9)
tensor(10)
tensor(8)
tensor(11)
tensor(10)
tensor(9)
tensor(11)
tensor(12)
tensor(10)
tensor(13)
tensor(12)
tensor(11)
tensor(13)
tensor(14)
tensor(12)
tensor(15)
tensor(14)
tensor(13)
tensor(15)
tensor(16)
tensor(14)
tensor(17)
tensor(16)
tensor(15)
tensor(17)
tensor(18)
tensor(16)
tensor(19)
tensor(18)
tensor(17)
tensor(19)
tensor(20)
tensor(18)
tensor(21)
tensor(20)
tensor(19)
tensor(21)
tensor(22)
tensor(20)
tensor(23)
tensor(22)
tensor(21)
tensor(23)
tensor(24)
tensor(22)
tensor(25)
tensor(24)
tensor(23)
tensor(25)
tensor(26)
tensor(24)
tensor(27)
tensor(26)
tensor(25)
tensor(27)
tensor(28)
tensor(26)
tensor(29)
tensor(28)
tensor(27)
tensor(29)
tensor(30)
tensor(28)
tensor(31)
tensor(30)
tensor(29)
tensor(31)
tensor(32)
tensor(30)
tens