In [53]:
import mdtraj as md
traj = md.load("/Users/tommysisk/asyn/lig47/fasudil_111frames.dcd",
               top="/Users/tommysisk/asyn/lig47/lig47.pdb"
              )
traj = traj.atom_slice(traj.top.select("name CA"))
coords = traj.atom_slice(traj.top.select("name CA")).xyz

In [182]:
import torch
import torch.nn as nn
import numpy as np
import itertools
from torch_scatter import scatter
import torch_geometric
from torch_geometric.data import Data as GeometricData
from torch_geometric.data import InMemoryDataset
from functools import partial
import math

 # Protein Residue stuff

In [247]:
def dict_map(dic, keys):
    check = list(dic.keys())
    assert all(k in check for k in keys), "Not all keys exist in dict"
    return list(map(dic.__getitem__, keys))

single_letter_codes = ["G", "A", "S", "P", "V", "T", "C", "L", "I","N",
                       "D", "Q", "K", "E", "M", "H", "F", "R", "Y", "W"]

three_letter_codes = ["GLY","ALA","SER","PRO","VAL","THR","CYS","LEU","ILE","ASN",
                      "ASP","GLN","LYS","GLU","MET","HIS","PHE","ARG","TYR","TRP"]

abr_to_code_ = dict(zip(three_letter_codes, single_letter_codes))

code_to_index_ = dict(zip(single_letter_codes, range(len(single_letter_codes))))

def get_codes(traj):
    return list(map(str, list(traj.top.to_fasta())[0]))
    
def abr_to_code(keys):
    return dict_map(abr_to_code, keys)

def code_to_index(codes):
    if len(codes[0]) > 1:
        codes = abr_to_code(codes)
    return torch.LongTensor(dict_map(code_to_index_, codes))



def get_residue_bonds(code_sequence):
    indices = code_to_index(code_sequence)
    residue_bond_types_ = torch.arange(400).long().reshape(20, 20)
    return residue_bond_types_[indices[:-1], indices[1:]]

code_sequence = get_codes(traj)
index_sequence = code_to_index(codes)

 # Radial Basis functions

In [476]:
class _SoftUnitStep(torch.autograd.Function):
    # pylint: disable=arguments-differ

    @staticmethod
    def forward(ctx, x) -> torch.Tensor:
        ctx.save_for_backward(x)
        y = torch.zeros_like(x)
        m = x > 0.0
        y[m] = (-1 / x[m]).exp()
        return y

    @staticmethod
    def backward(ctx, dy) -> torch.Tensor:
        (x,) = ctx.saved_tensors
        dx = torch.zeros_like(x)
        m = x > 0.0
        xm = x[m]
        dx[m] = (-1 / xm).exp() / xm.pow(2)
        return dx * dy


def soft_unit_step(x):
    r"""smooth :math:`C^\infty` version of the unit step function

    .. math::

        x \mapsto \theta(x) e^{-1/x}


    Parameters
    ----------
    x : `torch.Tensor`
        tensor of shape :math:`(...)`

    Returns
    -------
    `torch.Tensor`
        tensor of shape :math:`(...)`

    Examples
    --------

    .. jupyter-execute::
        :hide-code:

        import torch
        from e3nn.math import soft_unit_step
        import matplotlib.pyplot as plt

    .. jupyter-execute::

        x = torch.linspace(-1.0, 10.0, 1000)
        plt.plot(x, soft_unit_step(x));
    """
    return _SoftUnitStep.apply(x)
        
        
def soft_one_hot_linspace(x: torch.Tensor, start, end, number, basis=None, cutoff=None) -> torch.Tensor:
    r"""Projection on a basis of functions

    Returns a set of :math:`\{y_i(x)\}_{i=1}^N`,

    .. math::

        y_i(x) = \frac{1}{Z} f_i(x)

    where :math:`x` is the input and :math:`f_i` is the ith basis function.
    :math:`Z` is a constant defined (if possible) such that,

    .. math::

        \langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1

    See the last plot below.
    Note that ``bessel`` basis cannot be normalized.

    Parameters
    ----------
    x : `torch.Tensor`
        tensor of shape :math:`(...)`

    start : float
        minimum value span by the basis

    end : float
        maximum value span by the basis

    number : int
        number of basis functions :math:`N`

    basis : {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}
        choice of basis family; note that due to the :math:`1/x` term, ``bessel`` basis does not satisfy the normalization of
        other basis choices

    cutoff : bool
        if ``cutoff=True`` then for all :math:`x` outside of the interval defined by ``(start, end)``,
        :math:`\forall i, \; f_i(x) \approx 0`

    Returns
    -------
    `torch.Tensor`
        tensor of shape :math:`(..., N)`

    Examples
    --------

    .. jupyter-execute::
        :hide-code:

        import torch
        from e3nn.math import soft_one_hot_linspace
        import matplotlib.pyplot as plt

    .. jupyter-execute::

        bases = ['gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel']
        x = torch.linspace(-1.0, 2.0, 100)

    .. jupyter-execute::

        fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True)

        for axs, b in zip(axss, bases):
            for ax, c in zip(axs, [True, False]):
                plt.sca(ax)
                plt.plot(x, soft_one_hot_linspace(x, -0.5, 1.5, number=4, basis=b, cutoff=c))
                plt.plot([-0.5]*2, [-2, 2], 'k-.')
                plt.plot([1.5]*2, [-2, 2], 'k-.')
                plt.title(f"{b}" + (" with cutoff" if c else ""))

        plt.ylim(-1, 1.5)
        plt.tight_layout()

    .. jupyter-execute::

        fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True)

        for axs, b in zip(axss, bases):
            for ax, c in zip(axs, [True, False]):
                plt.sca(ax)
                plt.plot(x, soft_one_hot_linspace(x, -0.5, 1.5, number=4, basis=b, cutoff=c).pow(2).sum(1))
                plt.plot([-0.5]*2, [-2, 2], 'k-.')
                plt.plot([1.5]*2, [-2, 2], 'k-.')
                plt.title(f"{b}" + (" with cutoff" if c else ""))

        plt.ylim(0, 2)
        plt.tight_layout()
    """
    # pylint: disable=misplaced-comparison-constant

    if cutoff not in [True, False]:
        raise ValueError("cutoff must be specified")

    if not cutoff:
        values = torch.linspace(start, end, number, dtype=x.dtype, device=x.device)
        step = values[1] - values[0]
    else:
        values = torch.linspace(start, end, number + 2, dtype=x.dtype, device=x.device)
        step = values[1] - values[0]
        values = values[1:-1]

    diff = (x[..., None] - values) / step

    if basis == "gaussian":
        return diff.pow(2).neg().exp().div(1.12)

    if basis == "cosine":
        return torch.cos(math.pi / 2 * diff) * (diff < 1) * (-1 < diff)

    if basis == "smooth_finite":
        return 1.14136 * torch.exp(torch.tensor(2.0)) * soft_unit_step(diff + 1) * soft_unit_step(1 - diff)

    if basis == "fourier":
        x = (x[..., None] - start) / (end - start)
        if not cutoff:
            i = torch.arange(0, number, dtype=x.dtype, device=x.device)
            return torch.cos(math.pi * i * x) / math.sqrt(0.25 + number / 2)
        else:
            i = torch.arange(1, number + 1, dtype=x.dtype, device=x.device)
            return torch.sin(math.pi * i * x) / math.sqrt(0.25 + number / 2) * (0 < x) * (x < 1)

    if basis == "bessel":
        x = x[..., None] - start
        c = end - start
        bessel_roots = torch.arange(1, number + 1, dtype=x.dtype, device=x.device) * math.pi
        out = math.sqrt(2 / c) * torch.sin(bessel_roots * x / c) / x

        if not cutoff:
            return out
        else:
            return out * ((x / c) < 1) * (0 < x)

    raise ValueError(f'basis="{basis}" is not a valid entry')

 # Writhe stuff

