In [35]:
# This net uses the following modules of e3nn
from e3nn import o3
from e3nn.o3 import FullyConnectedTensorProduct

# From pytorch_geometric
from torch_cluster import radius_graph
from torch_scatter import scatter
import torch

from torch_geometric.data import Data, DataLoader

In [36]:
def tetris():
    pos = [
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)],  # chiral_shape_2
        [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
        [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # L
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # T
        [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)],  # zigzag
    ]
    pos = torch.tensor(pos, dtype=torch.get_default_dtype())

    # Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them
    labels = torch.tensor(
        [
            [+1, 0, 0, 0, 0, 0, 0],  # chiral_shape_1
            [-1, 0, 0, 0, 0, 0, 0],  # chiral_shape_2
            [0, 1, 0, 0, 0, 0, 0],  # square
            [0, 0, 1, 0, 0, 0, 0],  # line
            [0, 0, 0, 1, 0, 0, 0],  # corner
            [0, 0, 0, 0, 1, 0, 0],  # L
            [0, 0, 0, 0, 0, 1, 0],  # T
            [0, 0, 0, 0, 0, 0, 1],  # zigzag
        ],
        dtype=torch.get_default_dtype(),
    )

    # apply random rotation
    pos = torch.einsum("zij,zaj->zai", o3.rand_matrix(len(pos)), pos)

    # put in torch_geometric format
    dataset = [Data(pos=pos) for pos in pos]
    data = next(iter(DataLoader(dataset, batch_size=len(dataset))))

    return data, labels

In [54]:
class InveriantPolynomial(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.irreps_sh: o3.Irreps = o3.Irreps.spherical_harmonics(3)
            
        irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o")
        irreps_out = o3.Irreps("0o + 6x0e")
        
        self.tp1 = FullyConnectedTensorProduct(
            irreps_in1 = self.irreps_sh,
            irreps_in2 = self.irreps_sh,
            irreps_out = irreps_mid,
        )
        
        self.tp2 = FullyConnectedTensorProduct(
            irreps_in1 = irreps_mid,
            irreps_in2 = self.irreps_sh,
            irreps_out = irreps_out,
        )
        
        self.irreps_out = self.tp2.irreps_out
        
    def forward(self, data) -> torch.Tensor:
        num_neighbors = 2 # typical_number of neighbors
        num_nodes = 4 # typical number of nodes
        
        edge_src, edge_dst = radius_graph(x = data.pos, r = 1.1, batch = data.batch) #tensor of indices representing graph
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_sh = o3.spherical_harmonics(
            l = self.irreps_sh,
            x = edge_vec,
            normalize=False, # here we don't normalize otherwise it would not be a polynomial
            normalization='component'
        )
        
        # For each node, the initial feature are the sum of the spherical harmonic of the neighbors
        node_features = scatter(edge_sh, edge_dst, dim = 0).div(num_neighbors**0.5)
        
        # For each edge, tensor product the features on the source node with the spherical harmonic
        edge_features = self.tp1(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst, dim = 0).div(num_neighbors**0.5)
        
        edge_features = self.tp2(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst, dim = 0).div(num_neighbors**0.5)
        
        
        # For each graph, all the node's features are summed
        return scatter(node_features, data.batch, dim = 0).div(num_nodes**0.5)
    

class InvariantPolynomial(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.irreps_sh: o3.Irreps = o3.Irreps.spherical_harmonics(3)
        irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o")
        irreps_out = o3.Irreps("0o + 6x0e")

        self.tp1 = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_sh,
            irreps_in2=self.irreps_sh,
            irreps_out=irreps_mid,
        )
        self.tp2 = FullyConnectedTensorProduct(
            irreps_in1=irreps_mid,
            irreps_in2=self.irreps_sh,
            irreps_out=irreps_out,
        )
        self.irreps_out = self.tp2.irreps_out

    def forward(self, data) -> torch.Tensor:
        num_neighbors = 2  # typical number of neighbors
        num_nodes = 4  # typical number of nodes

        edge_src, edge_dst = radius_graph(x=data.pos, r=1.1, batch=data.batch)  # tensors of indices representing the graph
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_sh = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=edge_vec,
            normalize=False,  # here we don't normalize otherwise it would not be a polynomial
            normalization="component",
        )

        # For each node, the initial features are the sum of the spherical harmonics of the neighbors
        node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5)

        # For each edge, tensor product the features on the source node with the spherical harmonics
        edge_features = self.tp1(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5)

        edge_features = self.tp2(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5)

        # For each graph, all the node's features are summed
        return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
  

