In [81]:
import torch
import e3nn
import ase
import ase.neighborlist
import torch_geometric
import torch_geometric.data
import sys
import os

# Add the parent directory of `notebook/` to sys.path
sys.path.append(os.path.abspath(".."))

default_dtype = torch.float64
torch.set_default_dtype(default_dtype)

In [82]:
# A lattice is a 3 x 3 matrix
# The first index is the lattice vector (a, b, c)
# The second index is a Cartesian index over (x, y, z)

# Polonium with Simple Cubic Lattice
po_lattice = torch.eye(3) * 3.340  # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']

# Silicon with Diamond Structure
si_lattice = torch.tensor([
    [0.      , 2.734364, 2.734364],
    [2.734364, 0.      , 2.734364],
    [2.734364, 2.734364, 0.      ]
])
si_coords = torch.tensor([
    [1.367182, 1.367182, 1.367182],
    [0.      , 0.      , 0.      ]
])
si_types = ['Si', 'Si']

po = ase.Atoms(symbols=po_types, positions=po_coords, cell=po_lattice, pbc=True)
si = ase.Atoms(symbols=si_types, positions=si_coords, cell=si_lattice, pbc=True)
print(po)
print(si)

Atoms(symbols='Po', pbc=True, cell=[3.34, 3.34, 3.34])
Atoms(symbols='Si2', pbc=True, cell=[[0.0, 2.734364, 2.734364], [2.734364, 0.0, 2.734364], [2.734364, 2.734364, 0.0]])


#### We use the ase.neighborlist.neighbor_list algorithm and a radial_cutoff distance to define which edges to include in the graph to represent interactions with neighboring atoms. Note that for a convolutional network, the number of layers determines the receptive field, i.e. how “far out” any given atom can see. So even if a we use a radial_cutoff = 3.5, a two layer network effectively sees 2 * 3.5 = 7 distance units (in this case Angstroms) away and a three layer network 3 * 3.5 = 10.5 distance units. We then store our data in torch_geometric.data.Data objects that we will batch with torch_geometric.data.DataLoader below.

In [83]:
radial_cutoff = 3.5  # Only include edges for neighboring atoms within a radius of 3.5 Angstroms.
type_encoding = {'Po': 0, 'Si': 1}
type_onehot = torch.eye(len(type_encoding))

dataset = []

dummy_energies = torch.randn(2, 1, 1)  # dummy energies for example

for crystal, energy in zip([po, si], dummy_energies):
    # edge_src and edge_dst are the indices of the central and neighboring atom, respectively
    # edge_shift indicates whether the neighbors are in different images / copies of the unit cell
    edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=crystal, cutoff=radial_cutoff, self_interaction=True)
    # print(edge_src)
    # print(edge_dst)
    # print(edge_shift)
    data = torch_geometric.data.Data(
        pos=torch.tensor(crystal.get_positions()),
        lattice=torch.tensor(crystal.cell.array).unsqueeze(0),  # We add a dimension for batching
        x=type_onehot[[type_encoding[atom] for atom in crystal.symbols]],  # Using "dummy" inputs of scalars because they are all C
        edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
        edge_shift=torch.tensor(edge_shift, dtype=default_dtype),
        energy=energy  # dummy energy (assumed to be normalized "per atom")
    )

    dataset.append(data)

print(dataset)



[Data(x=[1, 2], edge_index=[2, 7], pos=[1, 3], lattice=[1, 3, 3], edge_shift=[7, 3], energy=[1, 1]), Data(x=[2, 2], edge_index=[2, 10], pos=[2, 3], lattice=[1, 3, 3], edge_shift=[10, 3], energy=[1, 1])]


In [84]:
dataset[1].x

tensor([[0., 1.],
        [0., 1.]])

In [85]:
batch_size = 2
dataloader = torch_geometric.data.DataLoader(dataset, batch_size=batch_size)

for data in dataloader:
    print(data)
    print(data.batch)
    print(data.pos)
    print(data.x)


DataBatch(x=[3, 2], edge_index=[2, 17], pos=[3, 3], lattice=[2, 3, 3], edge_shift=[17, 3], energy=[2, 1], batch=[3], ptr=[3])
tensor([0, 1, 1])
tensor([[0.0000, 0.0000, 0.0000],
        [1.3672, 1.3672, 1.3672],
        [0.0000, 0.0000, 0.0000]])
