In [20]:
import mdtraj as md
traj = md.load("/Users/tommysisk/asyn/lig47/fasudil_111frames.dcd",
               top="/Users/tommysisk/asyn/lig47/lig47.pdb"
              )
coords = traj.atom_slice(traj.top.select("name CA")).xyz
segments = get_segments(20)

In [37]:
import torch
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
from torch.nn.init import zeros_
from functools import partial
import math
import torch.nn as nn
import contextlib
import time

class Timer:
    """import time"""

    def __init__(self,
                 check_interval: "the time (hrs) after the call method should return True" = 1):

        self.start_time = time.time()
        self.interval = check_interval * (60 ** 2)

    def __call__(self):
        if abs(time.time() - self.start_time) > self.interval:
            self.start_time = time.time()
            return True
        else:
            return False

    def time_remaining(self):
        sec = max(0, self.interval - abs(time.time() - self.start_time))
        hrs = sec // (60 ** 2)
        mins_remaining = (sec / 60 - hrs * 60)
        mins = mins_remaining // 1
        secs = (mins_remaining - mins) * 60
        hrs, mins, secs = [int(i) for i in [hrs, mins, secs]]
        print(f"{hrs}:{mins}:{secs}")
        return None

    # for context managment

    @classmethod
    @contextlib.contextmanager
    def timeit(cls):
        start = time.time()
        try:
            yield
        finally:
            print(f"Time Elapsed : {time.time() - start:.5f} seconds")
            del start

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        self.end = time.time()
        self.interval = self.end - self.start
        print(f"Time elapsed {self.interval} s")
        return self.interval



