# EGNN - PyTorch

Implementation of "E(n)-Equivariant Graph Neural Networks", in Pytorch. May be used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.

Link to paper: https://arxiv.org/pdf/2102.09844v1.pdf

Credit: https://github.com/lucidrains/egnn-pytorch

Google Colab: https://drive.google.com/file/d/1r5faRM8TQGFt8u3Gf_GNHlpo86hfH46B/view?usp=sharing

In [None]:
# Install egnn-pytorch library
!pip install egnn-pytorch

### <b>Usage</b>

In [2]:
import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)

feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)

With edges

In [4]:
import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)

A full EGNN network

In [5]:
import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    num_positions = 1024,           # unless what you are passing in is an unordered set, set this to the maximum sequence length
    dim = 32,
    depth = 3,
    num_nearest_neighbors = 8,
    coor_weights_clamp_value = 2.   # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3)         # (1, 1024, 3)
mask = torch.ones_like(feats).bool()    # (1, 1024)

feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)

Only attend to sparse neighbors, given to the network as an adjacency matrix.

In [6]:
import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    only_sparse_neighbors = True
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

You can also have the network automatically determine the Nth-order neighbors, and pass in an adjacency embedding (depending on the order) to be used as an edge, with two extra keyword arguments

In [7]:
import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    num_adj_degrees = 3,           # fetch up to 3rd degree neighbors
    adj_dim = 8,                   # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
    only_sparse_neighbors = True
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

### <b>Edges</b>

If you need to pass in continuous edges

In [8]:
import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    edge_dim = 4,
    num_nearest_neighbors = 3
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

continuous_edges = torch.randn(1, 1024, 1024, 4)

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

## <b>Stability</b>

The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this.

In [9]:
import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    num_nearest_neighbors = 32,
    norm_coors = True,              # normalize the relative coordinates
    coor_weights_clamp_value = 2.   # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3)         # (1, 1024, 3)
mask = torch.ones_like(feats).bool()    # (1, 1024)

feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)

## <b>Example</b>

To run the protein backbone denoising example, first install <code>sidechainnet</code>

In [None]:
!pip install sidechainnet

In [12]:
# Taken from the denoise_sparse.py file
import torch
import torch.nn.functional as F

from torch import nn
from torch.optim import Adam

from einops import rearrange, repeat

import sidechainnet as scn
from egnn_pytorch.egnn_pytorch import EGNN_Network

torch.set_default_dtype(torch.float64)

BATCH_SIZE = 1
GRADIENT_ACCUMULATE_EVERY = 16

def cycle(loader, len_thres = 200):
    while True:
        for data in loader:
            if data.seqs.shape[1] > len_thres:
                continue
            yield data

net = EGNN_Network(
    num_tokens = 21,
    num_positions = 200 * 3,   # maximum number of positions - absolute positional embedding since there is inherent order in the sequence
    depth = 5,
    dim = 8,
    num_nearest_neighbors = 16,
    fourier_features = 2,
    norm_coors = True,
    coor_weights_clamp_value = 2.
).cuda()

data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = BATCH_SIZE,
    dynamic_batching = False
)

dl = cycle(data['train'])
optim = Adam(net.parameters(), lr=1e-3)

for _ in range(10000):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        batch = next(dl)
        seqs, coords, masks = batch.seqs, batch.crds, batch.msks

        seqs = seqs.cuda().argmax(dim = -1)
        coords = coords.cuda().type(torch.float64)
        masks = masks.cuda().bool()

        l = seqs.shape[1]
        coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

        # Keeping only the backbone coordinates

        coords = coords[:, :, 0:3, :]
        coords = rearrange(coords, 'b l s c -> b (l s) c')

        seq = repeat(seqs, 'b n -> b (n c)', c = 3)
        masks = repeat(masks, 'b n -> b (n c)', c = 3)

        i = torch.arange(seq.shape[-1], device = seq.device)
        adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

        noised_coords = coords + torch.randn_like(coords)

        feats, denoised_coords = net(seq, noised_coords, adj_mat = adj_mat, mask = masks)

        loss = F.mse_loss(denoised_coords[masks], coords[masks])

        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    print('loss:', loss.item())
    optim.step()
    optim.zero_grad()

SidechainNet(12, 30) was not found in ./sidechainnet_data.
Downloading from https://pitt.box.com/shared/static/hbatd2a750tx8e27yizwinc3hsceeeui.pkl


Downloading file chunks (estimated): 100%|█████████▉| 53059/53085 [03:41<00:00, 249.01chunk/s]

Downloaded SidechainNet to ./sidechainnet_data/sidechainnet_casp12_30.pkl.
SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp12_30.pkl.
loss: 0.9533114220863741
loss: 1.0338555187628922
loss: 1.017777472031135
loss: 1.1071745778064523
loss: 1.0398457268030783
loss: 0.991408605401272
loss: 0.9334685539777721
loss: 0.9690967102615862
loss: 1.0087449563643915
loss: 0.9918022527871619
loss: 1.0251756451352376
loss: 0.9521198788178261
loss: 0.8595171786813144
loss: 0.9942374839622138
loss: 0.971821479404791
loss: 1.229556048778545
loss: 0.9023371740956092
loss: 0.8715814248861558
loss: 0.9833517888991157
loss: 0.887537514864425
loss: 0.9624522769589032
loss: 0.9145576634609973
loss: 0.9660634270532449
loss: 0.8962380119664084
loss: 0.9501079554080546
loss: 0.9462586261867424
loss: 0.9255149947664215
loss: 0.8587703206442908
loss: 0.9337779115607365
loss: 1.011458078380807
loss: 0.9255404319965019
loss: 0.8941405525495936


Downloading file chunks (estimated): 53086chunk [04:00, 249.01chunk/s]                        

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
loss: 0.5008102299737249
loss: 0.3116210015661671
loss: 0.4088390401849865
loss: 0.2956810076054512
loss: 0.298556099289429
loss: 0.3642078679275275
loss: 0.2944852266858359
loss: 0.3547388301283935
loss: 0.3545877455579473
loss: 0.3391852379996731
loss: 0.28490701655438644
loss: 0.3061515019081935
loss: 0.31709608255837235
loss: 0.3663557597645103
loss: 0.42692809316349295
loss: 0.3467423477823183
loss: 0.35458471658029916
loss: 0.5475757459523621
loss: 0.32699711136184717
loss: 0.35187021792886014
loss: 0.3805297302734712
loss: 0.24630857397291103
loss: 0.37515415729427465
loss: 0.32608499671753327
loss: 0.3976378305031718
loss: 0.4125225126453087
loss: 0.4602921506700109
loss: 0.38218807610901534
loss: 0.3985347342424857
loss: 0.3618261202186783
loss: 0.3471810073947113
loss: 0.4315658016220574
loss: 0.39855957894941985
loss: 0.38367209348668224
loss: 0.38446412651092277
loss: 0.41161132097338626
loss: 0.37370769237274