In [3]:
from typing import Dict, Tuple, List
from torch import Tensor
from e3nn.o3 import Irreps
import torch
from torch_geometric.data import HeteroData
from modelname.graph import to_dtype_edge

# OUT_SLICE = slice(800, 805)                               ## (Notice !)
# IN_SLICE = slice(OUT_SLICE.start+4, OUT_SLICE.stop+4)     ## (Notice !)

def outs_rot_invariance(rot: Tensor, 
                  model: torch.nn.Module, 
                  data: HeteroData,
                  simple_sorted_block_irreps_dict: Dict[Tuple[int, int], Irreps]):
    """Unit test for checking whether a model (GNN model/layer) is 
    rotation and translation invariant.
    """
    # Rotate Output
    out_1: Dict[str, Tensor] = model(data)
    for edge, v in out_1.items():
        irreps_out: Irreps = simple_sorted_block_irreps_dict[to_dtype_edge(edge, [Tuple, int])]
        # print(f"irreps_out: {irreps_out}")            ## (Notice !)
        # print(v[0, OUT_SLICE])                        ## (Notice !)
        D_out = irreps_out.D_from_matrix(rot)
        # predicted Hamiltonian is simple sorted form
        # sort_v = v[:, inv_sort_dict[to_dtype_edge(edge, dtype=[Tuple, int])]]
        out_1[edge] = v @ D_out.T

    # Rotate Input
    # data.x_dict = data.x_dict # node_fea (no change)
    # data.edge_index_dict = data.edge_index_dict # edge_index (no change)
    # data.y_dict = data.y_dict # true y--hamiltonian (no change)
    new_data: HeteroData = HeteroData()                 ## (Notice !)
    new_data.x_dict = data.x_dict                       ## (Notice !)
    new_data.edge_index_dict = data.edge_index_dict     ## (Notice !)
    new_data.edge_fea_dict = {}                         ## (Notice !)
    for edge, v in data.edge_fea_dict.items():
        # dist: Tensor = v[:, 0:1] # (num_edges=num_blocks, 1)
        # vec: Tensor = v[:, 1:4] # (num_edges=num_blocks, 3)
        # matrix: Tensor = v[:, 4:]
        irreps_in: Irreps = Irreps("0e+1o") + simple_sorted_block_irreps_dict[to_dtype_edge(edge, [Tuple, int])]
        # print(f"irreps_in: {irreps_in}")              ## (Notice !)
        # print(v[0, IN_SLICE])                         ## (Notice !)
        D_in = irreps_in.D_from_matrix(rot)
        # overlap matrix has been simple sorted already in HeteroData
        # sort_v = v 
        #TODO (Notice !) data is not changed here ! WHY ?
        # data.edge_fea_dict[edge] = v @ D_in.T         ## (Notice !)
        new_data.edge_fea_dict[edge] = v @ D_in.T       ## (Notice !)

    # Forward pass on rotated example
    # out_2 = model(data)                               ## (Notice !)
    out_2 = model(new_data)                             ## (Notice !)
    # for edge, v in out_2.items():                     ## (Notice !)
    #     print(v[0, OUT_SLICE])                        ## (Notice !)
    
    return out_1, out_2

In [4]:
import os
import e3nn
import torch

from modelname.config import read_train_config, config_recorder
from modelname.graph import HeteroDataset
from modelname.old_0_model import Net


import warnings
warnings.filterwarnings("ignore")
PAIR_TYPE: Tuple[int, int] = (83, 83)
PAIR_TYPE_MODEL_OUT: str = to_dtype_edge(PAIR_TYPE, dtype=str)
root_dir: str = r"/home/muyj/Project/Project_1106_deephe3_example/modelname"
json_file =  os.path.join(root_dir, "config_dir", "config.json")
config: config_recorder = read_train_config(json_file=json_file)
dataset: HeteroDataset = HeteroDataset(
    processed_dir=config.processed_dir, 
    graph_file=config.graph_file, 
    is_spin=config.is_spin, 
    default_dtype_torch=config.default_dtype_torch
    )
print(dataset.simple_sorted_block_irreps_dict[PAIR_TYPE])
model: Net = Net(block_irreps_dict=dataset.simple_sorted_block_irreps_dict)

