# Check Equivariance and Invariance

In [1]:
import sys
sys.path.append("..")

In [2]:
import torch
import numpy as np

In [7]:
import fegnn.util.functional as smolF
from fegnn.models.egnn import ImprovedEgnnLayer, EgnnDynamics, CoordNorm

In [4]:
def gen_fake_data(n_atoms: list[int], d_model: int, n_coord_sets: int):
    coords = []
    feats = []
    masks = []

    for n in n_atoms:
        coords.append(torch.randn((n, n_coord_sets, 3)))
        feats.append(torch.randn((n, d_model)))
        masks.append(torch.ones(n))

    coords = smolF.pad_tensors(coords).transpose(1, 2)
    feats = smolF.pad_tensors(feats)
    mask = smolF.pad_tensors(masks)
    adj_matrix = smolF.adj_from_node_mask(mask)
    mask = mask.unsqueeze(1).expand(-1, n_coord_sets, -1)

    return coords, feats, adj_matrix, mask

In [5]:
def mae(t1, t2):
    err = torch.abs(t1 - t2)
    return err.mean()

In [6]:
def is_equal(t1, t2, eps=1e-5):
        err = torch.abs(t1 - t2)
        equal = err < eps
        return equal

In [8]:
d_model = 256
d_edge = 128
n_coord_sets = 64
edge_attention = "pairwise"
coord_attn_clamp = "tanh"

### Test Equivariance of CoordNorm Layers

In [9]:
# Generate some fake data

n_atoms = [12, 33, 9, 17, 24, 5]
coords, feats, adj, mask = gen_fake_data(n_atoms, d_model, n_coord_sets)

coords = smolF.zero_com(coords, mask)
coords = coords * mask.unsqueeze(-1)

print("coords", coords.shape)
print("feats", feats.shape)
print("adj", adj.shape)
print("mask", mask.shape)

coords torch.Size([6, 64, 33, 3])
feats torch.Size([6, 33, 256])
adj torch.Size([6, 33, 33])
mask torch.Size([6, 64, 33])


In [10]:
norm = CoordNorm(n_coord_sets, mlp_scale=False)
norm_w_mlp = CoordNorm(n_coord_sets, mlp_scale=True)

In [11]:
# Apply random shift

shift = torch.rand((len(n_atoms), n_coord_sets, max(n_atoms), 3))
shifted = coords + shift

In [12]:
coord_norm_out = norm(coords, mask)
shifted_norm_out = norm(shifted, mask)

coord_norm_mlp_out = norm_w_mlp(coords, mask)
shifted_norm_mlp_out = norm_w_mlp(shifted, mask)

In [20]:
coord_norm_com = smolF.calc_com(coord_norm_out, mask).mean()
shifted_norm_com = smolF.calc_com(shifted_norm_out, mask).mean()
coord_norm_mlp_com = smolF.calc_com(coord_norm_mlp_out, mask).mean()
shifted_norm_mlp_com = smolF.calc_com(shifted_norm_mlp_out, mask).mean()

In [21]:
print("Coord norm CoM", coord_norm_com.mean())
print("Shifted norm CoM", shifted_norm_com.mean())
print("Coord norm MLP CoM", coord_norm_mlp_com.mean())
print("Shifted norm MLP CoM", shifted_norm_mlp_com.mean())

Coord norm CoM tensor(-4.2556e-10, grad_fn=<MeanBackward0>)
Shifted norm CoM tensor(2.3044e-09, grad_fn=<MeanBackward0>)
Coord norm MLP CoM tensor(-9.3003e-11, grad_fn=<MeanBackward0>)
Shifted norm MLP CoM tensor(1.4760e-10, grad_fn=<MeanBackward0>)


In [23]:
print("Coord norm CoM", is_equal(coord_norm_com, torch.zeros_like(coord_norm_com)).all())
print("Shifted norm CoM", is_equal(shifted_norm_com, torch.zeros_like(shifted_norm_com)).all())
print("Coord norm MLP CoM", is_equal(coord_norm_mlp_com, torch.zeros_like(coord_norm_mlp_com)).all())
print("Shifted norm MLP CoM", is_equal(shifted_norm_mlp_com, torch.zeros_like(shifted_norm_mlp_com)).all())

