In [1]:
from typing import List, Tuple, Optional, Union

import torch

from e3nn import o3
from e3nn.util.jit import compile_mode as e3nn_compile_mode, script as e3nn_script
from torch_cluster import radius_graph

from diffusion_edf.equiformer.graph_attention_transformer import TransBlock, EdgeDegreeEmbeddingNetwork
from diffusion_edf.equiformer.gaussian_rbf import GaussianRadialBasisLayer
from diffusion_edf.equiformer.tensor_product_rescale import LinearRS
from diffusion_edf.equiformer.layer_norm import EquivariantLayerNormV2
from diffusion_edf.equiformer.drop import DropPath, GraphDropPath, EquivariantDropout, EquivariantScalarsDropout

from diffusion_edf.embedding import NodeEmbeddingNetwork


In [2]:
device = 'cuda:0'
compile = True

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('128x0e+64x1e+32x2e')
irreps_node_attr = o3.Irreps('1x0e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
irreps_block_output = o3.Irreps('128x0e+64x1e+32x2e')
number_of_basis = 128
fc_neurons = [64, 64]
irreps_head = o3.Irreps('32x0e+16x1o+8x2e')
num_heads = 4
irreps_pre_attn = None
rescale_degree = False
nonlinear_message = False
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = o3.Irreps('128x0e+64x1e+32x2e')
norm_layer = 'layer'
rbf_max_radius = 5.0




rbf = GaussianRadialBasisLayer(number_of_basis, cutoff=rbf_max_radius)
node_enc = NodeEmbeddingNetwork(irreps_input=irreps_input, irreps_node_emb=irreps_node_embedding)
edge_deg_enc = EdgeDegreeEmbeddingNetwork(irreps_node_embedding=irreps_node_embedding,
                                          irreps_edge_attr=irreps_sh, 
                                          fc_neurons = [number_of_basis] + fc_neurons, 
                                          avg_aggregate_num = 4)


if compile:
    rbf = torch.jit.script(rbf).to(device)
    node_enc = e3nn_script(node_enc).to(device)
    edge_deg_enc = e3nn_script(edge_deg_enc).to(device)
else:
    rbf = rbf.to(device)
    node_enc = node_enc.to(device)
    edge_deg_enc = edge_deg_enc.to(device)



In [3]:
drop = DropPath(drop_prob=0.1)
drop = e3nn_script(drop).to(device)

drop(irreps_node_embedding.randn(5,-1, device=device))

tensor([[-1.1914,  0.3048,  0.2026,  ...,  0.6521, -0.0531, -0.8030],
        [ 0.0000, -0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
        [-1.7094,  1.5936, -1.0857,  ..., -1.5737,  0.4267,  1.3627],
        [ 2.2153,  0.3201,  0.3205,  ..., -0.2401,  0.2434,  0.8181],
        [-0.4106,  1.0596,  1.1603,  ...,  2.3205,  0.6703, -0.9027]],
       device='cuda:0')

In [4]:
drop = GraphDropPath(drop_prob=0.1)
drop = e3nn_script(drop).to(device)

drop(irreps_node_embedding.randn(5,-1, device=device), batch=torch.zeros(5, device=device, dtype=torch.long))

tensor([[-1.0839,  0.8606, -2.0560,  ...,  1.2124,  2.0675, -0.3863],
        [ 0.3570, -0.0194, -2.8741,  ...,  0.6079,  0.4545,  2.4892],
        [ 0.1902,  0.5185,  0.2567,  ...,  0.4576,  0.1538,  1.8178],
        [-0.7451, -0.1206, -0.9822,  ...,  1.3339, -0.5530,  0.6448],
        [-0.8463,  1.2973, -0.5368,  ..., -1.8737, -0.6353,  1.1132]],
       device='cuda:0')

In [5]:
drop = EquivariantDropout(irreps=irreps_node_embedding, drop_prob=0.1)
drop = e3nn_script(drop).to(device)

drop(irreps_node_embedding.randn(5,-1, device=device))

tensor([[-0.3162,  0.4865, -0.6099,  ...,  2.3546,  0.5001, -1.9272],
        [ 0.5995,  0.7884, -1.2302,  ..., -1.6519, -1.6850, -0.2015],
        [ 0.0734, -1.8530,  0.3828,  ...,  1.8543, -1.6515,  1.0917],
        [ 1.1131, -0.4321, -0.6455,  ...,  0.7858, -1.1276,  1.4185],
        [-0.0000, -1.3218, -1.1581,  ...,  0.2792,  1.5140,  0.2949]],
       device='cuda:0')

In [6]:
drop = EquivariantScalarsDropout(irreps=irreps_node_embedding, drop_prob=0.1)
drop = e3nn_script(drop).to(device)

drop(irreps_node_embedding.randn(5,-1, device=device))

tensor([[ 2.0147, -0.1392,  0.0690,  ..., -1.6172,  0.8197,  0.0666],
        [ 0.2125, -0.2278, -1.1496,  ...,  1.2899,  0.3301,  0.5980],
        [ 1.7557, -0.5144, -0.0000,  ..., -1.3151, -1.0133, -2.1951],
        [-2.5364, -2.6835,  0.0365,  ...,  0.2068,  0.9934, -1.2140],
        [ 0.4186,  1.0241,  1.2136,  ...,  0.7510, -0.0035, -0.3442]],
       device='cuda:0')

In [7]:
fsda

NameError: name 'fsda' is not defined

In [None]:
# block1 = TransBlock(irreps_node_input=irreps_node_embedding, 
#                     irreps_node_attr=irreps_node_attr,
#                     irreps_edge_attr=irreps_sh, 
#                     irreps_node_output=irreps_block_output,
#                     fc_neurons= [number_of_basis] + fc_neurons, 
#                     irreps_head=irreps_head, 
#                     num_heads=num_heads, 
#                     irreps_pre_attn=irreps_pre_attn, 
#                     rescale_degree=rescale_degree,
#                     nonlinear_message=nonlinear_message,
#                     alpha_drop=alpha_drop, 
#                     proj_drop=proj_drop,
#                     drop_path_rate=drop_path_rate,
#                     irreps_mlp_mid=irreps_mlp_mid,
#                     norm_layer=norm_layer).to(device)



# # block1 = e3nn.util.jit.script(block1)

In [None]:
N_nodes = 5
neighball_radius = 5.0

input_feature = irreps_input.randn(N_nodes, -1, device=device)
node_coord = torch.randn(N_nodes, 3, device=device)
batch = torch.zeros(N_nodes, device=device, dtype=torch.long)
node_attr = torch.ones(5, 1, device=device) # dummy

In [None]:
edge_src, edge_dst = radius_graph(node_coord, r=neighball_radius, batch=batch, loop=False, max_num_neighbors=1000)

In [None]:
edge_vec = node_coord.index_select(0, edge_src) - node_coord.index_select(0, edge_dst)
edge_sh = o3.spherical_harmonics(l=irreps_sh, x=edge_vec, normalize=True, normalization='component')
edge_length = edge_vec.norm(dim=1)
edge_length_emb = rbf(edge_length)

node_emb = node_enc(input_feature)
edge_degree_emb = edge_deg_enc(node_input = node_emb, 
                               edge_attr = edge_sh, 
                               edge_scalars = edge_length_emb, 
                               edge_src = edge_src, 
                               edge_dst = edge_dst)

node_input = node_emb + edge_degree_emb
edge_scalars = edge_length_emb

In [None]:
node_features =  block1(node_input=node_input, 
                        node_attr = node_attr, 
                        edge_src = edge_src, 
                        edge_dst = edge_dst,
                        edge_attr = edge_sh,
                        edge_scalars = edge_scalars,
                        batch = batch)

In [None]:
node_features.shape