tensor([[1., 0.],
        [0., 1.],
        [0., 1.]])




To calculate the vectors associated with each edge for a given torch_geometric.data.Data object representing a single example, we use the following expression:



In [86]:
# edge_src, edge_dst = data['edge_index'][0], data['edge_index'][1]
# print(data['pos'][edge_dst] - data['pos'][edge_src])
# print(data['edge_shift'])
# edge_vec = (data['pos'][edge_dst] - data['pos'][edge_src]
#             + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice'])) 
# print(edge_vec)

In [None]:
# from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork
# from e3nn.nn.models.gate_points_2101 import Network
from utils_model.gate_points_networks import SimpleNetwork

from typing import Dict, Union
import torch_scatter

class SimplePeriodicNetwork(SimpleNetwork):
# class SimplePeriodicNetwork(Network):
    def __init__(self, **kwargs) -> None:
        """The keyword `pool_nodes` is used by SimpleNetwork to determine
        whether we sum over all atom contributions per example. In this example,
        we want use a mean operations instead, so we will override this behavior.
        """
        self.pool = False
        if kwargs['pool_nodes'] == True:
            kwargs['pool_nodes'] = False
            kwargs['num_nodes'] = 1.
            self.pool = True
        super().__init__(**kwargs)

    # Overwriting preprocess method of SimpleNetwork to adapt for periodic boundary data
    def preprocess(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long)

        edge_src = data['edge_index'][0]  # Edge source
        edge_dst = data['edge_index'][1]  # Edge destination

        # We need to compute this in the computation graph to backprop to positions
        # We are computing the relative distances + unit cell shifts from periodic boundaries
        edge_batch = batch[edge_src]
        edge_vec = (data['pos'][edge_dst]
                    - data['pos'][edge_src]
                    + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice'][edge_batch]))

        return batch, data['x'], edge_src, edge_dst, edge_vec

    def forward(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
        # if pool_nodes was set to True, use scatter_mean to aggregate
        output = super().forward(data)
        if self.pool == True:

            return torch_scatter.scatter_mean(output, data.batch, dim=0)  # Take mean over atoms per example
        
        else:
            return output



In [100]:
net = SimplePeriodicNetwork(
    irreps_in="64x0e",  # One hot scalars (L=0 and even parity) on each atom to represent atom type
    irreps_out="200x0e",  # Single scalar (L=0 and even parity) to output (for example) energy
    max_radius=radial_cutoff, # Cutoff radius for convolution
    layers = 3,
    lmax = 3,
    num_neighbors=10.0,  # scaling factor based on the typical number of neighbors
    pool_nodes=True,  # We pool nodes to predict total energy
)


In [101]:
print(net)

SimplePeriodicNetwork(
  (mp): MessagePassing(
    (layers): ModuleList(
      (0): Compose(
        (first): Convolution(
          (sc): FullyConnectedTensorProduct(64x0e x 1x0e -> 200x0e+50x1o+50x2e+50x3o | 12800 paths | 12800 weights)
          (lin1): FullyConnectedTensorProduct(64x0e x 1x0e -> 64x0e | 4096 paths | 4096 weights)
          (fc): FullyConnectedNet[10, 100, 256]
          (tp): TensorProduct(64x0e x 1x0e+1x1o+1x2e+1x3o -> 64x0e+64x1o+64x2e+64x3o | 256 paths | 256 weights)
          (lin2): FullyConnectedTensorProduct(64x0e+64x1o+64x2e+64x3o x 1x0e -> 200x0e+50x1o+50x2e+50x3o | 22400 paths | 22400 weights)
          (lin3): FullyConnectedTensorProduct(64x0e+64x1o+64x2e+64x3o x 1x0e -> 1x0e | 64 paths | 64 weights)
        )
        (second): Gate (200x0e+50x1o+50x2e+50x3o -> 50x0e+50x1o+50x2e+50x3o)
      )
      (1): Compose(
        (first): Convolution(
          (sc): FullyConnectedTensorProduct(50x0e+50x1o+50x2e+50x3o x 1x0e -> 350x0e+50x1o+50x1e+50x2o+50x2e+50x3

In [102]:
for data in dataloader:
    print(net(data).shape)  # One scalar per example

AssertionError: Incorrect last dimension for x

In [103]:
print(net(data))

AssertionError: Incorrect last dimension for x