In [55]:
# Training

data, labels = tetris()

f = InveriantPolynomial()

optim = torch.optim.Adam(f.parameters(), lr = 1e-2)

# == Train ==
for step in range(200):
    pred = f(data)
    loss = (pred - labels).pow(2).sum()
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if step % 10 == 0:
        accuracy = pred.round().eq(labels).all(dim = 1).double().mean(dim = 0).item()
        print(f"Epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy")

Epoch     0 | loss 1013.1     |   0.0% accuracy
Epoch    10 | loss 89.2       |   0.0% accuracy
Epoch    20 | loss 90.5       |   0.0% accuracy
Epoch    30 | loss 26.3       |   0.0% accuracy
Epoch    40 | loss 22.8       |   0.0% accuracy
Epoch    50 | loss 12.7       |  25.0% accuracy
Epoch    60 | loss 10.9       |  12.5% accuracy
Epoch    70 | loss 8.4        |  37.5% accuracy
Epoch    80 | loss 7.4        |  37.5% accuracy
Epoch    90 | loss 6.3        |  50.0% accuracy
Epoch   100 | loss 5.6        |  50.0% accuracy
Epoch   110 | loss 4.9        |  62.5% accuracy
Epoch   120 | loss 4.4        |  62.5% accuracy
Epoch   130 | loss 3.9        |  62.5% accuracy
Epoch   140 | loss 3.6        |  75.0% accuracy
Epoch   150 | loss 3.2        |  75.0% accuracy
Epoch   160 | loss 2.9        |  75.0% accuracy
Epoch   170 | loss 2.7        |  75.0% accuracy
Epoch   180 | loss 2.5        |  75.0% accuracy
Epoch   190 | loss 2.3        |  75.0% accuracy


In [56]:
import logging
from e3nn.util.test import assert_equivariant
# == Check equivariance ==
# Because the model outputs (psuedo)scalars, we can easily directly 
# check its equivariance to the same data with new rotations
print("Testing equivariance directly...")
rotated_data, _ = tetris()
error = f(rotated_data) - f(data)
print(f"Equivariance error = {error.abs().max().item():.1e}")

print("Testing equivariance using 'assert equivariance'...")

# To "interprit" between assert_equivariance and torch_geometric, we use a small wrapper:
def wrapper(pos, batch):
    return f(Data(pos = pos, batch = batch))

# `assert_equivariant` uses logging to print a summary of the equivariance error,
# so we enable logging
logging.basicConfig(level=logging.INFO)
assert_equivariant(
    wrapper,
    # We provide the original data that 'assert_equivariant' will transform...
    args_in = [data.pos, data.batch],
    # ...in accordance with these irreps...
    irreps_in = [
        'cartesian_points', # pos has vector 1o irreps, but is also translation equivariant
        None, # 'None' indicates invariant, possibly non-floating-point data
    ],
    # ..and confirm that the outputs transform correspondingly for these irreps:
    irreps_out=[f.irreps_out]
)

INFO:e3nn.util.test:Tested equivariance of `wrapper` -- max componentwise errors:
(parity_k=0, did_translate=False) -> max error=8.941e-07 in argument 0
(parity_k=0, did_translate=True) -> max error=1.490e-06 in argument 0
(parity_k=1, did_translate=False) -> max error=6.557e-07 in argument 0
(parity_k=1, did_translate=True) -> max error=4.016e-06 in argument 0


Testing equivariance directly...
Equivariance error = 1.3e-06
Testing equivariance using 'assert equivariance'...


{(0, False): tensor([8.9407e-07]),
 (0, True): tensor([1.4901e-06]),
 (1, False): tensor([6.5565e-07]),
 (1, True): tensor([4.0159e-06])}

