# Example
This examples shows how to build an equivariant neural network using our method.

In [1]:
import torch
from e3nn.o3 import rand_matrix

from tensor_frames.lframes import LFrames
from tensor_frames.lframes.learning_lframes import WrappedLearnedLFrames
from tensor_frames.nn.embedding.radial import TrivialRadialEmbedding
from tensor_frames.nn.gcn_conv import GCNConv
from tensor_frames.nn.local_global import FromLocalToGlobalFrame
from tensor_frames.reps import TensorReps

First we create a SimpleLoCaConv class, that uses a learned local frames and a GCNConv layer, which uses tensorial messages. The `WrappedLearnedLFrames` module uses a radial embedding to compute scalar features from the positions, which are also used in the calculation of the local frames.

The forward pass executes the following steps:
1. Compute local features and local frames using the `lframes_module`.
2. Apply the GCN convolution in the local frame using tensorial message passing.
3. Transform the output features back to the global frame.

In [2]:
class SimpleLoCaConv(torch.nn.Module):
    def __init__(self, in_reps, out_reps):
        super().__init__()
        self.lframes_module = WrappedLearnedLFrames(
            in_reps=in_reps, hidden_channels=[128], radial_module=TrivialRadialEmbedding()
        )
        self.gcn_conv = GCNConv(in_reps, out_reps)
        self.from_local_to_global = FromLocalToGlobalFrame(out_reps)

    def forward(self, x, pos, batch, edge_index):
        x_local, lframes = self.lframes_module(x=x, pos=pos, batch=batch, edge_index=edge_index)
        out_local = self.gcn_conv(x=x_local, edge_index=edge_index, lframes=lframes)
        out_global = self.from_local_to_global(x=out_local, lframes=lframes)
        return out_global

Now we generate some random data. We create a graph with 10 nodes, where each node has 5 scalar features and 1 vectorial feature (8 components). The graph is fully connected, meaning every node is connected to every other node.

In [3]:
# create dummy data consisting of 100 nodes in 3D space with 5 scalar features and 1 vectorial feature each
pos = torch.randn(10, 3)  # vectorial positions
features = torch.randn(10, 5 + 3)  # 5 scalar + 1 vectorial (3) = 8 components
edge_index = torch.stack(
    torch.meshgrid(torch.arange(10), torch.arange(10), indexing="ij")
).reshape(
    2, -1
)  # fully connected graph
batch = torch.zeros(pos.shape[0], dtype=torch.long)  # all nodes belong to the same graph

Next we define the transformation behavior of the features through the `TensorReps` class. Here we specify that the input features consist of 5 scalar features and 1 vectorial feature, and the output features should have the same structure.

In [4]:
# create model
in_reps = TensorReps("5x0n + 1x1n")  # 5 scalar features + 1 vectorial feature
out_reps = TensorReps("5x0n + 1x1n")  # 5 scalar features + 1 vectorial feature as output

Now we can create a model instance and use the random data to perform a forward pass.


In [5]:
model = SimpleLoCaConv(in_reps=in_reps, out_reps=out_reps)

# forward pass
output = model(x=features, pos=pos, edge_index=edge_index, batch=batch)
print(output.shape)  # should be [100, 6] -> 5 scalar + 1 vectorial (3) = 6 features per node

torch.Size([10, 8])


Next, we can check that the network is indeed equivariant. To do this we generate a random global rotation.

In [6]:
# now check the equivariance of the model
global_rot = rand_matrix(1).repeat(100, 1, 1)
global_frame = LFrames(rand_matrix().repeat(pos.shape[0], 1, 1))
print("Rotation:\n", global_frame.matrices.shape)

Rotation:
 torch.Size([10, 3, 3])


We transform the positions and the node features through the corresponding representation. The `get_transform_class()` function generates a module, which transforms the features respectively.

In [7]:
# rotate positions
rotated_pos = TensorReps("1x1n").get_transform_class().transform_coeffs(pos, global_frame)
rotated_features = in_reps.get_transform_class().transform_coeffs(features, global_frame)

If we rotate the output features back to the original frame, they should match the output of the unrotated input features. This confirms that our network is equivariant.

In [8]:
rotated_output = model(x=rotated_features, pos=rotated_pos, edge_index=edge_index, batch=batch)
unrotated_output = out_reps.get_transform_class().transform_coeffs(
    rotated_output, global_frame.inverse_lframes()
)

print(torch.allclose(output, unrotated_output, atol=1e-4))  # should be True

True