class Dense(nn.Linear):
    r"""Fully connected linear layer with activation function.

    .. math::
       y = activation(x W^T + b)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        activation: callable = None,
        weight_init = xavier_uniform_,
        bias_init = zeros_,
    ):
        """
        Args:
            in_features: number of input feature :math:`x`.
            out_features: umber of output features :math:`y`.
            bias: If False, the layer will not adapt bias :math:`b`.
            activation: if None, no activation function is used.
            weight_init: weight initializer from current weight.
            bias_init: bias initializer from current bias.
        """
        self.weight_init = weight_init
        self.bias_init = bias_init
        super(Dense, self).__init__(in_features, out_features, bias)

        self.activation = activation
        if self.activation is None:
            self.activation = nn.Identity()

    def reset_parameters(self):
        self.weight_init(self.weight)
        if self.bias is not None:
            self.bias_init(self.bias)

    def forward(self, input: torch.Tensor):
        y = F.linear(input, self.weight, self.bias)
        y = self.activation(y)
        return y


class GatedEquivariantBlock(nn.Module):
    """
    Gated equivariant block as used for the prediction of tensorial properties by PaiNN.
    Transforms scalar and vector representation using gated nonlinearities.

    """

    def __init__(
        self,
        n_sin: int,
        n_vin: int,
        n_sout: int,
        n_vout: int,
        n_hidden: int,
        activation=F.silu,
        sactivation=None,
    ):
        """
        Args:
            n_sin: number of input scalar features
            n_vin: number of input vector features
            n_sout: number of output scalar features
            n_vout: number of output vector features
            n_hidden: number of hidden units
            activation: interal activation function
            sactivation: activation function for scalar outputs
        """
        super().__init__()
        self.n_sin = n_sin
        self.n_vin = n_vin
        self.n_sout = n_sout
        self.n_vout = n_vout
        self.n_hidden = n_hidden
        self.mix_vectors = Dense(n_vin, 2 * n_vout, activation=None, bias=False)
        self.scalar_net = nn.Sequential(
            Dense(n_sin + n_vout, n_hidden, activation=activation),
            Dense(n_hidden, n_sout + n_vout, activation=None),
        )
        self.sactivation = sactivation

    def forward(self, inputs: tuple):
        scalars, vectors = inputs
        vmix = self.mix_vectors(vectors)
        vectors_V, vectors_W = torch.split(vmix, self.n_vout, dim=-1)
        vectors_Vn = torch.norm(vectors_V, dim=-2)

        ctx = torch.cat([scalars, vectors_Vn], dim=-1)
        x = self.scalar_net(ctx)
        s_out, x = torch.split(x, [self.n_sout, self.n_vout], dim=-1)
        v_out = x.unsqueeze(-2) * vectors_W

        if self.sactivation:
            s_out = self.sactivation(s_out)

        return s_out, v_out


def build_gated_equivariant_mlp(
    n_in: int,
    n_out: int,
    n_hidden  = None,
    n_gating_hidden = None,
    n_layers: int = 2,
    activation: callable = F.silu,
    sactivation: callable = F.silu,
):
    """
    Build neural network analog to MLP with `GatedEquivariantBlock`s instead of dense layers.

    Args:
        n_in: number of input nodes.
        n_out: number of output nodes.
        n_hidden: number hidden layer nodes.
            If an integer, same number of node is used for all hidden layers resulting
            in a rectangular network.
            If None, the number of neurons is divided by two after each layer starting
            n_in resulting in a pyramidal network.
        n_layers: number of layers.
        activation: Activation function for gating function.
        sactivation: Activation function for scalar outputs. All hidden layers would
            the same activation function except the output layer that does not apply
            any activation function.
    """
    # get list of number of nodes in input, hidden & output layers
    if n_hidden is None:
        c_neurons = n_in
        n_neurons = []
        for i in range(n_layers):
            n_neurons.append(c_neurons)
            c_neurons = max(n_out, c_neurons // 2)
        n_neurons.append(n_out)
    else:
        # get list of number of nodes hidden layers
        if type(n_hidden) is int:
            n_hidden = [n_hidden] * (n_layers - 1)
        else:
            n_hidden = list(n_hidden)
        n_neurons = [n_in] + n_hidden + [n_out]

    if n_gating_hidden is None:
        n_gating_hidden = n_neurons[:-1]
    elif type(n_gating_hidden) is int:
        n_gating_hidden = [n_gating_hidden] * n_layers
    else:
        n_gating_hidden = list(n_gating_hidden)

    # assign a GatedEquivariantBlock (with activation function) to each hidden layer
    layers = [
        snn.GatedEquivariantBlock(
            n_sin=n_neurons[i],
            n_vin=n_neurons[i],
            n_sout=n_neurons[i + 1],
            n_vout=n_neurons[i + 1],
            n_hidden=n_gating_hidden[i],
            activation=activation,
            sactivation=sactivation,
        )
        for i in range(n_layers - 1)
    ]
    # assign a GatedEquivariantBlock (without scalar activation function)
    # to the output layer
    layers.append(
        snn.GatedEquivariantBlock(
            n_sin=n_neurons[-2],
            n_vin=n_neurons[-2],
            n_sout=n_neurons[-1],
            n_vout=n_neurons[-1],
            n_hidden=n_gating_hidden[-1],
            activation=activation,
            sactivation=None,
        )
    )
    # put all layers together to make the network
    out_net = nn.Sequential(*layers)
    return out_net


In [38]:
def get_time_embedding(time_steps, temb_dim):
    r"""
    Convert time steps tensor into an embedding using the
    sinusoidal time embedding formula
    :param time_steps: 1D tensor of length batch size
    :param temb_dim: Dimension of the embedding
    :return: BxD embedding representation of B time steps
    """
    assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
    
    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    
    # pos / factor
    # timesteps B -> B, 1 -> B, temb_dim
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb

In [39]:
def to_difference_matrix(x, norm: bool=False):
    """
    Converts a batch of coordinates to a batch of difference tensors
    :param X: batch of coordinates [B, N, d]
    :return: batch of difference tensors [B, N, N, d]
    """
    
    x = x[None, :, :] if x.ndim < 3 else x

    if norm:
        diff = x[:, :, None, :] - x[:, None, :, :]
        return diff / (1 + torch.linalg.norm(diff, axis=-1))
    
    else:
        return x[:, :, None, :] - x[:, None, :, :]


def to_distmat(x):
    """
    Converts a batch of coordinates to a batch of EDMs
    :param x: batch of coordinates [B, N, d]
    :return: batch of EDMs [B, N, N]
    """
    return torch.squeeze(torch.linalg.norm(to_difference_matrix(x), dim=-1), -1)



In [40]:
import itertools
import warnings
import numpy as np
import torch
import torch.nn as nn


def product(x: np.ndarray, y: np.ndarray):
    return np.asarray(list(itertools.product(x, y)))


def combinations(x):
    return np.asarray(list(itertools.combinations(x, 2)))


def shifted_pairs(x: np.ndarray, shift: int, ax: int = 1):
    return np.stack([x[:-shift], x[shift:]], ax)


def get_segments(n: int = None,
                 length: int = 1,
                 index0: np.ndarray = None,
                 index1: np.ndarray = None):
    """
    Function to retrieve indices of segment pairs for various use cases.
    Returns an (n_segment_pairs, 4) array where each row (quadruplet) contains : (start1, end1, start2, end2)
    """

    
    if all(i is None for i in (index0, index1)):
        assert n is not None, \
            "Must provide indices (index0:array, (optionally) index1:array) or the number of points (n: int)"
        segments = combinations(shifted_pairs(np.arange(n), length)).reshape(-1, 4)
        return torch.from_numpy(segments[~(segments[:, 1] == segments[:, 2])])

    else:
        assert index0 is not None, ("If providing only one set of indices, must set the index0 argument \n"
                                    "Cannot only supply the index1 argument (doesn't make sense in this context")
        if index1 is not None:
            return torch.from_numpy(product(*[shifted_pairs(i, length) for i in (index0, index1)]).reshape(-1, 4))
        else:
            segments = combinations(shifted_pairs(index0, length)).reshape(-1, 4)
            return torch.from_numpy(segments[~(segments[:, 1] == segments[:, 2])])


##########################################   fastest ways of implementing these linear algebra ops for this purpose  (NOT trivial) ############################################


    
def nnorm(x: torch.Tensor):
    
    """Convenience function for (batched) normalization of vectors stored in arrays with last dimension 3"""
    
    norm = torch.linalg.norm(x, axis=-1)
    
    if x.ndim == 4:
        return x / norm[:, :, :, None]
    elif x.ndim == 3:
        return x / norm[:, :, None]
    elif x.ndim == 2:
        return x / norm[:, None]
    else:
        return x / norm


def ncross(x: torch.Tensor, y: torch.Tensor):

    """Convenience function for (batched) cross products of vectors stored in arrays with last dimension 3""" 

    # c = np.array(list(map(cross,x,y)))
    c = torch.cross(x, y, axis=-1)
    return c


def ndot(x, y):

    """Convenience function for (batched) dot products of vectors stored in arrays with last dimension 3"""

    # d = np.array(list(map(dot,x,y)))[:,None]
    d = torch.sum(x * y, axis=-1)
    return d


def ndet(v1, v2, v3):
    """for the triple product and finding the signed sin of the angle between v2 and v3, v1 should
    be set equal to a vector mutually orthogonal to v2,v3"""
    #     det = np.array(list(map(lambda x,y,z:np.linalg.det(np.array([x,y,z])),
    #                         v1,v2,v3)))[:,None]
    det = ndot(v1, ncross(v2, v3))
    return det


def uproj(a, b, norm_b: bool = True):
    """Convenience function for (batched) othogonal projection of vectors stored in arrays with last dimension 3
    where a is a set of vectors which we be othogonally projected onto a single (batched) set of vectors, b"""
    b = nnorm(b) if norm_b else b
    # faster than np.matmul when using ray
    return a - b * torch.sum(a[:, None, :] * b[:, None, :], -1).transpose(0, 2, 1)


solid_angle = lambda a, b: -torch.arcsin(torch.prod(nnorm(uproj(a, b)), 1).sum(-1).clip(-1, 1))
# or arccos(...) - (np.pi / 2)
# slower than current version with ray
# Usage of solid angle in writhe computation (not as time efficient as implementation in use)
# indices = zip([0, 1, 0, 2],
#               [3, 2, 3, 1],
#               [1, 3, 2, 0])

# omega = np.stack([solid_angle(displacements[:, [i, j]], displacements[:, None, n])
#                   for i, j, n in indices], 1).squeeze().sum(-1)
##############################################################################################################################


def writhe_segment(segment=None, xyz=None, smat=None):
    """compute the writhe (signed crossing) of 2 segments for all frames (index 0) in xyz (xyz can contain just one frame)
    
    THERE ARE 2 INPUT OPTIONS
    
    **provide both of the following**
    
    segment: numpy array of shape (4,) giving the indices of the 4 alpha carbons in xyz creating 2 segments:::
             array([seg1_start_point,seg1_end_point,seg2_start_point,seg2_end_point])

    xyz: numpy array of shape (Nframes, N_alpha_carbons, 3),coordinate array giving the positions of ALL the alpha carbons
    
    **OR just the following**
    
    smat ::: numpy array of shape (Nframes, 4, 3) : sliced coordinate matrix: coordinate array that is pre-sliced
    with only the positions of the 4 alpha carbons constituting the 2 segments to compute the writhe between """

    if smat is None:
        assert not ((segment is None) or (xyz is None)), \
            "must input smat or both a segment and xyz coordinates"
        smat = (xyz[None, :, :] if xyz.ndim < 3 else xyz)[:, segment]
    else:
        smat = smat[None, :, :] if smat.ndim < 3 else smat

    # smat = nnorm(smat)
    sum_dim = None if smat.shape[0] == 1 else 1

    # broadcasting trick
    # negative sign, None placement and order are intentional, don't change without testing equivalent option
    displacements = nnorm((-smat[:, :2, None, :] + smat[:, None, 2:, :]).reshape(-1, 4, 3))

    # array broadcasting is (surprisingly) slower than list comprehensions
    # when using ray for the following operations (without ray, broadcasting should be faster).

    crosses = nnorm(ncross(displacements[:, [0, 1, 3, 2]], displacements[:, [1, 3, 2, 0]]))

    omega = torch.arcsin(ndot(crosses[:, [0, 1, 2, 3]], crosses[:, [1, 2, 3, 0]]).clip(-1, 1)).squeeze().sum(sum_dim)

    signs = torch.sign(ndot(ncross(nnorm(smat[:, 3] - smat[:, 2]),
                                nnorm(smat[:, 1] - smat[:, 0])),
                         displacements[:, 0])).squeeze()

    wr = (1 / (2 * torch.pi)) * (omega * signs)

    return wr


def writhe_segments_along_axis(segments: torch.LongTensor, xyz: torch.Tensor, axis: int = 1):
    """helper function for parallelization to compute writhe over chuncks of segments for all frames in xyz"""
    # noinspection PyTypeChecker
    return torch.stack([writhe_segment(segment, xyz, None) for segment in segments], 1)
    
# Unfinished, need analouge to cpu treatment with ray parallelization

def calc_writhe_parallel(segments: np.ndarray, xyz: np.ndarray,) -> "Nframes by Nsegments np.ndarray":
    """parallelize writhe calculation by segment, avoids making multiple copies of coordinate (xyz) matrix,
    uses torch.scatter to parallelize"""

    # TODO finish remaking this function to scatter the computation across multiple GPUs
    
    #n_split = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count()
    chunks = torch.tensor_split(segments, int(torch.cuda.device_count()))
    result = torch.concatenate([writhe_segments_along_axis(segments=chunk, xyz=xyz_ref) for chunk in chunks]).T.squeeze()
    return result


def to_writhe_adj_matrix(writhe_features, n_points, length, segments=None):
    
    n = len(writhe_features)

    if segments is None:
        segments = get_segments(n_points, length)

    adj_matrix = torch.zeros((n, n_points, n_points))

    adj_matrix[:, segments[:, 0], segments[:, 2]] = writhe_features
    adj_matrix[:, segments[:, 1], segments[:, 3]] = writhe_features
    adj_matrix = adj_matrix + adj_matrix.swapaxes(1, 2)

    return adj_matrix.squeeze()

In [41]:
class _SoftUnitStep(torch.autograd.Function):
    # pylint: disable=arguments-differ

    @staticmethod
    def forward(ctx, x) -> torch.Tensor:
        ctx.save_for_backward(x)
        y = torch.zeros_like(x)
        m = x > 0.0
        y[m] = (-1 / x[m]).exp()
        return y

    @staticmethod
    def backward(ctx, dy) -> torch.Tensor:
        (x,) = ctx.saved_tensors
        dx = torch.zeros_like(x)
        m = x > 0.0
        xm = x[m]
        dx[m] = (-1 / xm).exp() / xm.pow(2)
        return dx * dy


def soft_unit_step(x):
    r"""smooth :math:`C^\infty` version of the unit step function

    .. math::

        x \mapsto \theta(x) e^{-1/x}


    Parameters
    ----------
    x : `torch.Tensor`
        tensor of shape :math:`(...)`

    Returns
    -------
    `torch.Tensor`
        tensor of shape :math:`(...)`

    Examples
    --------

    .. jupyter-execute::
        :hide-code:

        import torch
        from e3nn.math import soft_unit_step
        import matplotlib.pyplot as plt

    .. jupyter-execute::

        x = torch.linspace(-1.0, 10.0, 1000)
        plt.plot(x, soft_unit_step(x));
    """
    return _SoftUnitStep.apply(x)
        
        
def soft_one_hot_linspace(x: torch.Tensor, start, end, number, basis=None, cutoff=None) -> torch.Tensor:
    r"""Projection on a basis of functions

    Returns a set of :math:`\{y_i(x)\}_{i=1}^N`,

    .. math::

        y_i(x) = \frac{1}{Z} f_i(x)

    where :math:`x` is the input and :math:`f_i` is the ith basis function.
    :math:`Z` is a constant defined (if possible) such that,

    .. math::

        \langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1

    See the last plot below.
    Note that ``bessel`` basis cannot be normalized.

    Parameters
    ----------
    x : `torch.Tensor`
        tensor of shape :math:`(...)`

    start : float
        minimum value span by the basis

    end : float
        maximum value span by the basis

    number : int
        number of basis functions :math:`N`

    basis : {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}
        choice of basis family; note that due to the :math:`1/x` term, ``bessel`` basis does not satisfy the normalization of
        other basis choices

    cutoff : bool
        if ``cutoff=True`` then for all :math:`x` outside of the interval defined by ``(start, end)``,
        :math:`\forall i, \; f_i(x) \approx 0`

    Returns
    -------
    `torch.Tensor`
        tensor of shape :math:`(..., N)`

    Examples
    --------

    .. jupyter-execute::
        :hide-code:

        import torch
        from e3nn.math import soft_one_hot_linspace
        import matplotlib.pyplot as plt

    .. jupyter-execute::

        bases = ['gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel']
        x = torch.linspace(-1.0, 2.0, 100)

    .. jupyter-execute::

        fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True)

        for axs, b in zip(axss, bases):
            for ax, c in zip(axs, [True, False]):
                plt.sca(ax)
                plt.plot(x, soft_one_hot_linspace(x, -0.5, 1.5, number=4, basis=b, cutoff=c))
                plt.plot([-0.5]*2, [-2, 2], 'k-.')
                plt.plot([1.5]*2, [-2, 2], 'k-.')
                plt.title(f"{b}" + (" with cutoff" if c else ""))

        plt.ylim(-1, 1.5)
        plt.tight_layout()

    .. jupyter-execute::

        fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True)

        for axs, b in zip(axss, bases):
            for ax, c in zip(axs, [True, False]):
                plt.sca(ax)
                plt.plot(x, soft_one_hot_linspace(x, -0.5, 1.5, number=4, basis=b, cutoff=c).pow(2).sum(1))
                plt.plot([-0.5]*2, [-2, 2], 'k-.')
                plt.plot([1.5]*2, [-2, 2], 'k-.')
                plt.title(f"{b}" + (" with cutoff" if c else ""))

        plt.ylim(0, 2)
        plt.tight_layout()
    """
    # pylint: disable=misplaced-comparison-constant

    if cutoff not in [True, False]:
        raise ValueError("cutoff must be specified")

    if not cutoff:
        values = torch.linspace(start, end, number, dtype=x.dtype, device=x.device)
        step = values[1] - values[0]
    else:
        values = torch.linspace(start, end, number + 2, dtype=x.dtype, device=x.device)
        step = values[1] - values[0]
        values = values[1:-1]

    diff = (x[..., None] - values) / step

    if basis == "gaussian":
        return diff.pow(2).neg().exp().div(1.12)

    if basis == "cosine":
        return torch.cos(math.pi / 2 * diff) * (diff < 1) * (-1 < diff)

    if basis == "smooth_finite":
        return 1.14136 * torch.exp(torch.tensor(2.0)) * soft_unit_step(diff + 1) * soft_unit_step(1 - diff)

    if basis == "fourier":
        x = (x[..., None] - start) / (end - start)
        if not cutoff:
            i = torch.arange(0, number, dtype=x.dtype, device=x.device)
            return torch.cos(math.pi * i * x) / math.sqrt(0.25 + number / 2)
        else:
            i = torch.arange(1, number + 1, dtype=x.dtype, device=x.device)
            return torch.sin(math.pi * i * x) / math.sqrt(0.25 + number / 2) * (0 < x) * (x < 1)

    if basis == "bessel":
        x = x[..., None] - start
        c = end - start
        bessel_roots = torch.arange(1, number + 1, dtype=x.dtype, device=x.device) * math.pi
        out = math.sqrt(2 / c) * torch.sin(bessel_roots * x / c) / x

        if not cutoff:
            return out
        else:
            return out * ((x / c) < 1) * (0 < x)

    raise ValueError(f'basis="{basis}" is not a valid entry')




class TorchWrithe(nn.Module):

    """
    Compute writhe for a set of coordinates. 
    Return an (batch, n_atoms, n_atoms) writhe 'adjacentcy' matrix.
    
    """
    
    def __init__(self, n_atoms: int):
        super().__init__()
        self.segments = get_segments(n_atoms)
        self.n_atoms = n_atoms
    
    def to_matrix(self, x):
        return to_writhe_adj_matrix(x, self.n_atoms, 1, self.segments)
    
    def forward(self, x):
        return self.to_matrix(writhe_segments_along_axis(self.segments, x))


class SelfAttention(nn.Module):
    """
    Simple implementation of self attention where we skip the 'value' projection and return only the attention logits (positive)
    """
    
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        #self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        #values = self.value(x)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        #weighted = torch.bmm(attention, values)
        return attention

class WritheEmbedding(nn.Module):
    
    """
    Embed each value of writhe by a super position of a basis set of functions (vectors).
    The weight of each function in the super position is determined by a soft one hot embedding (RBF) of each value of writhe.
    OUTPUT : (batch, n_atoms, n_atoms, bins)
    """
    def __init__(self, embed_dim: int, bins: int=300, basis: str = "gaussian", cutoff: bool=False):
        super().__init__()

        self.soft_one_hot = partial(soft_one_hot_linspace,
                                    start=-1,
                                    end=1,
                                    number=bins,
                                    basis=basis,
                                    cutoff=cutoff)
        
        std = 1. / math.sqrt(embed_dim)

        
        self.functions = nn.Parameter(torch.Tensor(1, 1, 1, bins, embed_dim).uniform_(-std, std), requires_grad=True)
    
    def get_weights(self, x):
        return self.soft_one_hot(x).unsqueeze(-1)
        
    
    def forward(self, x):
        return (self.get_weights(x) * self.functions).sum(-2)
        

class WritheEmbeddedAttention(nn.Module):

    """
    1) Compute writhe from a set of coordinates.
    2) Use Radial Basis Functions to embed each value of writhe the same dimension as node embedding.
    3) Compute attention logits from node embeddings.
    4) Weight embedded writhe values with attention logits and sum for each node with every other node. 
    """
    
    def __init__(self,
                 node_embed_dim: int,
                 bins: int,
                 n_atoms: int,
                 ):
        
        super().__init__()
        
        self.writhe = TorchWrithe(n_atoms)
        self.attn = SelfAttention(node_embed_dim)
        self.embed = WritheEmbedding(embed_dim=node_embed_dim, bins=bins)
    
    def forward(self, node_embeddings: torch.Tensor, xyz: torch.Tensor):
        attn = self.attn(node_embeddings).unsqueeze(-1)
        writhe_embed = self.embed(self.writhe(xyz))
        return node_embeddings + (attn * writhe_embed).sum(1)
        

 # Dummy Dataset for testing layer (initial batch size of 100 sample)

In [42]:
x = torch.randn(100*20*20).reshape(100, 20, 20)
node_embeddings = torch.rand(100, 20, 128)
xyz = torch.rand(100, 20, 3)

print(f" writhe shape : {TorchWrithe(20)(xyz).shape}\n",
f"writhe embed shape : {WritheEmbedding(128)(TorchWrithe(20)(xyz)).shape}\n",
f"attention shape : {SelfAttention(128)(node_embeddings).shape}\n")

 writhe shape : torch.Size([100, 20, 20])
 writhe embed shape : torch.Size([100, 20, 20, 128])
 attention shape : torch.Size([100, 20, 20])



In [43]:
with Timer():
    WritheEmbeddedAttention(node_embed_dim=128, bins=100, n_atoms=20)(node_embeddings, xyz)

Time elapsed 0.18594837188720703 s


 # Scale up to a batch size of 1,000 and 100 bins

In [44]:
import timeit
x = torch.randn(1000*20*20).reshape(1000, 20, 20)
node_embeddings = torch.rand(1000, 20, 128)
xyz = torch.rand(1000, 20, 3)
with Timer():
    WritheEmbeddedAttention(node_embed_dim=128, bins=100, n_atoms=20)(node_embeddings, xyz)

Time elapsed 7.050119161605835 s