In [57]:
def test() -> None:
    data, labels = tetris()
    f = InveriantPolynomial()
    
    pred = f(data)
    loss = (pred - labels).pow(2).sum()
    loss.backward()
    
    rotated_data, _ = tetris()
    error = f(rotated_data) - f(data)
    assert error.abs().max() < 1e-5

In [58]:
test()

### convolution network

In [71]:
from e3nn import o3
from e3nn.nn import FullyConnectedNet, Gate
from e3nn.o3 import FullyConnectedTensorProduct
from e3nn.math import soft_one_hot_linspace
from e3nn.util.test import assert_equivariant

def mean_std(name, x) -> None:
    print(f"{name} \t{x.mean():.1f} +- ({x.var(0).mean().sqrt():.1f}|{x.std():.1f})")
    
class Convolution(torch.nn.Module):
    def __init__(self, irreps_in, irreps_sh, irreps_out, num_neighbors) -> None:
        
        super().__init__()
        
        self.num_neighbors = num_neighbors
        
        tp = FullyConnectedTensorProduct(
            irreps_in1=irreps_in,
            irreps_in2=irreps_sh,
            irreps_out=irreps_out,
            internal_weights = False,
            shared_weights = False,
        )
        
        self.fc = FullyConnectedNet([3, 256, tp.weight_numel], torch.relu)
        self.tp = tp
        self.irreps_out = self.tp.irreps_out
        
    def forward(self, node_features, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor:
        weight = self.fc(edge_scalars)
        edge_features = self.tp(node_features[edge_src], edge_attr, weight)
        node_features = scatter(edge_features, edge_dst, dim = 0).div(self.num_neighbors**0.5)
        return node_features
    
    
class Network(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.num_neighbors = 3.8 # typical number of neighbors
        self.irreps_sh = o3.Irreps.spherical_harmonics(3)
        
        irreps = self.irreps_sh
        
        # First layer with gate
        gate = Gate(
            "16x0e + 16x0o",
            [torch.relu, torch.abs], # scalar
            "8x0e + 8x0o + 8x0e + 8x0o",
            [torch.relu, torch.tanh, torch.relu, torch.tanh], # gates (scalars)
            "16x1o + 16x1e", # gates tensors, num_irreps has to matrch with gates
        )
        
        self.conv = Convolution(irreps, self.irreps_sh, gate.irreps_in, self.num_neighbors)
        self.gate = gate
        irreps = self.gate.irreps_out
        
        # Final layer
        self.final = Convolution(irreps, self.irreps_sh, "0o + 6x0e", self.num_neighbors)
        self.irreps_out = self.final.irreps_out
        
        
    def forward(self, data) -> torch.Tensor:
        num_nodes = 4 # typical number of nodes

        edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch)  # tensors of indices representing the graph
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_attr = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=edge_vec,
            normalize=True,  # here we don't normalize otherwise it would not be a polynomial
            normalization="component",
        )
        edge_length_embedded = (
            soft_one_hot_linspace(x = edge_vec.norm(dim = 1), 
                                  start = 0.5, 
                                  end = 2.5, 
                                  number = 3,
                                  basis = "smooth_finite",
                                  cutoff = True
                                 ) * 3**0.5
        )

        x = scatter(edge_attr, edge_dst, dim = 0).div(self.num_neighbors**0.5)

        x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded)
        x = self.gate(x)
        x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded)

        return scatter(x, data.batch, dim = 0).div(num_nodes**0.5)
        