In [477]:
def product(x: np.ndarray, y: np.ndarray):
    return np.asarray(list(itertools.product(x, y)))


def combinations(x):
    return np.asarray(list(itertools.combinations(x, 2)))


def shifted_pairs(x: np.ndarray, shift: int, ax: int = 1):
    return np.stack([x[:-shift], x[shift:]], ax)


def get_segments(n: int = None,
                 length: int = 1,
                 index0: np.ndarray = None,
                 index1: np.ndarray = None):
    """
    Function to retrieve indices of segment pairs for various use cases.
    Returns an (n_segment_pairs, 4) array where each row (quadruplet) contains : (start1, end1, start2, end2)
    """

    
    if all(i is None for i in (index0, index1)):
        assert n is not None, \
            "Must provide indices (index0:array, (optionally) index1:array) or the number of points (n: int)"
        segments = combinations(shifted_pairs(np.arange(n), length)).reshape(-1, 4)
        return torch.from_numpy(segments[~(segments[:, 1] == segments[:, 2])])

    else:
        assert index0 is not None, ("If providing only one set of indices, must set the index0 argument \n"
                                    "Cannot only supply the index1 argument (doesn't make sense in this context")
        if index1 is not None:
            return torch.from_numpy(product(*[shifted_pairs(i, length) for i in (index0, index1)]).reshape(-1, 4))
        else:
            segments = combinations(shifted_pairs(index0, length)).reshape(-1, 4)
            return torch.from_numpy(segments[~(segments[:, 1] == segments[:, 2])])


##########################################   fastest ways of implementing these linear algebra ops for this purpose  (NOT trivial) ############################################


    
def nnorm(x: torch.Tensor):
    
    """Convenience function for (batched) normalization of vectors stored in arrays with last dimension 3"""
    
    norm = torch.linalg.norm(x, axis=-1)
    
    if x.ndim == 4:
        return x / norm[:, :, :, None]
    elif x.ndim == 3:
        return x / norm[:, :, None]
    elif x.ndim == 2:
        return x / norm[:, None]
    else:
        return x / norm


def ncross(x: torch.Tensor, y: torch.Tensor):

    """Convenience function for (batched) cross products of vectors stored in arrays with last dimension 3""" 

    # c = np.array(list(map(cross,x,y)))
    c = torch.cross(x, y, axis=-1)
    return c


def ndot(x, y):

    """Convenience function for (batched) dot products of vectors stored in arrays with last dimension 3"""

    # d = np.array(list(map(dot,x,y)))[:,None]
    d = torch.sum(x * y, axis=-1)
    return d


def ndet(v1, v2, v3):
    """for the triple product and finding the signed sin of the angle between v2 and v3, v1 should
    be set equal to a vector mutually orthogonal to v2,v3"""
    #     det = np.array(list(map(lambda x,y,z:np.linalg.det(np.array([x,y,z])),
    #                         v1,v2,v3)))[:,None]
    det = ndot(v1, ncross(v2, v3))
    return det


def uproj(a, b, norm_b: bool = True):
    """Convenience function for (batched) othogonal projection of vectors stored in arrays with last dimension 3
    where a is a set of vectors which we be othogonally projected onto a single (batched) set of vectors, b"""
    b = nnorm(b) if norm_b else b
    # faster than np.matmul when using ray
    return a - b * torch.sum(a[:, None, :] * b[:, None, :], -1).transpose(0, 2, 1)


solid_angle = lambda a, b: -torch.arcsin(torch.prod(nnorm(uproj(a, b)), 1).sum(-1).clip(-1, 1))
# or arccos(...) - (np.pi / 2)
# slower than current version with ray
# Usage of solid angle in writhe computation (not as time efficient as implementation in use)
# indices = zip([0, 1, 0, 2],
#               [3, 2, 3, 1],
#               [1, 3, 2, 0])

# omega = np.stack([solid_angle(displacements[:, [i, j]], displacements[:, None, n])
#                   for i, j, n in indices], 1).squeeze().sum(-1)
##############################################################################################################################


def writhe_segment(segment=None, xyz=None, smat=None):
    """compute the writhe (signed crossing) of 2 segments for all frames (index 0) in xyz (xyz can contain just one frame)
    
    THERE ARE 2 INPUT OPTIONS
    
    **provide both of the following**
    
    segment: numpy array of shape (4,) giving the indices of the 4 alpha carbons in xyz creating 2 segments:::
             array([seg1_start_point,seg1_end_point,seg2_start_point,seg2_end_point])

    xyz: numpy array of shape (Nframes, N_alpha_carbons, 3),coordinate array giving the positions of ALL the alpha carbons
    
    **OR just the following**
    
    smat ::: numpy array of shape (Nframes, 4, 3) : sliced coordinate matrix: coordinate array that is pre-sliced
    with only the positions of the 4 alpha carbons constituting the 2 segments to compute the writhe between """

    if smat is None:
        assert not ((segment is None) or (xyz is None)), \
            "must input smat or both a segment and xyz coordinates"
        smat = (xyz[None, :, :] if xyz.ndim < 3 else xyz)[:, segment]
    else:
        smat = smat[None, :, :] if smat.ndim < 3 else smat

    # smat = nnorm(smat)
    sum_dim = None if smat.shape[0] == 1 else 1

    # broadcasting trick
    # negative sign, None placement and order are intentional, don't change without testing equivalent option
    displacements = nnorm((-smat[:, :2, None, :] + smat[:, None, 2:, :]).reshape(-1, 4, 3))

    # array broadcasting is (surprisingly) slower than list comprehensions
    # when using ray for the following operations (without ray, broadcasting should be faster).

    crosses = nnorm(ncross(displacements[:, [0, 1, 3, 2]], displacements[:, [1, 3, 2, 0]]))

    omega = torch.arcsin(ndot(crosses[:, [0, 1, 2, 3]], crosses[:, [1, 2, 3, 0]]).clip(-1, 1)).squeeze().sum(sum_dim)

    signs = torch.sign(ndot(ncross(nnorm(smat[:, 3] - smat[:, 2]),
                                nnorm(smat[:, 1] - smat[:, 0])),
                         displacements[:, 0])).squeeze()

    wr = (1 / (2 * torch.pi)) * (omega * signs)

    return wr


