-
Notifications
You must be signed in to change notification settings - Fork 316
/
mesh_union.py
44 lines (35 loc) · 1.49 KB
/
mesh_union.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from torch.nn import ConstantPad2d
class MeshUnion:
def __init__(self, n, device=torch.device('cpu')):
self.__size = n
self.rebuild_features = self.rebuild_features_average
self.groups = torch.eye(n, device=device)
def union(self, source, target):
self.groups[target, :] += self.groups[source, :]
def remove_group(self, index):
return
def get_group(self, edge_key):
return self.groups[edge_key, :]
def get_occurrences(self):
return torch.sum(self.groups, 0)
def get_groups(self, tensor_mask):
self.groups = torch.clamp(self.groups, 0, 1)
return self.groups[tensor_mask, :]
def rebuild_features_average(self, features, mask, target_edges):
self.prepare_groups(features, mask)
fe = torch.matmul(features.squeeze(-1), self.groups)
occurrences = torch.sum(self.groups, 0).expand(fe.shape)
fe = fe / occurrences
padding_b = target_edges - fe.shape[1]
if padding_b > 0:
padding_b = ConstantPad2d((0, padding_b, 0, 0), 0)
fe = padding_b(fe)
return fe
def prepare_groups(self, features, mask):
tensor_mask = torch.from_numpy(mask)
self.groups = torch.clamp(self.groups[tensor_mask, :], 0, 1).transpose_(1, 0)
padding_a = features.shape[1] - self.groups.shape[0]
if padding_a > 0:
padding_a = ConstantPad2d((0, 0, 0, padding_a), 0)
self.groups = padding_a(self.groups)