Coord norm CoM tensor(True)
Shifted norm CoM tensor(True)
Coord norm MLP CoM tensor(True)
Shifted norm MLP CoM tensor(True)


### Test Equivariance of the Model Layer

In [25]:
layer = ImprovedEgnnLayer(
    d_model,
    d_edge,
    n_coord_sets,
    edge_attention=edge_attention,
    coord_attn_clamp=coord_attn_clamp,
    eps=1e-6
)

In [26]:
# Generate some fake data

n_atoms = [12, 33, 9, 17, 24, 5]
coords, feats, adj, mask = gen_fake_data(n_atoms, d_model, n_coord_sets)

coords = smolF.zero_com(coords, mask)
coords = coords * mask.unsqueeze(-1)

print("coords", coords.shape)
print("feats", feats.shape)
print("adj", adj.shape)
print("mask", mask.shape)

coords torch.Size([6, 64, 33, 3])
feats torch.Size([6, 33, 256])
adj torch.Size([6, 33, 33])
mask torch.Size([6, 64, 33])


In [27]:
# Apply a random rotation

rotation = tuple((np.random.rand(3) * np.pi * 2).tolist())
rotated_coords = [smolF.rotate(cs.flatten(0, 1), rotation) for cs in coords]
rotated_coords = torch.stack(rotated_coords).unflatten(1, (n_coord_sets, -1)).float()

rotated_coords = smolF.zero_com(rotated_coords, mask)
rotated_coords = rotated_coords * mask.unsqueeze(-1)

print("rotated", rotated_coords.shape)

rotated torch.Size([6, 64, 33, 3])


In [28]:
print("CoM", smolF.calc_com(coords, mask).mean())
print("Rotated CoM", smolF.calc_com(rotated_coords, mask).mean())

CoM tensor(-4.8196e-11)
Rotated CoM tensor(-4.0742e-10)


In [29]:
coords_out, feats_out = layer(coords, feats, adj, mask)
coords_out = coords_out.detach().cpu()
feats_out = feats_out.detach().cpu()

coords_out_rotated = [smolF.rotate(cs.flatten(0, 1), rotation) for cs in coords_out]
coords_out_rotated = torch.stack(coords_out_rotated).unflatten(1, (n_coord_sets, -1)).float()

coords_out_rotated = smolF.zero_com(coords_out_rotated, mask)
coords_out_rotated = coords_out_rotated * mask.unsqueeze(-1)

In [30]:
rotated_coords_out, rotated_feats_out = layer(rotated_coords, feats, adj, mask)
rotated_coords_out = rotated_coords_out.detach().cpu()
rotated_feats_out = rotated_feats_out.detach().cpu()

rotated_coords_out = smolF.zero_com(rotated_coords_out, mask)
rotated_coords_out = rotated_coords_out * mask.unsqueeze(-1)

In [31]:
print("Out rotated CoM", smolF.calc_com(coords_out_rotated, mask).mean())
print("Rotated out CoM", smolF.calc_com(rotated_coords_out, mask).mean())

Out rotated CoM tensor(1.0651e-09)
Rotated out CoM tensor(-3.8316e-11)


In [32]:
# Test invariant features
invariant = is_equal(feats_out, rotated_feats_out)

print("All invariant:", (invariant.sum() == invariant.numel()).item())
print(feats_out[0, 0, :8])
print(rotated_feats_out[0, 0, :8])
print("MAE:", mae(feats_out, rotated_feats_out))

All invariant: True
tensor([ 0.2683,  0.1237,  0.5861, -0.9327, -1.3947, -2.0149,  1.2569, -1.7422])
tensor([ 0.2683,  0.1237,  0.5861, -0.9327, -1.3947, -2.0149,  1.2569, -1.7422])
MAE: tensor(5.5412e-08)