def writhe_segments_along_axis(segments: torch.LongTensor, xyz: torch.Tensor, axis: int = 1):
    """helper function for parallelization to compute writhe over chuncks of segments for all frames in xyz"""
    # noinspection PyTypeChecker
    return torch.stack([writhe_segment(segment, xyz, None) for segment in segments], 1)
    
# Unfinished, need analouge to cpu treatment with ray parallelization

# def calc_writhe_parallel(segments: np.ndarray, xyz: np.ndarray,) -> "Nframes by Nsegments np.ndarray":
#     """parallelize writhe calculation by segment, avoids making multiple copies of coordinate (xyz) matrix,
#     uses torch.scatter to parallelize"""

#     # TODO finish remaking this function to scatter the computation across multiple GPUs
    
#     #n_split = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count()
#     chunks = torch.tensor_split(segments, int(torch.cuda.device_count()))
#     result = torch.concatenate([writhe_segments_along_axis(segments=chunk, xyz=xyz_ref) for chunk in chunks]).T.squeeze()
#     return result


def to_writhe_adj_matrix(writhe_features,
                         n_points,
                         length,
                         segments=None,
                         full_matrix=True):
    
    n = len(writhe_features)

    if segments is None:
        segments = get_segments(n_points, length)

    adj_matrix = torch.zeros((n, n_points, n_points))

    adj_matrix[:, segments[:, 0], segments[:, 2]] = writhe_features
    adj_matrix[:, segments[:, 1], segments[:, 3]] = writhe_features
    adj_matrix = adj_matrix + adj_matrix.swapaxes(1, 2) if full_matrix else adj_matrix

    return adj_matrix.squeeze()

#def to_writhe_pair_list(writhe_features, n_point, length, segments):
    

 # Writhe Layer that can easily integrate into existing cpaiNN source

In [478]:
class TorchWrithe(nn.Module):

    """
    Compute writhe for a set of coordinates. 
    Return an (batch, n_atoms, n_atoms) writhe 'adjacentcy' matrix.
    
    """
    
    def __init__(self, n_atoms: int):
        super().__init__()

        self.register_buffer("segments_", get_segments(n_atoms))
        self.register_buffer("n_atoms_", torch.LongTensor([n_atoms]))
    
    @property
    def n_atoms(self):
        return self.n_atoms_.item()

    @property
    def segments(self):
        return self.segments_
    
    def to_matrix(self, x):
        return to_writhe_adj_matrix(x, self.n_atoms, 1, self.segments)
    
    def forward(self, x):
        return self.to_matrix(writhe_segments_along_axis(self.segments, x.reshape(-1, self.n_atoms, 3)))


class SelfAttention(nn.Module):
    """
    Simple implementation of self attention where we skip the 'value' projection and return only the attention logits (positive)
    """
    
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.query = nn.Linear(input_dim, input_dim, bias=True)
        self.key = nn.Linear(input_dim, input_dim, bias=True)
        #self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        #values = self.value(x)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        #weighted = torch.bmm(attention, values)
        return attention

class WritheEmbedding(nn.Module):
    
    """
    Embed each value of writhe by a super position of a basis set of functions (vectors).
    The weight of each function in the super position is determined by a soft one hot embedding (RBF) of each value of writhe.
    OUTPUT : (batch, n_atoms, n_atoms, bins)
    """
    def __init__(self, embed_dim: int, bins: int=300, basis: str = "gaussian", cutoff: bool=False):
        super().__init__()

        self.soft_one_hot = partial(soft_one_hot_linspace,
                                    start=-1,
                                    end=1,
                                    number=bins,
                                    basis=basis,
                                    cutoff=cutoff)
        
        std = 1. / math.sqrt(embed_dim)

        self.register_parameter("functions", torch.nn.Parameter(torch.Tensor(1, 1, 1, bins, embed_dim).uniform_(-std, std),
                                                                requires_grad=True)
                               )
        
        #self.functions = nn.Parameter(torch.Tensor(1, 1, 1, bins, embed_dim).uniform_(-std, std), requires_grad=True)
    
    def get_weights(self, x):
        return self.soft_one_hot(x).unsqueeze(-1)
        
    
    def forward(self, x):
        return (self.get_weights(x) * self.functions).sum(-2)
        
        

class WritheEmbeddedAttention(nn.Module):

    """
    1) Compute writhe from a set of coordinates.
    2) Use Radial Basis Functions to embed each value of writhe the same dimension as node embedding.
    3) Compute attention logits from node embeddings.
    4) Weight embedded writhe values with attention logits and sum for each node with every other node. 
    """
    
    def __init__(self,
                 node_embed_dim: int,
                 bins: int,
                 n_atoms: int,
                 ):
        
        super().__init__()
        
        self.writhe = TorchWrithe(n_atoms)
        self.attn = SelfAttention(node_embed_dim)
        self.embed = WritheEmbedding(embed_dim=node_embed_dim, bins=bins)
    
    def forward(self, node_embeddings: torch.Tensor, xyz: torch.Tensor):
        attn = self.attn(node_embeddings).unsqueeze(-1)
        writhe_embed = self.embed(self.writhe(xyz))
        return node_embeddings + (attn * writhe_embed).sum(1)



class AddEmbeddedWrithe(nn.Module):

    """
    Class to add writhe embeddings as edge features to batch of graphs.
    Follows structure of other feature additions in ITO classes. 
    1) Compute writhe from a set of coordinates.
    2) Use Radial Basis Functions to embed each value of writhe the same dimension as node embedding.
    """
    
    def __init__(self,
                 embed_dim: int,
                 bins: int,
                 n_atoms: int,
                 ):
        
        super().__init__()
        
        self.writhe = TorchWrithe(n_atoms)
        self.embed = WritheEmbedding(embed_dim=embed_dim, bins=bins)
    
    def forward(self, batch):
        
        # indexing trick to ensure writhe values are correct
        index = batch.edge_index[0] // self.writhe.n_atoms
        index_src, index_dst = batch.edge_index % self.writhe.n_atoms
        
        writhe_embed = self.embed(self.writhe(batch.x))
        batch.writhe = writhe_embed[index, index_src, index_dst, :]
        return batch
        

 # cpaiNN source

