In [1]:
from typing import List, Tuple, Optional, Union

import torch

from e3nn import o3
from e3nn.util.jit import compile_mode as e3nn_compile_mode, script as e3nn_script
from torch_cluster import radius_graph

from diffusion_edf.equiformer.graph_attention_transformer import TransBlock, EdgeDegreeEmbeddingNetwork, FeedForwardNetwork, SeparableFCTP
from diffusion_edf.equiformer.gaussian_rbf import GaussianRadialBasisLayer
from diffusion_edf.equiformer.tensor_product_rescale import LinearRS, irreps2gate, FullyConnectedTensorProductRescaleSwishGate
from diffusion_edf.equiformer.fast_activation import Activation, Gate

from diffusion_edf.embedding import NodeEmbeddingNetwork


In [2]:
device = 'cuda:0'
compile = True

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('128x0e+64x1e+32x2e')
irreps_node_attr = o3.Irreps('1x0e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
irreps_block_output = o3.Irreps('128x0e+64x1e+32x2e')
number_of_basis = 128
fc_neurons = [64, 64]
irreps_head = o3.Irreps('32x0e+16x1o+8x2e')
num_heads = 4
irreps_pre_attn = None
rescale_degree = False
nonlinear_message = True
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = o3.Irreps('384x0e+192x1e+96x2e')
norm_layer = 'layer'
rbf_max_radius = 5.0




rbf = GaussianRadialBasisLayer(number_of_basis, cutoff=rbf_max_radius)
node_enc = NodeEmbeddingNetwork(irreps_input=irreps_input, irreps_node_emb=irreps_node_embedding)
edge_deg_enc = EdgeDegreeEmbeddingNetwork(irreps_node_embedding=irreps_node_embedding,
                                          irreps_edge_attr=irreps_sh, 
                                          fc_neurons = [number_of_basis] + fc_neurons, 
                                          avg_aggregate_num = 4)


if compile:
    rbf = torch.jit.script(rbf).to(device)
    node_enc = e3nn_script(node_enc).to(device)
    edge_deg_enc = e3nn_script(edge_deg_enc).to(device)
else:
    rbf = rbf.to(device)
    node_enc = node_enc.to(device)
    edge_deg_enc = edge_deg_enc.to(device)



In [3]:
fctp = SeparableFCTP(irreps_node_input = irreps_node_embedding, 
                     irreps_edge_attr = irreps_sh, 
                     irreps_node_output = irreps_node_embedding, 
                     fc_neurons = [number_of_basis] + fc_neurons, 
                     use_activation = True, 
                     norm_layer = 'layer', 
                     internal_weights = False)

fctp = e3nn_script(fctp).to(device)

In [4]:
out = fctp(node_input = irreps_node_embedding.randn(5,-1,device=device), 
     edge_attr = irreps_sh.randn(5,-1,device=device),
     edge_scalars = torch.randn(5, number_of_basis, device=device),
     batch = torch.zeros(5, device=device))

In [6]:
out.shape

torch.Size([5, 480])

In [11]:
out.std()

tensor(1.0172, device='cuda:0', grad_fn=<StdBackward0>)

In [5]:
fsda

NameError: name 'fsda' is not defined

In [None]:
# block1 = TransBlock(irreps_node_input=irreps_node_embedding, 
#                     irreps_node_attr=irreps_node_attr,
#                     irreps_edge_attr=irreps_sh, 
#                     irreps_node_output=irreps_block_output,
#                     fc_neurons= [number_of_basis] + fc_neurons, 
#                     irreps_head=irreps_head, 
#                     num_heads=num_heads, 
#                     irreps_pre_attn=irreps_pre_attn, 
#                     rescale_degree=rescale_degree,
#                     nonlinear_message=nonlinear_message,
#                     alpha_drop=alpha_drop, 
#                     proj_drop=proj_drop,
#                     drop_path_rate=drop_path_rate,
#                     irreps_mlp_mid=irreps_mlp_mid,
#                     norm_layer=norm_layer).to(device)



# # block1 = e3nn.util.jit.script(block1)

In [None]:
N_nodes = 5
neighball_radius = 5.0

input_feature = irreps_input.randn(N_nodes, -1, device=device)
node_coord = torch.randn(N_nodes, 3, device=device)
batch = torch.zeros(N_nodes, device=device, dtype=torch.long)
node_attr = torch.ones(5, 1, device=device) # dummy

In [None]:
edge_src, edge_dst = radius_graph(node_coord, r=neighball_radius, batch=batch, loop=False, max_num_neighbors=1000)

In [None]:
edge_vec = node_coord.index_select(0, edge_src) - node_coord.index_select(0, edge_dst)
edge_sh = o3.spherical_harmonics(l=irreps_sh, x=edge_vec, normalize=True, normalization='component')
edge_length = edge_vec.norm(dim=1)
edge_length_emb = rbf(edge_length)

node_emb = node_enc(input_feature)
edge_degree_emb = edge_deg_enc(node_input = node_emb, 
                               edge_attr = edge_sh, 
                               edge_scalars = edge_length_emb, 
                               edge_src = edge_src, 
                               edge_dst = edge_dst)

node_input = node_emb + edge_degree_emb
edge_scalars = edge_length_emb

In [None]:
node_features =  block1(node_input=node_input, 
                        node_attr = node_attr, 
                        edge_src = edge_src, 
                        edge_dst = edge_dst,
                        edge_attr = edge_sh,
                        edge_scalars = edge_scalars,
                        batch = batch)

In [None]:
node_features.shape