# Test Part One
1. `LinearModuleDict`
2. `SelfGateModuleDict`
3. `MultiLayerPerceptionModuleDict`
4. `DepthwiseTensorProductModuleDict`

In [1]:
from typing import Dict, Tuple
from torch import Tensor
import torch
import e3nn
from e3nn.o3 import Irreps

# edge_a, edge_b: "6-8", "8-6"
irreps_edges_a: Irreps = Irreps("32x0e+16x1e")
irreps_edges_b: Irreps = Irreps("8x0e+16x1e")
irreps_edge_dict: Dict[str, Irreps] = {"6-8": irreps_edges_a, "8-6": irreps_edges_b}
irreps_edge_length: Irreps = Irreps("16x0e") # should only contain Irrep="0e"
irreps_edge_vec: Irreps = Irreps("8x0e+4x1e+4x2e+4x3e")

num_edges_a: int = 7
num_edges_b: int = 12
num_heads: int = 3


# = parameter 0 = 
seed = 0
torch.manual_seed(seed)
torch.random.manual_seed(seed)
edge_fea_a = irreps_edges_a.randn(num_edges_a, -1, normalization="component")
edge_fea_b = irreps_edges_b.randn(num_edges_b, -1, normalization="component")
edge_fea_a = irreps_edges_a.randn(num_edges_a, num_heads, -1, normalization="component")
edge_fea_b = irreps_edges_b.randn(num_edges_b, num_heads, -1, normalization="component")
edge_fea_dict = {"6-8": edge_fea_a, "8-6": edge_fea_b}
edge_vec_a = irreps_edge_vec.randn(num_edges_a, -1, normalization="component")
edge_vec_b = torch.randn(1) * irreps_edge_vec.randn(num_edges_b, -1, normalization="component") 
edge_vec_dict = {"6-8": edge_vec_a, "8-6": edge_vec_b}
edge_length_dict = {"6-8": irreps_edge_length.randn(num_edges_a, -1, normalization="component"), 
                    "8-6": irreps_edge_length.randn(num_edges_b, -1, normalization="component")}

from modelname.nn.moduledict import LinearModuleDict
from modelname.nn.moduledict import SelfGateModuleDict
from modelname.nn.moduledict import MultiLayerPerceptionModuleDict
from modelname.nn.moduledict import DepthwiseTensorProductModuleDict
net_1 = LinearModuleDict(irreps_in_dict=irreps_edge_dict, irreps_out_dict=irreps_edge_dict, bias=True)
net_2 = SelfGateModuleDict(irreps_in_dict=irreps_edge_dict, act_scalars=torch.nn.SiLU(), act_gates=torch.nn.Sigmoid())
net_3 = MultiLayerPerceptionModuleDict(irreps_in_dict=irreps_edge_dict, irreps_out_dict=irreps_edge_dict,
        irreps_mid_list=[Irreps("4x0e+2x1e+1x2e")], add_last_linear="bias", if_act=True, if_norm=True, 
        act_scalars=torch.nn.SiLU(), act_gates=torch.nn.Sigmoid(), norm_type="layer"
    )   
#TODO Why get wrong out variance when irreps_mid_list contain "1e" ?
#TODO Why rescale=True/False, "1e" Variance always 10 times than "0e" ?
net_4 = DepthwiseTensorProductModuleDict(irreps_in_dict=irreps_edge_dict, 
        irreps_edge_vec_embed=irreps_edge_vec, 
        irreps_edge_length_embed=irreps_edge_length,
        dtp_internal_weights=False,
        mlp_irreps_mid_list=[Irreps("10x0e")], # should only contain Irrep="0e"
        mlp_add_last_linear=None, 
        mlp_if_act=True,
        mlp_if_norm=True,
        mlp_act_scalars=torch.nn.SiLU(), mlp_act_gates=torch.nn.Sigmoid(),
        mlp_norm_type="layer"
    )   # dtp_internal_weights=False for alpha
        # mlp_add_last_linear="bias" for alpha
        # mlp_if_act=True for alpha
        # mlp_if_norm=True for alpha
        # mlp_act_scalars=torch.nn.SiLU(), mlp_act_gates=torch.nn.Sigmoid() for alpha
        # mlp_norm_type="layer" for alpha
        # dtp_internal_weights=True for value