In [479]:
class MLP(torch.nn.Module):
    def __init__(self, f_in, f_hidden, f_out, skip_connection=False):
        super().__init__()
        self.skip_connection = skip_connection

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(f_in, f_hidden),
            torch.nn.LayerNorm(f_hidden),
            torch.nn.SiLU(),
            torch.nn.Linear(f_hidden, f_hidden),
            torch.nn.LayerNorm(f_hidden),
            torch.nn.SiLU(),
            torch.nn.Linear(f_hidden, f_out),
        )

    def forward(self, x):
        if self.skip_connection:
            return x + self.mlp(x)

        return self.mlp(x)


class AddEdgeIndex(torch.nn.Module):
    def __init__(self, n_neighbors=None, cutoff=None):
        super().__init__()
        self.n_neighbors = n_neighbors if n_neighbors else 10000
        self.cutoff = cutoff if cutoff is not None else float("inf")

    def forward(self, batch):
        batch = batch.clone()
        edge_index = self.generate_edge_index(batch)
        batch.edge_index = edge_index.to(batch.x.device)
        return batch


class AddSpatialEdgeFeatures(torch.nn.Module):
    def forward(self, batch, *_, **__):
        r = batch.x[batch.edge_index[0]] - batch.x[batch.edge_index[1]]

        edge_dist = r.norm(dim=-1)
        edge_dir = r / (1 + edge_dist.unsqueeze(-1))

        batch.edge_dist = edge_dist
        batch.edge_dir = edge_dir
        return batch


class InvariantFeatures(torch.nn.Module):
    """
    Implement embedding in child class
    All features that will be embedded should be in the batch
    """

    def __init__(self, feature_name, type_="node"):
        super().__init__()
        self.feature_name = feature_name
        self.type = type_

    def forward(self, batch):
        embedded_features = self.embedding(batch[self.feature_name])

        name = f"invariant_{self.type}_features"
        if hasattr(batch, name):
            batch[name] = torch.cat([batch[name], embedded_features], dim=-1)
        else:
            batch[name] = embedded_features

        return batch


class NominalEmbedding(InvariantFeatures):
    def __init__(self, feature_name, n_features, n_types, feature_type="node"):
        super().__init__(feature_name, feature_type)
        self.embedding = torch.nn.Embedding(n_types, n_features)


