In [6]:
import torch
from e3nn.o3 import Irreps
from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork

# Define the input and output irreps
irreps_in = Irreps("1x0e + 3x1o")  # One scalar and three vectors
irreps_out = Irreps("1x0e")        # One scalar output

# Initialize the SimpleNetwork
network = SimpleNetwork(
    irreps_in=irreps_in,    # Input representation
    irreps_out=irreps_out,  # Output representation
    # num_layers=3,        # Number of hidden layers
    # fc_neurons=32           # Neurons per hidden layer
)

# Define input data: a batch of input features, e.g., batch size 10
x = torch.randn(10, irreps_in.dim)  # Random input data

# Pass the data through the network
output = network(x)

print(output.shape)  # Should match the output dimension of irreps_out


TypeError: __init__() missing 3 required positional arguments: 'max_radius', 'num_neighbors', and 'num_nodes'

In [12]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from e3nn.o3 import Irreps, FullyConnectedTensorProduct, Linear, Norm
from e3nn.nn import Gate

# 定义等变图神经网络的层
class EquivariantMessagePassing(MessagePassing):
    def __init__(self, irreps_in, irreps_out, hidden_irreps):
        super().__init__(aggr='add')  # "Add" aggregation (like summing up messages)
        self.tp1 = FullyConnectedTensorProduct(irreps_in, irreps_in, hidden_irreps)
        
        # Manually filter scalar parts of hidden_irreps
        irreps_scalars = Irreps([(mul, ir) for mul, ir in hidden_irreps if ir.l == 0])
        irreps_gates = Irreps(f"{irreps_scalars.num_irreps}x0e")  # Gates should match scalar parts
        irreps_gated = hidden_irreps
        
        self.gate = Gate(
            irreps_scalars=irreps_scalars,
            act_scalars=[torch.nn.functional.relu] * irreps_scalars.num_irreps,
            irreps_gates=irreps_gates,
            act_gates=[torch.sigmoid] * irreps_gates.num_irreps,
            irreps_gated=irreps_gated
        )
        self.tp2 = FullyConnectedTensorProduct(hidden_irreps, irreps_in, irreps_out)
    
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        # Compute message passing between nodes i and j
        message = self.tp1(x_i, x_j)
        message = self.gate(message)
        return message
    
    def update(self, inputs):
        return self.tp2(inputs)

# 定义整个网络
class EquivariantGraphNetwork(torch.nn.Module):
    def __init__(self, irreps_in, irreps_out, hidden_irreps):
        super().__init__()
        self.layer1 = EquivariantMessagePassing(irreps_in, hidden_irreps, hidden_irreps)
        self.layer2 = EquivariantMessagePassing(hidden_irreps, irreps_out, hidden_irreps)
    
    def forward(self, x, edge_index):
        # 第一层消息传递
        x1 = self.layer1(x, edge_index)
        # 第二层消息传递
        x2 = self.layer2(x1, edge_index)
        return x1, x2

# 设置 Irreps（不可约表示）
irreps_scalar = Irreps("1x0e")  # 标量：旋转不变
irreps_vector = Irreps("1x1o")  # 矢量：三维向量

# 输入的不可约表示: 标量和矢量的组合
irreps_in = Irreps("1x0e + 1x1o")
# 输出的不可约表示: 假设是标量
irreps_out = Irreps("1x0e")

# 隐藏层的不可约表示: 设置为 4 个标量和 4 个矢量
hidden_irreps = Irreps("4x0e + 4x1o")

# 初始化等变图神经网络
model = EquivariantGraphNetwork(irreps_in, irreps_out, hidden_irreps)

# 构建图数据
N = 5  # 假设有 5 个节点
x_scalar = torch.randn(N, 1)  # 每个节点的标量输入
x_vector = torch.randn(N, 3)  # 每个节点的矢量输入

# 拼接标量和矢量输入
x = torch.cat([x_scalar, x_vector], dim=-1)

# 构建边列表 (edge_index)，例如下面是一个全连接图的边
edge_index = torch.tensor([
    [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
    [1, 2, 3, 4, 0, 2, 3, 4, 0, 1]
], dtype=torch.long)

# 将数据转换为 PyTorch Geometric Data 对象
data = Data(x=x, edge_index=edge_index)

# 使用网络进行前向传播
output_layer1, output_layer2 = model(data.x, data.edge_index)

print("Output of layer 1:", output_layer1)
print("Output of layer 2:", output_layer2)


ValueError: There are 8 irreps in irreps_gated, but a different number (4) of gate scalars in irreps_gates