seed = 1
torch.manual_seed(seed)
torch.random.manual_seed(seed)
rot = e3nn.o3.rand_matrix()
print(rot)

out_1, out_2 = outs_rot_invariance(
    rot=rot,
    model=model, 
    data=dataset[0], 
    simple_sorted_block_irreps_dict=dataset.simple_sorted_block_irreps_dict
)
out_1 = out_1[PAIR_TYPE_MODEL_OUT]
out_2 = out_2[PAIR_TYPE_MODEL_OUT]

min_out_1 = out_1.abs().min()
min_out_2 = out_2.abs().min()
ave_out_1 = out_1.abs().mean()
ave_out_2 = out_2.abs().mean()

error = out_2 - out_1
max_error = error.abs().max()
ave_error = error.abs().mean()

print(f"Is {type(model).__name__} rotation equivariant?")
print(f"Max Error: {max_error}.\n"
      f"Mean Error: {ave_error}.\n"
      f"("
      f"Min Absolute of out_1: {min_out_1}\n"
      f"Min Absolute of out_2: {min_out_2}\n"
      f"Mean Absolute Value out_1: {ave_out_1}\n"
      f"Mean Absolute Value of out_2: {ave_out_2})"
      )

# is_zero = (error[OUT_SLICE] == 0).all()       ## (Notice !)
# max_error = error[OUT_SLICE].abs().max()      ## (Notice !)
# ave_error = error[OUT_SLICE].abs().mean()     ## (Notice !)
# print(is_zero)                                ## (Notice !)
# print(f"Max Error: {max_error}.\n"            ## (Notice !)
#       f"Mean Error: {ave_error}.\n")          ## (Notice !)

90x0e+202x1e+192x2e+112x3e+40x4e+8x5e
tensor([[-0.1991, -0.9799,  0.0117],
        [ 0.9644, -0.1939,  0.1797],
        [-0.1738,  0.0470,  0.9837]])
Is Net rotation equivariant?
Max Error: 0.00021076202392578125.
Mean Error: 5.320175091583224e-07.
(Min Absolute of out_1: 0.0
Min Absolute of out_2: 0.0
Mean Absolute Value out_1: 0.21476514637470245
Mean Absolute Value of out_2: 0.2147650271654129)


In [5]:
import torch
import e3nn
from e3nn.o3 import Irreps

irreps_in = Irreps("2x2e")
irreps_out = Irreps("1x2e")

fea_in = irreps_in.randn(1, -1)
#(Notice !) shallow copy when using Tensor '=' directly or Tensor slice !
fea_out = fea_in[:, :irreps_out.dim]
# fea_out = (fea_in + 0.0)[:, :irreps_out.dim]
print(f"fea_in: {fea_in}")
print(f"fea_out: {fea_out}")
r'''
fea_out[0, 0] = 0.00
# fea_out = fea_out * 100
print(f"fea_in: {fea_in}")
print(f"fea_out: {fea_out}")
'''


fea_in: tensor([[ 0.7244, -0.7022,  1.1661,  0.2605,  0.3506,  1.0203, -1.8349, -2.2149,
          0.0436,  1.3240]])
fea_out: tensor([[ 0.7244, -0.7022,  1.1661,  0.2605,  0.3506]])


'\nfea_out[0, 0] = 0.00\n# fea_out = fea_out * 100\nprint(f"fea_in: {fea_in}")\nprint(f"fea_out: {fea_out}")\n'

In [6]:
seed = 1
torch.manual_seed(seed)
torch.random.manual_seed(seed)
rot = e3nn.o3.rand_matrix()

D_in = irreps_in.D_from_matrix(rot)
D_out = irreps_out.D_from_matrix(rot)
fea_in = fea_in @ D_in.T
fea_out = fea_out @ D_out.T
print(f"fea_in: {fea_in}")
print(f"fea_out: {fea_out}")

fea_in: tensor([[-0.6085,  1.0241, -0.3608,  0.6639, -0.7606, -0.3931,  1.1587,  0.8509,
          1.2740,  2.6879]])
fea_out: tensor([[-0.6085,  1.0241, -0.3608,  0.6639, -0.7606]])