class DeviceTracker(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("device_tracker", torch.tensor(1))

    @property
    def device(self):
        return self.device_tracker.device


class PositionalEncoder(DeviceTracker):
    def __init__(self, dim, length=10):
        super().__init__()
        assert dim % 2 == 0, "dim must be even for positional encoding for sin/cos"

        self.dim = dim
        self.length = length
        self.max_rank = dim // 2

    def forward(self, x):
        encodings = [self.positional_encoding(x, rank) for rank in range(self.max_rank)]
        return torch.cat(
            encodings,
            axis=1,
        )

    def positional_encoding(self, x, rank):
        sin = torch.sin(x / self.length * rank * np.pi)
        cos = torch.cos(x / self.length * rank * np.pi)
        assert (
            cos.device == self.device
        ), f"batch device {cos.device} != model device {self.device}"
        return torch.stack((cos, sin), axis=1)


class PositionalEmbedding(InvariantFeatures):
    def __init__(self, feature_name, n_features, length):
        super().__init__(feature_name)
        assert n_features % 2 == 0, "n_features must be even"
        self.rank = n_features // 2
        self.embedding = PositionalEncoder(n_features, length)


class CombineInvariantFeatures(torch.nn.Module):
    def __init__(self, n_features_in, n_features_out):
        super().__init__()
        self.mlp = MLP(n_features_in, n_features_out, n_features_out)

    def forward(self, batch):
        batch.invariant_node_features = self.mlp(batch.invariant_node_features)
        return batch


class AddEquivariantFeatures(DeviceTracker):
    def __init__(self, n_features):
        super().__init__()
        self.n_features = n_features

    def forward(self, batch):
        eq_features = torch.zeros(
            batch.batch.shape[0],
            self.n_features,
            3,
        )
        batch.equivariant_node_features = eq_features.to(self.device)
        return batch


class PaiNNTLScore(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device

    def __init__(
        self,
        n_features=32,
        embedding_layers=2,
        score_layers=5,
        max_lag=1000,
        diff_steps=1000,
        n_types=167,
        dist_encoding="positional_encoding",
    ):
        super().__init__()
        layers = [
            embedding.AddSpatialEdgeFeatures(),
            embedding.NominalEmbedding(
                "bonds", n_features, n_types=4, feature_type="edge"
            ),
            embedding.NominalEmbedding("atoms", n_features, n_types=n_types),
            embedding.AddEquivariantFeatures(n_features),
            #  embedding.CombineInvariantFeatures(2 * n_features, n_features),
            PaiNNBase(
                n_features=n_features,
                n_features_out=n_features,
                n_layers=embedding_layers,
                dist_encoding=dist_encoding,
            ),
        ]

        self.embed = torch.nn.Sequential(*layers)

        self.net = torch.nn.Sequential(
            embedding.AddSpatialEdgeFeatures(),
            embedding.PositionalEmbedding("ts_diff", n_features, diff_steps),
            embedding.PositionalEmbedding("lag", n_features, max_lag),
            embedding.CombineInvariantFeatures(3 * n_features, n_features),
            PaiNNBase(
                n_features=n_features,
                dist_encoding=dist_encoding,
                n_layers=score_layers,
            ),
        )

    def forward(self, batch):
        cond = batch["cond"].clone().to(self.device)
        corr = batch["corr"].clone().to(self.device)

        batch_idx = batch["cond"].batch
        corr.lag = batch["lag"][batch_idx].squeeze()
        corr.ts_diff = batch["ts_diff"][batch_idx].squeeze()

        embedded = self.embed(cond)
        corr.invariant_node_features = embedded.invariant_node_features
        corr.equivariant_node_features = embedded.equivariant_node_features
        corr.invariant_edge_features = embedded.invariant_edge_features
        corr.edge_index = embedded.edge_index

        dx = self.net(corr).equivariant_node_features.squeeze()

        corr.x += dx
        return corr


class PaiNNBase(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device

    def __init__(
        self,
        n_features=128,
        n_layers=5,
        n_features_out=1,
        length_scale=10,
        dist_encoding="positional_encoding",
        use_edge_features=True,
    ):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                Message(
                    n_features=n_features,
                    length_scale=length_scale,
                    dist_encoding=dist_encoding,
                    use_edge_features=use_edge_features,
                )
            )
            layers.append(Update(n_features))

        layers.append(Readout(n_features, n_features_out))
        self.layers = torch.nn.Sequential(*layers)

    def forward(self, batch):
        return self.layers(batch)


class Message(torch.nn.Module):
    def __init__(
        self,
        n_features=128,
        length_scale=10,
        dist_encoding="positional_encoding",
        use_edge_features=True,
    ):
        super().__init__()
        self.n_features = n_features
        self.use_edge_features = use_edge_features

        assert dist_encoding in (
            a := ["positional_encoding", "soft_one_hot"]
        ), f"positional_encoder must be one of {a}"

        if dist_encoding in ["positional_encoding", None]:
            self.positional_encoder = embedding.PositionalEncoder(
                n_features, length=length_scale
            )
        elif dist_encoding == "soft_one_hot":
            self.positional_encoder = embedding.SoftOneHotEncoder(
                n_features, max_radius=length_scale
            )

        phi_in_features = 2 * n_features if use_edge_features else n_features
        self.phi = embedding.MLP(phi_in_features, n_features, 4 * n_features)
        self.w = embedding.MLP(n_features, n_features, 4 * n_features)

    def forward(self, batch):
        src_node = batch.edge_index[0]
        dst_node = batch.edge_index[1]

        in_features = batch.invariant_node_features[src_node]

        if self.use_edge_features:
            in_features = torch.cat(
                [in_features, batch.invariant_edge_features], dim=-1
            )

        positional_encoding = self.positional_encoder(batch.edge_dist)

        gates, scale_edge_dir, ds, de = torch.split(
            self.phi(in_features) * self.w(positional_encoding),
            self.n_features,
            dim=-1,
        )
        gated_features = multiply_first_dim(
            gates, batch.equivariant_node_features[src_node]
        )
        scaled_edge_dir = multiply_first_dim(
            scale_edge_dir, batch.edge_dir.unsqueeze(1).repeat(1, self.n_features, 1)
        )

        dv = scaled_edge_dir + gated_features
        dv = scatter(dv, dst_node, dim=0)
        ds = scatter(ds, dst_node, dim=0)

        batch.equivariant_node_features += dv
        batch.invariant_node_features += ds
        batch.invariant_edge_features += de

        return batch


def multiply_first_dim(w, x):
    with warnings.catch_warnings(record=True):
        return (w.T * x.T).T


class Update(torch.nn.Module):
    def __init__(self, n_features=128):
        super().__init__()
        self.u = EquivariantLinear(n_features, n_features)
        self.v = EquivariantLinear(n_features, n_features)
        self.n_features = n_features
        self.mlp = embedding.MLP(2 * n_features, n_features, 3 * n_features)

    def forward(self, batch):
        v = batch.equivariant_node_features
        s = batch.invariant_node_features

        vv = self.v(v)
        uv = self.u(v)

        vv_norm = vv.norm(dim=-1)
        vv_squared_norm = vv_norm**2

        mlp_in = torch.cat([vv_norm, s], dim=-1)

        gates, scale_squared_norm, add_invariant_features = torch.split(
            self.mlp(mlp_in), self.n_features, dim=-1
        )

        delta_v = multiply_first_dim(uv, gates)
        delta_s = vv_squared_norm * scale_squared_norm + add_invariant_features

        batch.invariant_node_features = batch.invariant_node_features + delta_s
        batch.equivariant_node_features = batch.equivariant_node_features + delta_v

        return batch


class EquivariantLinear(torch.nn.Module):
    def __init__(self, n_features_in, n_features_out):
        super().__init__()
        self.linear = torch.nn.Linear(n_features_in, n_features_out, bias=False)

    def forward(self, x):
        return self.linear(x.swapaxes(-1, -2)).swapaxes(-1, -2)


class Readout(torch.nn.Module):
    def __init__(self, n_features=128, n_features_out=13):
        super().__init__()
        self.mlp = embedding.MLP(n_features, n_features, 2 * n_features_out)
        self.V = EquivariantLinear(  # pylint:disable=invalid-name
            n_features, n_features_out
        )
        self.n_features_out = n_features_out

    def forward(self, batch):
        invariant_node_features_out, gates = torch.split(
            self.mlp(batch.invariant_node_features), self.n_features_out, dim=-1
        )

        equivariant_node_features = self.V(batch.equivariant_node_features)
        equivariant_node_features_out = multiply_first_dim(
            equivariant_node_features, gates
        )

        batch.invariant_node_features = invariant_node_features_out
        batch.equivariant_node_features = equivariant_node_features_out
        return batch


class PaiNNScore(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device

    def __init__(
        self,
        n_features=32,
        score_layers=3,
        diff_steps=1000,
        n_types=167,
        dist_encoding="positional_encoding",
    ):
        super().__init__()
        layers = [
            embedding.AddSpatialEdgeFeatures(),
            embedding.NominalEmbedding(
                "bonds", n_features, n_types=4, feature_type="edge"
            ),
            embedding.NominalEmbedding("atoms", n_features, n_types=n_types),
            embedding.PositionalEmbedding("ts_diff", n_features, diff_steps),
            embedding.AddEquivariantFeatures(n_features),
            embedding.CombineInvariantFeatures(2 * n_features, n_features),
            PaiNNBase(
                n_features=n_features,
                n_features_out=1,
                n_layers=score_layers,
                dist_encoding=dist_encoding,
            ),
        ]

        self.net = torch.nn.Sequential(*layers)

    def forward(self, batch):
        corr = batch["corr"].clone().to(self.device)
        batch_idx = batch["corr"].batch
        corr.ts_diff = batch["ts_diff"][batch_idx].squeeze()

        dx = self.net(corr).equivariant_node_features.squeeze()
        corr.x += dx

        return corr

 # Add RBF Embedded Writhe into cpainNN work flow

In [525]:
# make data objects
data_objs = [GeometricData(x=torch.Tensor(x),
                           atoms=index_sequence,
                           edge_index=torch.triu_indices(20, 20, 1).long(),
                           bonds=get_residue_bonds(code_sequence),)
            for x in coords]

dataset = GraphDataSet(data_objs)
loader = DataLoader(dataset, batch_size=10)
batch = next(iter(loader))

# make input batch and add features including new writhe feature
n_features = 10
n_bond_types = 400
n_residue_types = 20
layers = [AddSpatialEdgeFeatures(),
          NominalEmbedding("bonds", n_features, n_types=n_bond_types, feature_type="edge"),
          NominalEmbedding("atoms", n_features, n_types=n_residue_types),
          AddEquivariantFeatures(n_features),
          AddEmbeddedWrithe(embed_dim=n_features, bins=100, n_atoms=20)
         ]
for layer in layers:
    batch = layer(batch)

batch

DataBatch(x=[200, 3], edge_index=[2, 1900], atoms=[200], bonds=[190], batch=[200], ptr=[11], edge_dist=[1900], edge_dir=[1900, 3], invariant_edge_features=[190, 10], invariant_node_features=[200, 10], equivariant_node_features=[200, 10, 3], writhe=[1900, 10])

 # Prep inputs for testing SE3 Equivariant attention implementation in torch

In [305]:
attention_features = dict(vectors=batch.equivariant_node_features, scalars=batch.invariant_node_features.unsqueeze(-1))

In [306]:
attention_features_with_edge = dict(vectors=batch.equivariant_node_features,
                                    scalars=batch.invariant_node_features.unsqueeze(-1),
                                    edges=batch.writhe.unsqueeze(-1))

 # SE3 Equivariant attention pytorch implementation

 # The class names say SO(3) but with the addition of cross products / determinants we get SE(3)

In [485]:
features = dict(vectors=vectors, scalars=scalars)
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, out_dim), nn.LayerNorm(out_dim), nn.ReLU(), nn.Linear(out_dim, out_dim))
    def forward(self, x):
        return self.net(x)

class SODInvariantScalars(nn.Module):
    """ NN parameterized SO(d)-invariant scalar function.
    g(VV^T, scalars, minors(V)) -> scalar
    """
    def __init__(self, n_vectors: int,
                 n_scalars: int,
                 out_dim: int):
        """
        :param n_vectors: channels of input vector features
        :param n_scalars: channels of input scalar features
        :param out_dim: dimension of the output feature
        """
        super(SODInvariantScalars, self).__init__()
        self.n_vectors = n_vectors
        self.n_scalars = n_scalars
        self.out_dim = out_dim
        
        triu_idx = torch.triu_indices(n_vectors, n_vectors) # get unique values from GRAHAM matrix
        assert triu_idx.shape[1] == int(n_vectors * (n_vectors + 1) / 2)
        self.register_buffer('triu_idx', triu_idx)

        sub_idx = torch.combinations(torch.arange(n_vectors), 3, False) # get all possible determiniants of vector triples
        
        if sub_idx.shape[0] != 0:
            self.register_buffer('sub_idx', torch.flatten(sub_idx))
        else:
            self.register_buffer('sub_idx', None)

        self.in_dim = n_scalars + self.triu_idx.shape[1] + sub_idx.shape[0] # add up dims of dot prodicts, cross products, and scalar features 
        
        self.net = MLP(self.in_dim, self.out_dim)

        # NOTE : the vector features only show themselves in the form of dots and crosses


    @staticmethod
    def _compute_determinant(A: torch.Tensor):
        """ Compute the determinant of batched 3x3 matrices
        :param A: ..., 3, 3
        :return:
        """
        output = A[..., 0, 0] * A[..., 1, 1] * A[..., 2, 2] + \
                 A[..., 0, 1] * A[..., 1, 2] * A[..., 2, 0] + \
                 A[..., 0, 2] * A[..., 1, 0] * A[..., 2, 1] - \
                 A[..., 0, 2] * A[..., 1, 1] * A[..., 2, 0] - \
                 A[..., 0, 1] * A[..., 1, 0] * A[..., 2, 2] - \
                 A[..., 0, 0] * A[..., 1, 2] * A[..., 2, 1]
        return output

    def forward(self, features):
        """
        :param features: dict,
            'vec': B, n_vectors, 3
            'scalar': B, n_scalars, 1
        :return:
            B, out_dim
        """

        vectors, scalars = features.values()
        
        net_input_features = []
        

        # dot products
        B = vectors.shape[0]
        inner_product = torch.bmm(vectors, vectors.transpose(2, 1)) # GRAHAM matrix
        inner_product = inner_product[..., self.triu_idx[0], self.triu_idx[1]]  # unique indices including diagonal
        net_input_features.append(inner_product)

        # determinants
        det_mat = torch.index_select(vectors, 1, self.sub_idx)
        det_mat = det_mat.reshape(B, -1, 3, 3)
        determinants = self._compute_determinant(det_mat)
        net_input_features.append(determinants)

        # scalar features -> go directly into input with no changes
        scalars = scalars[...,0]
        net_input_features.append(scalars)

        # concatenate results
        net_input_features = torch.cat(net_input_features, dim=-1)

        # pass through neural net
        scalar_weights = self.net(net_input_features)  # B, out_dim

        # will return a (batch, out_dim) result

        return scalar_weights

In [308]:
SODInvariantScalars(n_vectors=n_features, n_scalars=n_features, out_dim=n_features)(attention_features)

tensor([[-0.4049, -0.1439, -0.5524,  ..., -0.1914,  0.2513, -0.6142],
        [-0.5472, -0.2710, -0.4868,  ...,  0.1545,  0.1505, -0.7519],
        [-0.7155, -0.4028, -0.0791,  ..., -0.3790,  0.0344, -0.4619],
        ...,
        [-0.5246, -0.2488,  0.1019,  ...,  0.2331, -0.7568, -0.6138],
        [-0.7155, -0.4028, -0.0791,  ..., -0.3790,  0.0344, -0.4619],
        [-0.8707, -0.4729,  0.1103,  ..., -0.3236,  0.0674, -0.3891]],
       grad_fn=<AddmmBackward0>)

In [309]:
class SO3LayerNorm(nn.Module):
    def __init__(self, feature_dims):
        """
        :param feature_dims: dict, channels of input vectors and scalars
        """
        super(SO3LayerNorm, self).__init__()

        self.feature_dims = feature_dims
        self.LN_modules = nn.ModuleDict()
        self.eps = 1e-12

        for data_type, channel in feature_dims.items():
            if channel != 0:
                self.LN_modules[data_type] = nn.LayerNorm(channel)

    def forward(self, features, **kwargs):
        """
               :param features: dict
                   'vectors': B, m_in['vectors'], 3
                   'scalar': B, m_in['scalars'], 1
               :param kwargs:
               :return: dict
                   'vectors': B, m_in['vectors'], 3
                   'scalar', B, m_in[scalars'], 1
               """
        new_features = {}
        
        if 'vectors' in self.LN_modules:
            data_item = features['vectors']     # B, m_in[*], dim
            norm = torch.sqrt(torch.sum(torch.square(data_item), dim=-1) + self.eps)  # B, m_in[*]
            phase = data_item / norm.unsqueeze(-1)  # B, m_in[*], dim
            transformed = self.LN_modules['vectors'](norm)  # B, m_in[*]
            new_features["vectors"] = transformed.unsqueeze(-1) * phase  # B, m_in[*], dim
        
        if 'scalars' in self.LN_modules:
            data_item = features['scalars'][..., 0]  # B, m_in[*]
            data_item = self.LN_modules['scalars'](data_item)    # B, m_in[*]
            new_features["scalars"] = data_item.unsqueeze(-1)    # B, m_in[*], 1

        return new_features


In [310]:
SO3LayerNorm(feature_dims=dict(vectors=n_features, scalars=n_features))(attention_features)

{'vectors': tensor([[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]], gr

In [368]:
class SO3EquivariantVector(nn.Module):
    """SO(3) Equivariant Vector Function. Possibly update scalar vectors, depending on whether
    the output channels are zeros or not.
    The updated scalar vector is extracted from the output of ODInvariantScalars
    """

    def __init__(self,
                 n_vectors: int,
                 n_vectors_out: int,
                 n_scalars: int,
                 n_scalars_out: int,
                 cross_product: bool = True,
                 input_LN: bool = False):
        """
        :param n_vectors: channels of input vector features
        :param n_vectors_out: channels of output vector features
        :param n_scalars: channels of input scalar features
        :param n_scalars_out: channels of output scalar features
        :param cross_product: whether or not to compute the cross product (defaults to true for SE3 equivariance)
        :param input_LN: normalize the input vectors and scalars individually (bool)
        """
        super(SO3EquivariantVector, self).__init__()
        assert (n_vectors_out == 0) or (n_vectors >0)

        self.n_vectors = n_vectors
        self.n_vectors_out = n_vectors_out
        self.n_scalars = n_scalars
        self.n_scalars_out = n_scalars_out
        
        
        # cross product information
        if cross_product:
            self.cross_product = True
        else:
            self.cross_product = False
        
        if self.cross_product:
            # later, we'll need a respresentation that gives a pair of vectors for each unique cross product possible
            # we cross those -> len(tri_up) normal vectors -> concatenate 10 equivariant vectors
            self.n_cross_prod = int(n_vectors * (n_vectors - 1) / 2) 
        
        else:
            self.n_cross_prod = 0

        cross_prod_idx = torch.triu_indices(n_vectors, n_vectors, offset=1)
        self.register_buffer('cross_prod_idx', cross_prod_idx)
        self.normalize_term = self.n_cross_prod + self.n_vectors
        
        self.out_dim_vec = self.n_vectors_out * (self.n_cross_prod + self.n_vectors) # n * (n**2 + n) / 2
        self.out_dim_s = self.n_scalars_out
        
        self.out_dim = self.out_dim_vec + self.out_dim_s # after passing through scalar net, we still want a new scalar reprensentation
        self.scalar_net = SODInvariantScalars(self.n_vectors, self.n_scalars, self.out_dim)


        # normalization
        self.input_LN = input_LN
        if self.input_LN:
            self.input_layer_norm = SO3LayerNorm(feature_dims=dict(vectors=n_vectors, scalars=n_scalars))
        

    def forward(self, features):
        """
        :param features: dict
            'vec': B, n_vectors, 3
            'scalar': B, n_scalars, 1 (optional)
        :return: dict
            'vec': B, n_vectors_out, 3
            'scalar': B, n_scalars_out, 1
        """
        # normalize vectors and scalars individually

        #vectors, scalars = features.values()
        
        if self.input_LN:
            features = self.input_layer_norm(features)
        
        
        # pass through SE3Equivariant scalar network
        weights = self.scalar_net(features)  # B, out_dim
        
        # split output out SE3Equivariant scalar net, the output
        
        # the scalar weights received from the split remain unmodified throughout the rest of the forward method and are returned as is

        vector_weights, scalar_weights = torch.split(weights, [self.out_dim_vec, self.out_dim_s], dim=-1)
        
        new_features = {}

        if self.out_dim_vec > 0:
            B = features["vectors"].shape[0]
            vector_weights = vector_weights.reshape(B, self.n_vectors_out, self.n_cross_prod + self.n_vectors)
            
            # if there are vectors, we compute the paired cross products
            if self.n_vectors > 1 and self.cross_product:
                
                # prepare matrcies for cross products
                c0, c1 = self.cross_prod_idx # separate indices

                # separate paratitions of vectors for cross product computation
                mat0 = torch.gather(features["vectors"], 1, c0[None, :, None].expand(B, -1, 3))  # B, n_cross_prod, 3
                mat1 = torch.gather(features["vectors"], 1, c1[None, :, None].expand(B, -1, 3))  # B, n_cross_prod, 3

                
                # compute cross products
                cross_prods = torch.linalg.cross(mat0, mat1, dim=-1)  # B, n_cross_prod, 3

                
                # combine vectors and cross products to get a larger set of vectors, still 3D
                cat_mat = torch.cat([features["vectors"], cross_prods], dim=-2)  # B, n_cross_prod + n_vectors ,3
            
            else:
                # if no cross products, then the out is simply the vectors ... is there any reason for this function in that case
                cat_mat = features["vectors"]  # B, n_vectors_out, 3
            
            
            # matrix multiply the result of output of scalar net and concatenation of input vectors and cross products of vector weights
            new_features["vectors"] = torch.bmm(vector_weights, cat_mat) / self.normalize_term

        if self.out_dim_s > 0:
            new_features["scalars"] = scalar_weights.unsqueeze(-1)    # B, n_scalars_out, 1

        return new_features
    

 # Dummy test (equivariant vectors start as zero, so this is meaningless testing)

In [356]:
SO3EquivariantVector(n_features, n_features, n_features, n_features, True, True)(attention_features)

{'vectors': tensor([[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]], gr

 # Used RBF Embedded Writhe as the edge feature in the equivariant attention scheme

 # Write layer for integration into existing cpaiNN work flow

In [361]:
class PairwiseSO3Conv(nn.Module):
    
    """ Generate pairwise features.
    f_ji = h('f_j', 's_j', x_i - x_j, edge_ji)   -> {'vec' , 'scalar'}
    """

    def __init__(self,
                 n_vectors: int,
                 n_vectors_out: int,
                 n_scalars: int,
                 n_scalars_out: int,
                 edge_dim: int = int,
                 cross_product: bool = True,
                ):
        
        """
        :param n_vectors: channels of input vector features
        :param n_vectors_out: channels of output vector features
        :param n_scalars: channels of input scalar features
        :param n_scalars_out: channels of output scalar features
        :param edge_dim: dimensions of edge features.
        """
        super(PairwiseSO3Conv, self).__init__()
        self.n_vectors = n_vectors
        self.n_vectors_out = n_vectors_out
        self.n_scalars = n_scalars
        self.n_scalars_out = n_scalars_out
        self.edge_dim = edge_dim

        net_in_vec = n_vectors + 1  # cat(f_j, x_i - x_j)
        net_in_s = n_scalars + edge_dim  # cat(s_j, edge_ji)
        
        self.net = SO3EquivariantVector(net_in_vec,
                                        n_vectors_out,
                                        net_in_s,
                                        n_scalars_out,
                                        cross_product)

    def forward(self, batch):

        
        # instantiate total feat dict
        input_feat_dict = {}

        src_node = batch.edge_index[0]
        dst_node = batch.edge_index[1]
        
        # combine equivariant vectors with displacement vectors, add to total
        rel = batch.edge_dir
        vec_feats = []
        vec_feats.append(batch.equivariant_node_features[src_node])
        vec_feats.append(rel[:, None, :])
        input_feat_dict['vectors'] = torch.cat(vec_feats, dim=1)  # num_edges, m_in + 1, 3

        # combine equivariant scalars with edge features, add to total
        add_feat = []
        add_feat.append(batch.invariant_node_features[src_node].unsqueeze(-1))
        add_feat.append(batch.writhe[src_node].unsqueeze(-1)) #assume we've added writhe
        input_feat_dict['scalars'] = torch.cat(add_feat, dim=-2)  # num_edges, n_scalars + edge_dim, 1

        # use as input to SE3EquivariantVector network
        new_features = self.net(input_feat_dict) # num_edges, m_out, 3

        return new_features


# Make our attention inputs ... these can be simplified / joined in a final function such that the input in simply a batch of graphs

In [522]:
key = PairwiseSO3Conv(n_features, n_features,n_features,n_features,n_features)(batch)
query = SO3EquivariantVector(*(4*[n_features]))(attention_features)
value = PairwiseSO3Conv(n_features, n_features,n_features,n_features,n_features)(batch)
#attention_module = AttentionModule(1)
#key, query = map(attention_module.vectorize_dict, (key, query))

In [None]:
class AttentionModule(nn.Module):
    """An SO(3)-equivariant self-attention module."""

    def __init__(self, heads: int):
        """
        :param heads: Number of attention heads.
        """
        super(AttentionModule, self).__init__()
        self.heads = heads
        self.attn_dropout = nn.Dropout(p=0.0)


    def vectorize_dict(self, data_dict):
        """
        Vectorize data in the data_dict and concatenate them together.
        :param data_dict:
            'vec': B, m_vec, 3
            'scalar': B, m_s, 1
        :return:
            B, heads, m_vec // heads * 3 + m_s // heads * 1
        """
        container = []
        for key, value in data_dict.items():
            B, m_in, dim = value.shape
            assert m_in % self.heads == 0, 'm_in is not divisible by heads.'
            container.append(value.reshape(B, self.heads, -1))
        return torch.cat(container, dim=-1)
    
    
    def forward(self, q: dict, k: dict, v: dict):
        
        """
        :param q: dict, query
            'vec': B, m_qk_vec, 3
            'scalar': B, m_qk_s, 1
        :param k: dict, key
            'vec': B, m_qk_vec, 3
            'scalar': B, m_qk_s, 1
        :param v: dict, value
            'vec': B, m_v_vec, 3
            'scalar': B, m_V_s, 1
        :param G:
            A DGL graph
        :return: dict
            {'vec': B, m_v_vec, 3,
            'scalar': B, m_v_s, 1}
        """

        src_node, dst_node = batch.edge_index
        
        query, key = (self.vectorize_dict(i) for i in (q, k))
        div_term = math.sqrt(query.shape[-1])
        attn = torch.exp((query[dst_node] * key).sum(-1) / div_term).squeeze()
        attn = (attn / scatter(attn, dst_node)[dst_node]).reshape(-1, self.heads, 1)

        # Apply attention weights to value embeddings
        output_dict = {}
        for data_type, data_item in v.items():
            num_edges, m_in, dim = data_item.shape
            assert m_in % self.heads == 0, 'm_in is not divisible by heads in the value embedding.'
            weight = data_item.reshape(num_edges, self.heads, -1, dim) * attn.unsqueeze(-1)
            new_feature = scatter(weight, dst_node, dim=0)
            output_dict[data_type] = new_feature.reshape(-1, m_in, dim)
            
            return output_dict


In [524]:
AttentionModule(1)(query, key, value)

{'vectors': tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
         [[ 1.3809e-03, -2.7748e-04,  2.0714e-03],
          [-1.4380e-03,  2.8893e-04, -2.1569e-03],
          [-3.5203e-04,  7.0735e-05, -5.2805e-04],
          ...,
          [ 6.3171e-04, -1.2693e-04,  9.4756e-04],
          [-9.6343e-04,  1.9359e-04, -1.4451e-03],
          [ 1.1546e-04, -2.3200e-05,  1.7319e-04]],
 
         [[ 1.6486e-03, -1.1649e-03,  3.9034e-04],
          [-1.8759e-03,  1.3749e-03, -3.2096e-04],
          [-3.1356e-04,  1.8921e-04, -1.5508e-04],
          ...,
          [ 1.4566e-03, -1.2362e-03, -1.7207e-04],
          [-2.1720e-03,  1.8380e-03,  2.4336e-04],
          [-1.6592e-04,  2.0148e-04,  1.7114e-04]],
 
         ...,
 
 

from dgl import function as fn
from dgl.nn.functional import edge_softmax

from torch_geometric.utils import to_dgl, from_dgl
G = to_dgl(batch)

with G.local_scope():
    G.ndata['query'] = query    # num_nodes, heads, dim
    G.edata['key'] = key     # num_edges, heads, dim
    div_term = math.sqrt(G.ndata['query'].shape[-1])

    # Compute the attention weights
    G.apply_edges(fn.e_dot_v('key', 'query', 'attn'))
    attn = G.edata.pop('attn') / div_term
    attn = edge_softmax(G, attn)

class LastDimNorm(nn.Module):
    def __init__(self, dim: int=3):
        super.__init__()
        self.norm = nn.InstanceNorm1d(dim)
    def forward(self, x):
        return norm(vectors.transpose(2, 1)).transpose(2, 1)
    