#TODO mlp_add_last_linear=None, mlp_if_norm=True seems get better Out Variace

# = parameter 1 = 
net = net_4 # net_1 or net_2 or net_3 or net_4 

print(net) # Note that non-scalars have no bias even bias=True
if isinstance(net, DepthwiseTensorProductModuleDict):
    out = net(edge_fea_dict=edge_fea_dict, edge_vec_dict=edge_vec_dict, edge_length_dict=edge_length_dict)
else:   
    out = net(x_dict=edge_fea_dict)

# = parameter 2 = 
key = "8-6" # "8-6" or "6-8"
in_ = edge_fea_dict[key]
in_var = in_.var()
in_mean = in_.mean()
out_ = out[key]
out_var = out_.var()
out_mean = out_.mean()
print(
    f"Mean Value of Output: {out_mean}\n"
    f"Variance value of Output: {out_var}\n"
    f"("
    f"Mean Value of Input: {in_mean}\n"
    f"Variance value of Input: {in_var})\n"
)

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


DepthwiseTensorProductModuleDict(
  (dtp_dict): ModuleDict(
    (DepthwiseTensorProductModule(6-8)): DepthwiseTensorProductModule(
      (tp): TensorProductRescale(
        (tp): TensorProduct(32x0e+16x1e x 8x0e+4x1e+4x2e+4x3e -> 32x0e+16x1e | 512 paths | 512 weights)
        (bias): ParameterList()
      )
    )
    (DepthwiseTensorProductModule(8-6)): DepthwiseTensorProductModule(
      (tp): TensorProductRescale(
        (tp): TensorProduct(8x0e+16x1e x 8x0e+4x1e+4x2e+4x3e -> 8x0e+16x1e | 320 paths | 320 weights)
        (bias): ParameterList()
      )
    )
  )
  (mlp_dict): MultiLayerPerceptionModuleDict(
    (mlp_dict): ModuleDict(
      (MultiLayerPerceptionModule(6-8)): MultiLayerPerceptionModule(
        (net): Sequential(
          (0): LinearRS(
            (tp): TensorProduct(16x0e x 1x0e -> 10x0e | 160 paths | 160 weights)
            (bias): ParameterList()
          )
          (1): EquivariantLayerNormV2(10x0e, eps=1e-05)
          (2): SelfGateModule(
            (scal

In [2]:
def outs_rot_invariance(rot: Tensor, 
                  model: torch.nn.Module, 
                  x_dict: Dict[str, Tensor],
                  irreps_in_dict: Dict[str, Irreps],
                  irreps_out_dict: Dict[str, Irreps]=None, # not used in DepthwiseTensorProductModuleDict
                  irreps_edge_vec: Irreps=None, 
                  irreps_edge_length: Irreps=None,
                  edge_vec_dict: Dict[str, Tensor]=None, 
                  edge_length_dict: Dict[str, Tensor]=None):
    """Unit test for checking whether a model (GNN model/layer) is 
    rotation equivariant.
    """
    # Rotate Output
    if isinstance(model, DepthwiseTensorProductModuleDict):
        out_1: Dict[str, Tensor] = model(edge_fea_dict=x_dict, edge_vec_dict=edge_vec_dict, edge_length_dict=edge_length_dict)
        irreps_out_dict = irreps_in_dict
    else:
        out_1: Dict[str, Tensor] = model(x_dict=x_dict)
    for edge, v in out_1.items():
        irreps_out: Irreps = irreps_out_dict[edge]
        D_out = irreps_out.D_from_matrix(rot)
        out_1[edge] = v @ D_out.T

    # Rotate Input
    new_x_dict: Dict[str, Tensor] = {}
    for edge, v in x_dict.items():
        irreps_in: Irreps = irreps_in_dict[edge]
        D_in = irreps_in.D_from_matrix(rot)
        #TODO (Notice !) data is not changed here ! WHY ?
        # x_dict[edge] = v @ D_in.T         ## (Notice !)
        new_x_dict[edge] = v @ D_in.T       ## (Notice !)

    if isinstance(model, DepthwiseTensorProductModuleDict):
        new_edge_vec_dict: Dict[str, Tensor] = {}
        new_edge_length_dict: Dict[str, Tensor] = {}
        for edge, _ in x_dict.items():
            irreps_in_vec: Irreps = irreps_edge_vec
            irreps_in_length: Irreps = irreps_edge_length
            D_in_vec = irreps_in_vec.D_from_matrix(rot)
            D_in_length = irreps_in_length.D_from_matrix(rot)
            new_edge_vec_dict[edge] = edge_vec_dict[edge] @ D_in_vec.T       
            new_edge_length_dict[edge] = edge_length_dict[edge] @ D_in_length.T

    # Forward pass on rotated example
    if isinstance(model, DepthwiseTensorProductModuleDict):
        out_2: Dict[str, Tensor] = model(edge_fea_dict=new_x_dict, 
                edge_vec_dict=new_edge_vec_dict, edge_length_dict=new_edge_length_dict)
    else:
        out_2 = model(x_dict=new_x_dict)
    
    return out_1, out_2


torch.manual_seed(seed)
torch.random.manual_seed(seed)
rot = e3nn.o3.rand_matrix()
print(rot)

out_1_dict, out_2_dict = outs_rot_invariance(
    rot=rot, model=net, x_dict=edge_fea_dict,
    irreps_in_dict=irreps_edge_dict, irreps_out_dict=irreps_edge_dict,

    irreps_edge_vec=irreps_edge_vec, 
    irreps_edge_length=irreps_edge_length,
    edge_vec_dict=edge_vec_dict, 
    edge_length_dict=edge_length_dict
)

# = parameter 3 = 
out_1 = out_1_dict[key]
out_2 = out_2_dict[key]
min_out_1 = out_1.abs().min()
min_out_2 = out_2.abs().min()
max_out_1 = out_1.abs().max()
max_out_2 = out_2.abs().max()
ave_out_1 = out_1.abs().mean()
ave_out_2 = out_2.abs().mean()

error = out_2 - out_1
max_error = error.abs().max()
ave_error = error.abs().mean()

print(f"Is {type(net).__name__} rotation equivariant?")
print(f"Max Error: {max_error}.\n"
      f"Mean Error (L1Loss): {ave_error}.\n"
      f"("
      f"Min Absolute of out_1: {min_out_1}\n"
      f"Min Absolute of out_2: {min_out_2}\n"
      f"Max Absolute of out_1: {max_out_1}\n"
      f"Max Absolute of out_2: {max_out_2}\n"
      f"Mean Absolute Value out_1: {ave_out_1}\n"
      f"Mean Absolute Value of out_2: {ave_out_2})"
      )

tensor([[-0.1334,  0.0134,  0.9910],
        [-0.5643, -0.8230, -0.0649],
        [ 0.8147, -0.5678,  0.1174]])
Is DepthwiseTensorProductModuleDict rotation equivariant?
Max Error: 1.5735626220703125e-05.
Mean Error (L1Loss): 1.2073162451997632e-06.
(Min Absolute of out_1: 0.00013166152348276228
Min Absolute of out_2: 0.00013256072998046875
Max Absolute of out_1: 18.542585372924805
Max Absolute of out_2: 18.54258918762207
Mean Absolute Value out_1: 1.7347888946533203
Mean Absolute Value of out_2: 1.7347891330718994)


# Test Part Two
5. `ElementLevelScatterModuleDict`
6. `TypeLevelScatterModuleDict`

In [3]:
from typing import Dict, Tuple
from torch import Tensor
import torch
import e3nn
from e3nn.o3 import Irreps

# edge_a, edge_b: "6-8", "8-6"
# node_a, node_b: "6", "8"

irreps_nodes_a: Irreps = Irreps("16x0e+8x1e+4x2e")
irreps_nodes_b: Irreps = Irreps("16x0e+8x1e+4x2e")
irreps_node_dict: Dict[str, Irreps] = {"6": irreps_nodes_a, "8": irreps_nodes_b}
irreps_edges_a: Irreps = Irreps("1x0e+16x2e")
irreps_edges_b: Irreps = Irreps("8x0e+16x1e")
irreps_edge_dict: Dict[str, Irreps] = {"6-8": irreps_edges_a, "8-6": irreps_edges_b}

num_nodes_a: int = 6
num_nodes_b: int = 12
num_nodes_dict: Dict[str, int] = {"6": num_nodes_a, "8": num_nodes_b}
edge_index_a: Tensor = torch.tensor([[x, y] for x in range(num_nodes_a) for y in range(num_nodes_b-1)]).T.contiguous()
edge_index_b: Tensor = torch.tensor([[y, x] for x in range(num_nodes_a-1) for y in range(num_nodes_b-1)]).T.contiguous()
# nearly fully connected, last "6" is isolated in "8-6", last "8" is isolated in both.
edge_index_dict = {"6-8": edge_index_a, "8-6": edge_index_b} 

num_edges_a: int = edge_index_a.shape[1]
num_edges_b: int = edge_index_b.shape[1]
num_heads: int = 3

# = parameter 0 = 
seed = 0
torch.manual_seed(seed)
torch.random.manual_seed(seed)
edge_fea_a = irreps_edges_a.randn(num_edges_a, -1, normalization="component")
edge_fea_b = irreps_edges_b.randn(num_edges_b, -1, normalization="component")
edge_fea_a = irreps_edges_a.randn(num_edges_a, num_heads, -1, normalization="component")
edge_fea_b = irreps_edges_b.randn(num_edges_b, num_heads, -1, normalization="component")
edge_fea_dict = {"6-8": edge_fea_a, "8-6": edge_fea_b}

from modelname.nn.moduledict import ElementLevelScatterModuleDict
from modelname.nn.moduledict import TypeLevelScatterModuleDict
net_5 = ElementLevelScatterModuleDict(irreps_message_dict=irreps_edge_dict, irreps_node_dict=irreps_node_dict)
net_6 = TypeLevelScatterModuleDict()

# = parameter 1 = 
net = net_6 # net_5 or net_6
print(net) # Note that non-scalars have no bias even bias=True

if isinstance(net, ElementLevelScatterModuleDict):
    out, other = net(message_dict=edge_fea_dict, edge_index_dict=edge_index_dict, num_nodes_dict=num_nodes_dict)
elif isinstance(net, TypeLevelScatterModuleDict):
    out, other = net_5(message_dict=edge_fea_dict, edge_index_dict=edge_index_dict, num_nodes_dict=num_nodes_dict)
    out = net(type_message_dict=out, num_types_dict=other)

# = parameter 2 =
key: str = "6" # "6" or "8" 
print(out[key].shape)
if isinstance(net, ElementLevelScatterModuleDict):
    print(other[key])

TypeLevelScatterModuleDict()
torch.Size([6, 3, 60])


  num_type_scatter = out_scatter.new_tensor(num_type_scatter)


In [4]:
def outs_rot_invariance(rot: Tensor, 
                  model: torch.nn.Module, 

                  irreps_in_dict: Dict[str, Irreps], 
                  irreps_out_dict: Dict[str, Irreps],

                  edge_fea_dict: Dict[str, Tensor],
                  edge_index_dict: Dict[str, Tensor],
                  num_nodes_dict: Dict[str, int],
                  ):
    """Unit test for checking whether a model (GNN model/layer) is 
    rotation equivariant.
    """
    # Rotate Output
    if isinstance(model, ElementLevelScatterModuleDict):
        out_1, others_1 = model(message_dict=edge_fea_dict, edge_index_dict=edge_index_dict, num_nodes_dict=num_nodes_dict)
    elif isinstance(model, TypeLevelScatterModuleDict):
        out_1, others_1 = net_5(message_dict=edge_fea_dict, edge_index_dict=edge_index_dict, num_nodes_dict=num_nodes_dict)
        out_1 = model(type_message_dict=out_1, num_types_dict=others_1)

    for edge, v in out_1.items():
        irreps_out: Irreps = irreps_out_dict[edge]
        D_out = irreps_out.D_from_matrix(rot)
        out_1[edge] = v @ D_out.T

    # Rotate Input
    new_edge_fea_dict: Dict[str, Tensor] = {}
    for edge, v in edge_fea_dict.items():
        irreps_in: Irreps = irreps_in_dict[edge]
        D_in = irreps_in.D_from_matrix(rot)
        #TODO (Notice !) data is not changed here ! WHY ?
        # x_dict[edge] = v @ D_in.T         ## (Notice !)
        new_edge_fea_dict[edge] = v @ D_in.T       ## (Notice !)

    # Forward pass on rotated example
    if isinstance(model, ElementLevelScatterModuleDict):
        out_2, others_2 = model(message_dict=new_edge_fea_dict, edge_index_dict=edge_index_dict, num_nodes_dict=num_nodes_dict)
    elif isinstance(model, TypeLevelScatterModuleDict):
        out_2, others_2 = net_5(message_dict=new_edge_fea_dict, edge_index_dict=edge_index_dict, num_nodes_dict=num_nodes_dict)
        out_2 = model(type_message_dict=out_2, num_types_dict=others_2)

    return out_1, out_2


torch.manual_seed(seed)
torch.random.manual_seed(seed)
rot = e3nn.o3.rand_matrix()
print(rot)

out_1_dict, out_2_dict = outs_rot_invariance(
    rot=rot, model=net, 
    irreps_in_dict=irreps_edge_dict, irreps_out_dict=irreps_node_dict,

    edge_fea_dict=edge_fea_dict,
    edge_index_dict=edge_index_dict,
    num_nodes_dict=num_nodes_dict,
    
)

# = parameter 3 = 
out_1 = out_1_dict[key]
out_2 = out_2_dict[key]
min_out_1 = out_1.abs().min()
min_out_2 = out_2.abs().min()
max_out_1 = out_1.abs().max()
max_out_2 = out_2.abs().max()
ave_out_1 = out_1.abs().mean()
ave_out_2 = out_2.abs().mean()

error = out_2 - out_1
max_error = error.abs().max()
ave_error = error.abs().mean()

print(f"Is {type(net).__name__} rotation equivariant?")
print(f"Max Error: {max_error}.\n"
      f"Mean Error (L1Loss): {ave_error}.\n"
      f"("
      f"Min Absolute of out_1: {min_out_1}\n"
      f"Min Absolute of out_2: {min_out_2}\n"
      f"Max Absolute of out_1: {max_out_1}\n"
      f"Max Absolute of out_2: {max_out_2}\n"
      f"Mean Absolute Value out_1: {ave_out_1}\n"
      f"Mean Absolute Value of out_2: {ave_out_2})"
      )

tensor([[-0.1334,  0.0134,  0.9910],
        [-0.5643, -0.8230, -0.0649],
        [ 0.8147, -0.5678,  0.1174]])
Is TypeLevelScatterModuleDict rotation equivariant?
Max Error: 1.7881393432617188e-07.
Mean Error (L1Loss): 1.2281034855732287e-08.
(Min Absolute of out_1: 0.0
Min Absolute of out_2: 0.0
Max Absolute of out_1: 1.086525321006775
Max Absolute of out_2: 1.0865254402160645
Mean Absolute Value out_1: 0.13905927538871765
Mean Absolute Value of out_2: 0.13905927538871765)


# Test Part Three
7. `SoftmaxScatteNormModuleDict` (Only for "0e" Scalars)

In [5]:
from typing import Dict, Tuple
from torch import Tensor
import torch
import e3nn
from e3nn.o3 import Irreps

# edge_a, edge_b: "6-8", "8-6"
irreps_edges_a: Irreps = Irreps("10x0e")
irreps_edges_b: Irreps = Irreps("8x0e")
irreps_edge_dict: Dict[str, Irreps] = {"6-8": irreps_edges_a, "8-6": irreps_edges_b}

num_nodes_a: int = 6
num_nodes_b: int = 12
num_nodes_dict: Dict[str, int] = {"6": num_nodes_a, "8": num_nodes_b}
edge_index_a: Tensor = torch.tensor([[x, y] for x in range(num_nodes_a) for y in range(num_nodes_b-1)]).T.contiguous()
edge_index_b: Tensor = torch.tensor([[y, x] for x in range(num_nodes_a-1) for y in range(num_nodes_b-1)]).T.contiguous()
# nearly fully connected, last "6" is isolated in "8-6", last "8" is isolated in both.
edge_index_dict = {"6-8": edge_index_a, "8-6": edge_index_b} 

num_edges_a: int = edge_index_a.shape[1]
num_edges_b: int = edge_index_b.shape[1]
num_heads: int = 3


# = parameter 0 = 
seed = 0
torch.manual_seed(seed)
torch.random.manual_seed(seed)
edge_fea_a = irreps_edges_a.randn(num_edges_a, -1, normalization="component")
edge_fea_b = irreps_edges_b.randn(num_edges_b, -1, normalization="component")
edge_fea_a = irreps_edges_a.randn(num_edges_a, num_heads, -1, normalization="component")
edge_fea_b = irreps_edges_b.randn(num_edges_b, num_heads, -1, normalization="component")
edge_fea_dict = {"6-8": edge_fea_a, "8-6": edge_fea_b}


from modelname.nn.moduledict import SoftmaxScatterNormModuleDict
net_7 = SoftmaxScatterNormModuleDict(irreps_in_dict=irreps_edge_dict)

# = parameter 1 = 
net = net_7 # net_7
print(net) # Note that non-scalars have no bias even bias=True
if isinstance(net, SoftmaxScatterNormModuleDict):
    out = net(edge_fea_dict=edge_fea_dict , edge_index_dict=edge_index_dict, num_nodes_dict=num_nodes_dict)

# = parameter 2 = 
key = "8-6" # "8-6" or "6-8"
in_ = edge_fea_dict[key]
in_var = in_.var()
in_mean = in_.mean()
out_ = out[key]
out_var = out_.var()
out_mean = out_.mean()
print(
    f"Mean Value of Output: {out_mean}\n"
    f"Variance value of Output: {out_var}\n"
    f"("
    f"Mean Value of Input: {in_mean}\n"
    f"Variance value of Input: {in_var})\n"
)


SoftmaxScatterNormModuleDict()
Mean Value of Output: 0.09090909361839294
Variance value of Output: 0.007198275998234749
(Mean Value of Input: 0.023028507828712463
Variance value of Input: 0.9667189717292786)



# Test Part Four
7. `TransformerModuleDict`

In [6]:
from typing import Dict, Tuple
from torch import Tensor
import torch
import e3nn
from e3nn.o3 import Irreps

# edge_a, edge_b: "6-8", "8-6"
# node_a, node_b: "6", "8"

irreps_nodes_a: Irreps = Irreps("16x0e+8x1e+4x2e")
irreps_nodes_b: Irreps = Irreps("16x0e+8x1e+4x2e")
irreps_node_dict: Dict[str, Irreps] = {"6": irreps_nodes_a, "8": irreps_nodes_b}
irreps_edges_a: Irreps = Irreps("32x0e+16x1e")
irreps_edges_b: Irreps = Irreps("8x0e+16x1e")
irreps_edge_dict: Dict[str, Irreps] = {"6-8": irreps_edges_a, "8-6": irreps_edges_b}
irreps_edge_length: Irreps = Irreps("16x0e") # should only contain Irrep="0e"
irreps_edge_vec: Irreps = Irreps("8x0e+4x1e+4x2e+4x3e")
# irreps_edge_length_dict = {"6-8": irreps_edge_length, "8-6": irreps_edge_length}

num_nodes_a: int = 6
num_nodes_b: int = 12
num_nodes_dict: Dict[str, int] = {"6": num_nodes_a, "8": num_nodes_b}
edge_index_a: Tensor = torch.tensor([[x, y] for x in range(num_nodes_a) for y in range(num_nodes_b-1)]).T.contiguous()
edge_index_b: Tensor = torch.tensor([[y, x] for x in range(num_nodes_a-1) for y in range(num_nodes_b-1)]).T.contiguous()
# nearly fully connected, last "6" is isolated in "8-6", last "8" is isolated in both.
edge_index_dict = {"6-8": edge_index_a, "8-6": edge_index_b} 

num_edges_a: int = edge_index_a.shape[1]
num_edges_b: int = edge_index_b.shape[1]
num_heads: int = 4

# = parameter 0 = 
seed = 0
torch.manual_seed(seed)
torch.random.manual_seed(seed)
edge_fea_a = irreps_edges_a.randn(num_edges_a, -1, normalization="component")
edge_fea_b = irreps_edges_b.randn(num_edges_b, -1, normalization="component")
edge_fea_dict = {"6-8": edge_fea_a, "8-6": edge_fea_b}
edge_vec_a = irreps_edge_vec.randn(num_edges_a, -1, normalization="component")
edge_vec_b = torch.randn(1) * irreps_edge_vec.randn(num_edges_b, -1, normalization="component") 
edge_vec_dict = {"6-8": edge_vec_a, "8-6": edge_vec_b}
edge_length_a = irreps_edge_length.randn(num_edges_a, -1, normalization="component")
edge_length_b = torch.randn(1) * irreps_edge_length.randn(num_edges_b, -1, normalization="component")
edge_length_dict = {"6-8": edge_length_a, "8-6": edge_length_b}
node_fea_a = irreps_nodes_a.randn(num_nodes_a, -1, normalization="component")
node_fea_b = torch.randn(1) * irreps_nodes_b.randn(num_nodes_b, -1, normalization="component")
node_fea_dict = {"6": node_fea_a, "8": node_fea_b}

from modelname.nn.moduledict import TransformerModuleDict
net_7 = TransformerModuleDict(
        irreps_node_fea_dict=irreps_node_dict, 
        irreps_edge_fea_dict=irreps_edge_dict, 
        irreps_edge_vec_embed=irreps_edge_vec, 
        irreps_edge_length_embed=irreps_edge_length,
        num_heads=num_heads, 
        alpha_dropout=0.0
    )

# = parameter 1 = 
net = net_7 # net_7 
print(net) # Note that non-scalars have no bias even bias=True

if isinstance(net, TransformerModuleDict):
    out_edge, out_node = net(
        edge_fea_dict=edge_fea_dict, 
        node_fea_dict=node_fea_dict, 
        edge_index_dict=edge_index_dict, 
        edge_vec_embed_dict=edge_vec_dict, 
        edge_length_embed_dict=edge_length_dict,
        num_nodes_dict=num_nodes_dict
    )

# = parameter 2 =
key_edge: str = "8-6" # "6-8" or "8-6" 
key_node: str = "8" # "6" or "8" 
print(out_edge[key_edge].shape)
print(out_node[key_node].shape)

TransformerModuleDict(
  (node_src_linear_dict): LinearModuleDict(
    (linear_dict): ModuleDict(
      (Linear(6-8)): LinearRS(
        (tp): TensorProduct(16x0e+8x1e+4x2e x 1x0e -> 32x0e+16x1e | 640 paths | 640 weights)
        (bias): ParameterList(  (0): Parameter containing: [torch.float32 of size 32])
      )
      (Linear(8-6)): LinearRS(
        (tp): TensorProduct(16x0e+8x1e+4x2e x 1x0e -> 8x0e+16x1e | 256 paths | 256 weights)
        (bias): ParameterList(  (0): Parameter containing: [torch.float32 of size 8])
      )
    )
  )
  (node_dst_linear_dict): LinearModuleDict(
    (linear_dict): ModuleDict(
      (Linear(6-8)): LinearRS(
        (tp): TensorProduct(16x0e+8x1e+4x2e x 1x0e -> 32x0e+16x1e | 640 paths | 640 weights)
        (bias): ParameterList()
      )
      (Linear(8-6)): LinearRS(
        (tp): TensorProduct(16x0e+8x1e+4x2e x 1x0e -> 8x0e+16x1e | 256 paths | 256 weights)
        (bias): ParameterList()
      )
    )
  )
  (message_dtp_dict): DepthwiseTensorProduct

In [7]:
def outs_rot_invariance(rot: Tensor, 
                  model: torch.nn.Module, 

                  irreps_node_dict: Dict[str, Irreps], 
                  irreps_edge_dict: Dict[str, Irreps],
                  irreps_edge_vec: Irreps,

                  edge_fea_dict: Dict[str, Tensor],
                  node_fea_dict: Dict[str, Tensor],
                  edge_index_dict: Dict[str, Tensor],
                  edge_vec_dict: Dict[str, Tensor],
                  edge_length_dict: Dict[str, Tensor],
                  num_nodes_dict: Dict[str, int],
                  ):
    """Unit test for checking whether a model (GNN model/layer) is 
    rotation equivariant.
    """
    # Rotate Output
    if isinstance(model, TransformerModuleDict):
        out_edge_1, out_node_1 = model(
            edge_fea_dict=edge_fea_dict, 
            node_fea_dict=node_fea_dict, 
            edge_index_dict=edge_index_dict, 
            edge_vec_embed_dict=edge_vec_dict, 
            edge_length_embed_dict=edge_length_dict,
            num_nodes_dict=num_nodes_dict
        )
        
    for edge, v in out_edge_1.items():
        irreps_out: Irreps = irreps_edge_dict[edge]
        D_out = irreps_out.D_from_matrix(rot)
        out_edge_1[edge] = v @ D_out.T
    for node, v in out_node_1.items():
        irreps_out: Irreps = irreps_node_dict[node]
        D_out = irreps_out.D_from_matrix(rot)
        out_node_1[node] = v @ D_out.T

    # Rotate Input
    new_edge_fea_dict: Dict[str, Tensor] = {}
    new_node_fea_dict: Dict[str, Tensor] = {}
    new_edge_vec_dict: Dict[str, Tensor] = {}
    for edge, v in edge_fea_dict.items():
        irreps_in: Irreps = irreps_edge_dict[edge]
        D_in = irreps_in.D_from_matrix(rot)
        new_edge_fea_dict[edge] = v @ D_in.T 
    for node, v in node_fea_dict.items():
        irreps_in: Irreps = irreps_node_dict[node]
        D_in = irreps_in.D_from_matrix(rot)
        new_node_fea_dict[node] = v @ D_in.T
    for edge, v in edge_vec_dict.items():
        irreps_in: Irreps = irreps_edge_vec
        D_in = irreps_in.D_from_matrix(rot)
        new_edge_vec_dict[edge] = v @ D_in.T 

    # Forward pass on rotated example
    if isinstance(model, TransformerModuleDict):
        out_edge_2, out_node_2 = model(
            edge_fea_dict=new_edge_fea_dict, 
            node_fea_dict=new_node_fea_dict, 
            edge_index_dict=edge_index_dict, 
            edge_vec_embed_dict=new_edge_vec_dict, 
            edge_length_embed_dict=edge_length_dict,
            num_nodes_dict=num_nodes_dict
        )

    return out_edge_1, out_node_1, out_edge_2, out_node_2


torch.manual_seed(seed)
torch.random.manual_seed(seed)
rot = e3nn.o3.rand_matrix()
print(rot)

out_edge_1, out_node_1, out_edge_2, out_node_2 = outs_rot_invariance(
    rot=rot, 
    model=net, 
    irreps_node_dict=irreps_node_dict, 
    irreps_edge_dict=irreps_edge_dict,
    irreps_edge_vec=irreps_edge_vec,
    edge_fea_dict=edge_fea_dict,
    node_fea_dict=node_fea_dict,
    edge_index_dict=edge_index_dict,
    edge_vec_dict=edge_vec_dict,
    edge_length_dict=edge_length_dict,
    num_nodes_dict=num_nodes_dict,
)

# = parameter 3 = 
# Do not forget to alpha_drop=0.0 when test equivariance !
# Edge 
out_1 = out_edge_1[key_edge]
out_2 = out_edge_2[key_edge]
# Node
"""
out_1 = out_node_1[key_node]
out_2 = out_node_2[key_node]
"""
min_out_1 = out_1.abs().min()
min_out_2 = out_2.abs().min()
max_out_1 = out_1.abs().max()
max_out_2 = out_2.abs().max()
ave_out_1 = out_1.abs().mean()
ave_out_2 = out_2.abs().mean()

error = out_2 - out_1
max_error = error.abs().max()
ave_error = error.abs().mean()

print(f"Is {type(net).__name__} rotation equivariant?")
print(f"Max Error: {max_error}.\n"
      f"Mean Error (L1Loss): {ave_error}.\n"
      f"("
      f"Min Absolute of out_1: {min_out_1}\n"
      f"Min Absolute of out_2: {min_out_2}\n"
      f"Max Absolute of out_1: {max_out_1}\n"
      f"Max Absolute of out_2: {max_out_2}\n"
      f"Mean Absolute Value out_1: {ave_out_1}\n"
      f"Mean Absolute Value of out_2: {ave_out_2})"
      )

tensor([[-0.1334,  0.0134,  0.9910],
        [-0.5643, -0.8230, -0.0649],
        [ 0.8147, -0.5678,  0.1174]])
Is TransformerModuleDict rotation equivariant?
Max Error: 1.4007091522216797e-06.
Mean Error (L1Loss): 8.836180853677433e-08.
(Min Absolute of out_1: 3.4095199225703254e-05
Min Absolute of out_2: 3.4074535506078973e-05
Max Absolute of out_1: 1.0643510818481445
Max Absolute of out_2: 1.0643517971038818
Mean Absolute Value out_1: 0.10497470200061798
Mean Absolute Value of out_2: 0.10497473180294037)
