diff --git a/.gitignore b/.gitignore index 3e204eea9188..2e7c4c976eab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ +.pytest_cache/ data/ build/ dist/ diff --git a/test/utils/test_get_mesh_laplacian.py b/test/utils/test_get_mesh_laplacian.py new file mode 100644 index 000000000000..f438a26c3be5 --- /dev/null +++ b/test/utils/test_get_mesh_laplacian.py @@ -0,0 +1,98 @@ +import torch + +from torch_geometric.utils import get_mesh_laplacian + + +def test_get_mesh_laplacian_of_cube(): + pos = torch.tensor([ + [1.0, 1.0, 1.0], + [1.0, -1.0, 1.0], + [-1.0, -1.0, 1.0], + [-1.0, 1.0, 1.0], + [1.0, 1.0, -1.0], + [1.0, -1.0, -1.0], + [-1.0, -1.0, -1.0], + [-1.0, 1.0, -1.0], + ]) + + face = torch.tensor([ + [0, 1, 2], + [0, 3, 2], + [4, 5, 1], + [4, 0, 1], + [7, 6, 5], + [7, 4, 5], + [3, 2, 6], + [3, 7, 6], + [4, 0, 3], + [4, 7, 3], + [1, 5, 6], + [1, 2, 6], + ]) + + edge_index, edge_weight = get_mesh_laplacian(pos, face.t()) + + assert edge_index.tolist() == [ + [ + 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, + 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 0, 1, 2, 3, 4, 5, 6, 7 + ], + [ + 1, 2, 3, 4, 0, 2, 4, 5, 6, 0, 1, 3, 6, 0, 2, 4, 6, 7, 0, 1, 3, 5, + 7, 1, 4, 6, 7, 1, 2, 3, 5, 7, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7 + ], + ] + + assert torch.allclose( + edge_weight, + torch.tensor([ + -1.0, -0.0, -1.0, -1.0, -1.0, -1.0, -0.0, -1.0, -0.0, -0.0, -1.0, + -1.0, -1.0, -1.0, -1.0, -0.0, -0.0, -1.0, -1.0, -0.0, -0.0, -1.0, + -1.0, -1.0, -1.0, -1.0, -0.0, -0.0, -1.0, -0.0, -1.0, -1.0, -1.0, + -1.0, -0.0, -1.0, 1.125, 0.9, 1.125, 0.9, 0.9, 1.125, 0.9, 1.125 + ])) + + +def test_get_mesh_laplacian_of_irregular_triangular_prism(): + pos = torch.tensor([ + [0.0, 0.0, 0.0], + [4.0, 0.0, 0.0], + [0.0, 0.0, -3.0], + [1.0, 5.0, -1.0], + [3.0, 5.0, -1.0], + [2.0, 5.0, -2.0], + ]) + + face = torch.tensor([ + [0, 1, 2], + [3, 4, 5], + [0, 1, 4], + [0, 3, 4], + [1, 2, 5], + [1, 4, 5], + [2, 0, 3], + [2, 5, 3], + ]) + + edge_index, edge_weight = get_mesh_laplacian(pos, face.t()) + + assert edge_index.tolist() == [ + [ + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, + 5, 5, 0, 1, 2, 3, 4, 5 + ], + [ + 1, 2, 3, 4, 0, 2, 4, 5, 0, 1, 3, 5, 0, 2, 4, 5, 0, 1, 3, 5, 1, 2, + 3, 4, 0, 1, 2, 3, 4, 5 + ], + ] + + assert torch.allclose( + edge_weight, + torch.tensor([ + -0.938834, -1.451131, -0.490290, -0.000000, -0.938834, -0.378790, + -0.577017, -0.077878, -1.451131, -0.378790, -0.163153, -0.344203, + -0.490290, -0.163153, -1.421842, -2.387739, -0.000000, -0.577017, + -1.421842, -2.550610, -0.077878, -0.344203, -2.387739, -2.550610, + 0.298518, 0.183356, 0.233502, 0.761257, 0.688181, 0.768849 + ])) diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 060f5c646fc2..537c036ea215 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -11,6 +11,7 @@ from .subgraph import get_num_hops, subgraph, k_hop_subgraph from .homophily import homophily from .get_laplacian import get_laplacian +from .get_mesh_laplacian import get_mesh_laplacian from .mask import index_to_mask from .to_dense_batch import to_dense_batch from .to_dense_adj import to_dense_adj @@ -53,6 +54,7 @@ 'k_hop_subgraph', 'homophily', 'get_laplacian', + 'get_mesh_laplacian', 'index_to_mask', 'to_dense_batch', 'to_dense_adj', diff --git a/torch_geometric/utils/get_mesh_laplacian.py b/torch_geometric/utils/get_mesh_laplacian.py new file mode 100644 index 000000000000..2a44218ee881 --- /dev/null +++ b/torch_geometric/utils/get_mesh_laplacian.py @@ -0,0 +1,75 @@ +from typing import Tuple + +import torch +from torch import Tensor +from torch_scatter import scatter_add + +from torch_geometric.utils import add_self_loops, coalesce + + +def get_mesh_laplacian(pos: Tensor, face: Tensor) -> Tuple[Tensor, Tensor]: + r"""Computes the mesh Laplacian of a mesh given by :obj:`pos` and + :obj:`face`. + It is computed as + + .. math:: + \mathbf{L}_{ij} = \begin{cases} + \frac{\cot \angle_{ikj} + \cot \angle_{ilj}}{2 a_{ij}} & + \mbox{if } i, j \mbox{ is an edge} \\ + \sum_{j \in N(i)}{L_{ij}} & + \mbox{if } i \mbox{ is in the diagonal} \\ + 0 \mbox{ otherwise} + \end{cases} + + where :math:`a_{ij}` is the local area element, *i.e.* one-third of the + neighbouring triangle's area. + + Args: + pos (Tensor): The node positions. + face (LongTensor): The face indices. + """ + assert pos.size(1) == 3 and face.size(0) == 3 + + num_nodes = pos.shape[0] + + def add_angles(left, centre, right): + left_pos, central_pos, right_pos = pos[left], pos[centre], pos[right] + left_vec = left_pos - central_pos + right_vec = right_pos - central_pos + dot = torch.einsum('ij, ij -> i', left_vec, right_vec) + cross = torch.norm(torch.cross(left_vec, right_vec, dim=1), dim=1) + cot = dot / cross # cot = cos / sin + area = cross / 6.0 # one-third of a triangle's area is cross / 6.0 + return cot / 2.0, area / 2.0 + + # For each triangle face, add all three angles: + cot_201, area_201 = add_angles(face[2], face[0], face[1]) + cot_012, area_012 = add_angles(face[0], face[1], face[2]) + cot_120, area_120 = add_angles(face[1], face[2], face[0]) + + cot_weight = torch.cat( + [cot_201, cot_201, cot_012, cot_012, cot_120, cot_120]) + area_weight = torch.cat( + [area_201, area_201, area_012, area_012, area_120, area_120]) + + edge_index = torch.cat([ + torch.stack([face[2], face[1]], dim=0), + torch.stack([face[1], face[2]], dim=0), + torch.stack([face[0], face[2]], dim=0), + torch.stack([face[2], face[0]], dim=0), + torch.stack([face[1], face[0]], dim=0), + torch.stack([face[0], face[1]], dim=0) + ], dim=1) + + edge_index, weight = coalesce(edge_index, [cot_weight, area_weight]) + cot_weight, area_weight = weight + + # Compute the diagonal part: + row, col = edge_index + cot_deg = scatter_add(cot_weight, row, dim=0, dim_size=num_nodes) + area_deg = scatter_add(area_weight, row, dim=0, dim_size=num_nodes) + deg = cot_deg / area_deg + edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) + edge_weight = torch.cat([-cot_weight, deg], dim=0) + + return edge_index, edge_weight