def main() -> None:
    data, labels = tetris()
    f = Network()

    print("Built a model:")
    print(f)

    optim = torch.optim.Adam(f.parameters(), lr=1e-3)

    # == Training ==
    for step in range(200):
        pred = f(data)
        loss = (pred - labels).pow(2).sum()

        optim.zero_grad()
        loss.backward()
        optim.step()

        if step % 10 == 0:
            accuracy = pred.round().eq(labels).all(dim=1).double().mean(dim=0).item()
            print(f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy")

    # == Check equivariance ==
    # Because the model outputs (psuedo)scalars, we can easily directly
    # check its equivariance to the same data with new rotations:
    print("Testing equivariance directly...")
    rotated_data, _ = tetris()
    error = f(rotated_data) - f(data)
    print(f"Equivariance error = {error.abs().max().item():.1e}")

    print("Testing equivariance using `assert_equivariance`...")
    # We can also use the library's `assert_equivariant` helper
    # `assert_equivariant` also tests parity and translation, and
    # can handle non-(psuedo)scalar outputs.
    # To "interpret" between it and torch_geometric, we use a small wrapper:

    def wrapper(pos, batch):
        return f(Data(pos=pos, batch=batch))

    # `assert_equivariant` uses logging to print a summary of the equivariance error,
    # so we enable logging
    logging.basicConfig(level=logging.INFO)
    assert_equivariant(
        wrapper,
        # We provide the original data that `assert_equivariant` will transform...
        args_in=[data.pos, data.batch],
        # ...in accordance with these irreps...
        irreps_in=[
            "cartesian_points",  # pos has vector 1o irreps, but is also translation equivariant
            None,  # `None` indicates invariant, possibly non-floating-point data
        ],
        # ...and confirm that the outputs transform correspondingly for these irreps:
        irreps_out=[f.irreps_out],
    )

In [72]:
main()

Built a model:
Network(
  (conv): Convolution(
    (fc): FullyConnectedNet[3, 256, 272]
    (tp): FullyConnectedTensorProduct(1x0e+1x1o+1x2e+1x3o x 1x0e+1x1o+1x2e+1x3o -> 32x0o+32x0e+16x1o+16x1e | 272 paths | 272 weights)
  )
  (gate): Gate (32x0o+32x0e+16x1o+16x1e -> 16x0e+16x0e+8x1o+8x1e+8x1e+8x1o)
  (final): Convolution(
    (fc): FullyConnectedNet[3, 256, 304]
    (tp): FullyConnectedTensorProduct(32x0e+8x1o+16x1e+8x1o x 1x0e+1x1o+1x2e+1x3o -> 1x0o+6x0e | 304 paths | 304 weights)
  )
)
epoch     0 | loss 200.9      |   0.0% accuracy
epoch    10 | loss 36.4       |   0.0% accuracy
epoch    20 | loss 12.4       |  37.5% accuracy
epoch    30 | loss 6.6        |  50.0% accuracy
epoch    40 | loss 4.8        |  50.0% accuracy
epoch    50 | loss 3.6        |  62.5% accuracy
epoch    60 | loss 2.9        |  75.0% accuracy
epoch    70 | loss 2.4        |  75.0% accuracy
epoch    80 | loss 2.0        |  75.0% accuracy
epoch    90 | loss 1.7        |  87.5% accuracy
epoch   100 | loss 1.5   

INFO:e3nn.util.test:Tested equivariance of `wrapper` -- max componentwise errors:
(parity_k=0, did_translate=False) -> max error=8.866e-07 in argument 0
(parity_k=0, did_translate=True) -> max error=3.906e-06 in argument 0
(parity_k=1, did_translate=False) -> max error=5.737e-07 in argument 0
(parity_k=1, did_translate=True) -> max error=1.490e-06 in argument 0


Testing equivariance directly...
Equivariance error = 9.4e-07
Testing equivariance using `assert_equivariance`...


In [75]:
def profile() -> None:
    data, labels = tetris()
    data = data.to(device="cpu")
    labels = labels.to(device="cpu")

    f = Network()
    f.to(device="cpu")

    optim = torch.optim.Adam(f.parameters(), lr=1e-2)

    called_num = [0]

    def trace_handler(p) -> None:
        print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
        p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json")
        called_num[0] += 1

    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        schedule=torch.profiler.schedule(wait=50, warmup=1, active=1),
        on_trace_ready=trace_handler,
    ) as p:
        for _ in range(52):
            pred = f(data)
            loss = (pred - labels).pow(2).sum()

            optim.zero_grad()
            loss.backward()
            optim.step()


In [76]:
profile()