In [33]:
# Test equivariant features
equivariant = is_equal(coords_out_rotated, rotated_coords_out)

print("All equivariant:", (equivariant.sum() == equivariant.numel()).item())
print(coords_out_rotated[0, 0, :8])
print(rotated_coords_out[0, 0, :8])
print("MAE:", mae(coords_out_rotated, rotated_coords_out))

All equivariant: True
tensor([[ 1.5528, -0.2032, -0.1132],
        [-0.4472,  0.9468, -0.7487],
        [-1.4964,  0.2370,  0.8031],
        [ 0.0391,  0.9592, -0.1951],
        [-1.5740, -1.5283,  0.2595],
        [ 0.7647, -0.0716,  1.3367],
        [-0.4526,  1.6164,  0.7683],
        [-0.7029, -0.5096, -1.4388]])
tensor([[ 1.5528, -0.2032, -0.1132],
        [-0.4472,  0.9468, -0.7487],
        [-1.4964,  0.2370,  0.8031],
        [ 0.0391,  0.9592, -0.1951],
        [-1.5740, -1.5283,  0.2595],
        [ 0.7647, -0.0716,  1.3367],
        [-0.4526,  1.6164,  0.7683],
        [-0.7029, -0.5096, -1.4388]])
MAE: tensor(1.7985e-08)


### Test equivariance of whole model

In [34]:
n_layers = 9
dynamics = EgnnDynamics(d_model, layer, n_layers, n_coord_sets)

In [35]:
coords, feats, adj, mask = gen_fake_data(n_atoms, d_model, 1)
coords = coords[:, 0]
mask = mask[:, 0]

coords = smolF.zero_com(coords, mask)
coords = coords * mask.unsqueeze(-1)

print("coords", coords.shape)
print("feats", feats.shape)
print("adj", adj.shape)
print("mask", mask.shape)

coords torch.Size([6, 33, 3])
feats torch.Size([6, 33, 256])
adj torch.Size([6, 33, 33])
mask torch.Size([6, 33])


In [36]:
rotation = tuple((np.random.rand(3) * np.pi * 2).tolist())
rotated_coords = [smolF.rotate(cs, rotation) for cs in coords]
rotated_coords = torch.stack(rotated_coords).float()

rotated_coords = smolF.zero_com(rotated_coords, mask)
rotated_coords = rotated_coords * mask.unsqueeze(-1)

print("rotated", rotated_coords.shape)

rotated torch.Size([6, 33, 3])


In [37]:
coords_out, feats_out = dynamics(coords, feats, adj, mask)
coords_out = coords_out.detach().cpu()
feats_out = feats_out.detach().cpu()

coords_out_rotated = [smolF.rotate(cs, rotation) for cs in coords_out]
coords_out_rotated = torch.stack(coords_out_rotated).float()

coords_out_rotated = smolF.zero_com(coords_out_rotated, mask)
coords_out_rotated = coords_out_rotated * mask.unsqueeze(-1)

In [38]:
rotated_coords_out, rotated_feats_out = dynamics(rotated_coords, feats, adj, mask)
rotated_coords_out = rotated_coords_out.detach().cpu()
rotated_feats_out = rotated_feats_out.detach().cpu()

rotated_coords_out = smolF.zero_com(rotated_coords_out, mask)
rotated_coords_out = rotated_coords_out * mask.unsqueeze(-1)

In [39]:
# Test invariant features
invariant = is_equal(feats_out, rotated_feats_out)

print("All invariant:", (invariant.sum() == invariant.numel()).item())
print("MAE:", mae(feats_out, rotated_feats_out))

All invariant: True
MAE: tensor(3.9760e-07)


In [40]:
# Test equivariant features
equivariant = is_equal(coords_out_rotated, rotated_coords_out, eps=1e-4)

print("All equivariant:", (equivariant.sum() == equivariant.numel()).item())
print("MAE:", mae(coords_out_rotated, rotated_coords_out))

All equivariant: True
MAE: tensor(2.5869e-06)
