Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mesh Laplacian computation #4187

Merged
merged 17 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .pytest_cache/v/cache/lastfailed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
daniel-unyi-42 marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions .pytest_cache/v/cache/nodeids
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[
"test/utils/test_get_mesh_laplacian.py::test_get_mesh_laplacian"
]
59 changes: 59 additions & 0 deletions test/utils/test_get_mesh_laplacian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch

from torch_geometric.utils import get_mesh_laplacian


def test_get_mesh_laplacian():

# cube
pos = torch.Tensor([[1.0, 1.0, 1.0], [1.0, -1.0, 1.0], \
[-1.0, -1.0, 1.0], [-1.0, 1.0, 1.0], \
[1.0, 1.0, -1.0], [1.0, -1.0, -1.0], \
[-1.0, -1.0, -1.0], [-1.0, 1.0, -1.0]])
face = torch.Tensor([[0, 1, 2], [0, 3, 2], [4, 5, 1], \
[4, 0, 1], [7, 6, 5], [7, 4, 5], \
[3, 2, 6], [3, 7, 6], [4, 0, 3], \
[4, 7, 3], [1, 5, 6], [1, 2, 6]]).long().t()

lap = get_mesh_laplacian(pos, face)
assert torch.all(lap[0] == torch.Tensor([[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, \
2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, \
5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, \
0, 1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4, 0, 2, 4, 5, 6, 0, 1, 3, \
6, 0, 2, 4, 6, 7, 0, 1, 3, 5, 7, 1, \
4, 6, 7, 1, 2, 3, 5, 7, 3, 4, 5, 6, \
0, 1, 2, 3, 4, 5, 6, 7]]))
assert torch.allclose(lap[1], torch.Tensor([-1.0, -0.0, -1.0, -1.0, -1.0, -1.0, \
-0.0, -1.0, -0.0, -0.0, -1.0, -1.0, \
-1.0, -1.0, -1.0, -0.0, -0.0, -1.0, \
-1.0, -0.0, -0.0, -1.0, -1.0, -1.0, \
-1.0, -1.0, -0.0, -0.0, -1.0, -0.0, \
-1.0, -1.0, -1.0, -1.0, -0.0, -1.0, \
1.125, 0.9, 1.125, 0.9, \
0.9, 1.125, 0.9, 1.125]))

# irregular triangular prism
pos = torch.Tensor([[0.0, 0.0, 0.0], [4.0, 0.0, 0.0], \
[0.0, 0.0, -3.0], [1.0, 5.0, -1.0], \
[3.0, 5.0, -1.0], [2.0, 5.0, -2.0]])
face = torch.Tensor([[0, 1, 2], [3, 4, 5], [0, 1, 4], [0, 3, 4], \
[1, 2, 5], [1, 4, 5], [2, 0, 3], [2, 5, 3]]).long().t()

lap = get_mesh_laplacian(pos, face)
assert torch.all(lap[0] == torch.Tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, \
3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, \
0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 0, 2, 4, 5, 0, 1, 3, 5, \
0, 2, 4, 5, 0, 1, 3, 5, 1, 2, 3, 4, \
0, 1, 2, 3, 4, 5]]))
assert torch.allclose(lap[1], torch.Tensor([-0.938834, -1.451131, -0.490290, \
-0.000000, -0.938834, -0.378790, \
-0.577017, -0.077878, -1.451131, \
-0.378790, -0.163153, -0.344203, \
-0.490290, -0.163153, -1.421842, \
-2.387739, -0.000000, -0.577017, \
-1.421842, -2.550610, -0.077878, \
-0.344203, -2.387739, -2.550610, \
0.298518, 0.183356, 0.233502, \
0.761257, 0.688181, 0.768849]))
2 changes: 2 additions & 0 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .subgraph import get_num_hops, subgraph, k_hop_subgraph
from .homophily import homophily
from .get_laplacian import get_laplacian
from .get_mesh_laplacian import get_mesh_laplacian
from .mask import index_to_mask
from .to_dense_batch import to_dense_batch
from .to_dense_adj import to_dense_adj
Expand Down Expand Up @@ -53,6 +54,7 @@
'k_hop_subgraph',
'homophily',
'get_laplacian',
'get_mesh_laplacian',
'index_to_mask',
'to_dense_batch',
'to_dense_adj',
Expand Down
74 changes: 74 additions & 0 deletions torch_geometric/utils/get_mesh_laplacian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Tuple

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.utils import add_self_loops, coalesce


def get_mesh_laplacian(pos: Tensor, face: Tensor) -> Tuple[Tensor, Tensor]:
r"""Computes the mesh Laplacian of the mesh given by
:obj:`pos` and :obj:`face`. It is computed as

.. math::
\mathbf{L}_{ij} = \begin{cases}
\frac{\cot \angle_{ikj} + \cot \angle_{ilj}}{2 a_{ij}} &
\mbox{if } i, j \mbox{ is an edge,} \\
\sum_{j \in N(i)}{L_{ij}} &
\mbox{if } i \mbox{ is in the diagonal,} \\
0 \mbox{ otherwise.}
\end{cases}

where :math:`a_{ij}` is the local area element,
i.e. one-third of the neighbouring triangle's area.

Args:
pos (Tensor): The node positions.
face (LongTensor): The face indices.
"""

assert pos.shape[1] == 3
assert face.shape[0] == 3

num_nodes = pos.shape[0]

def add_angles(left, centre, right):
left_pos, central_pos, right_pos = pos[left], pos[centre], pos[right]
left_vec = left_pos - central_pos
right_vec = right_pos - central_pos
dot = torch.einsum('ij, ij -> i', left_vec, right_vec)
cross = torch.norm(torch.cross(left_vec, right_vec, dim=1), dim=1)
cot = dot / cross # cot = cos / sin
area = cross / 6.0 # one-third of a triangle's area is cross / 6.0
return cot / 2.0, area / 2.0

# for each triangle face, add all 3 angles
cot_201, area_201 = add_angles(face[2], face[0], face[1])
cot_012, area_012 = add_angles(face[0], face[1], face[2])
cot_120, area_120 = add_angles(face[1], face[2], face[0])

cot_weight = torch.cat(
[cot_201, cot_201, cot_012, cot_012, cot_120, cot_120])
area_weight = torch.cat(
[area_201, area_201, area_012, area_012, area_120, area_120])
edge_index = torch.cat([
torch.stack([face[2], face[1]], dim=0),
torch.stack([face[1], face[2]], dim=0),
torch.stack([face[0], face[2]], dim=0),
torch.stack([face[2], face[0]], dim=0),
torch.stack([face[1], face[0]], dim=0),
torch.stack([face[0], face[1]], dim=0)
], dim=1)

_, cot_weight = coalesce(edge_index, cot_weight)
edge_index, area_weight = coalesce(edge_index, area_weight)

# compute the diagonal part
row, col = edge_index
cot_deg = scatter_add(cot_weight, row, dim=0, dim_size=num_nodes)
area_deg = scatter_add(area_weight, row, dim=0, dim_size=num_nodes)
deg = cot_deg / area_deg
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
edge_weight = torch.cat([-cot_weight, deg], dim=0)
return edge_index, edge_weight