In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import InMemoryDataset, Data
import torch_geometric.transforms as T
import numpy as np

In [295]:
def pad_(X, max_atoms, dim=1):
    extra = max_atoms - X.shape[0]
    return F.pad(X, (0, extra) * dim)

In [308]:
def create_dummy(num_atoms, total_atoms):
    Z = torch.LongTensor(num_atoms).random_(total_atoms)
    D = torch.rand((num_atoms, num_atoms))
    D.masked_fill_(torch.eye(num_atoms).bool(), 0)
    return Z + 1, (D + D.T) / 2

In [328]:
def create_dummy_batch(min_atoms, max_atoms, total_atoms, bs):
    Zs, Ds, sizes = [], [], []
    for num_atoms in torch.randint(min_atoms, max_atoms, (bs,)):
        Z, D = create_dummy(num_atoms.item(), total_atoms)
        Zs.append(pad_(Z, max_atoms))
        Ds.append(pad_(D, max_atoms, 2))
        sizes.append(num_atoms)
    Zs = torch.stack(Zs)
    Ds = torch.stack(Ds)
    return Zs, Ds, torch.LongTensor(sizes)

In [329]:
def transform_D(D, sz):
    shape = list(D.shape) + [sz]
    return D.unsqueeze(-1).expand(shape)

In [344]:
min_atoms = 5
max_atoms = 10
num_atoms = 100
basis = 3
num_gauss = 5
hidden = 6

In [345]:
Zs, Ds, sizes = create_dummy_batch(min_atoms, max_atoms, num_atoms, 20)

In [346]:
Zs

tensor([[ 46,  27,  75,   7,  48,  48,   0,   0,   0,   0],
        [  6,  64,  82,  16,  83,  23,   0,   0,   0,   0],
        [ 35,   9,  61,  21,  48,   0,   0,   0,   0,   0],
        [ 26,  84,  19,  85,  75,  97,  55,  50,  98,   0],
        [ 45,  29,  19,  24,  61,  13,  95,  35,  12,   0],
        [ 36,  17,  98,  61,  38,  79,  90,  58,  35,   0],
        [ 74,  50,  97,   8,  64,   0,   0,   0,   0,   0],
        [ 51,  82,  68,  98,  45,  42,  22,   0,   0,   0],
        [ 65,  36,  27,  72,  76,  63,  43,  67,   0,   0],
        [ 85,  62,  58,  30,  69,  17,  52,   0,   0,   0],
        [ 42,  99,  65,  84,  54,  83,  23,  78,  11,   0],
        [ 47,   8,  86,  68,  91,  33,  78,  26,   0,   0],
        [ 86,  62,  96,  24,  62,  38,  50,   1,  68,   0],
        [ 67,  56,  65,  21,  44,   5,  60,   0,   0,   0],
        [ 71,  24,   3,  92,  57,  40, 100,  43,  14,   0],
        [ 84,  83,  95,  41,  26,  10,  91,  18,   0,   0],
        [ 14,   5,  84,  25,   7,  21,  

In [314]:
sizes

tensor([9, 9, 7, 8, 7, 6, 6, 7, 7, 5])

In [315]:
def create_mask(sizes, full_size):
    masks = []
    
    for size in sizes:
        mask = torch.zeros((full_size, full_size))
        mask[np.diag_indices(size)] = 1
        mask[:size, :size] -= 1
        mask.abs_()
        masks.append(mask)
    
    return torch.stack(masks) if len(sizes) > 1 else mask

In [316]:
def create_mask_1d(sizes, full_size):
    masks = []
    
    for size in sizes:
        mask = torch.zeros((full_size,))
        mask[:size] = 1
        masks.append(mask)
    
    return torch.stack(masks) if len(sizes) > 1 else mask

In [347]:
class InteractionBlock(nn.Module):
    def __init__(self, basis, hidden):
        super().__init__()
        self.cf = nn.Linear(basis, hidden)
        self.fc = nn.Linear(hidden, basis, False)
    
    def forward(self, C, D_hat, sizes):
        X = self.cf(C)
        X = X.unsqueeze(-2) * D_hat
        X = torch.tanh(self.fc(X))
        
        num_batch = C.shape[0] if len(C.shape) > 2 else 1
        mask = create_mask(sizes, max_atoms)
        return (mask.unsqueeze(-1) * X).sum(-3)

In [352]:
class MDTNN(nn.Module):
    def __init__(self, basis, num_atoms, num_gauss, hidden, T=3):
        super().__init__()
        self.basis = basis
        self.T = T
        
        self.C_embed = nn.Embedding(num_atoms + 1, basis)
        self.df = nn.Linear(num_gauss, hidden)
        self.interaction = InteractionBlock(basis, hidden)
        self.mlp = nn.Sequential(nn.Linear(basis, hidden),
                                 nn.Tanh(),
                                 nn.Linear(hidden, 1))
    
    def forward(self, Z, D, sizes):
        C = self.C_embed(Z)
        d_hat = self.df(D)
        
        print(C.shape)
        print(d_hat.shape)
        for _ in range(self.T):
            C += self.interaction(C, d_hat, sizes)
            
        E = self.mlp(C).squeeze()
        mask = create_mask_1d(sizes, max_atoms)
        
        return (mask * E).sum(-1)#.squeeze()

In [353]:
model = MDTNN(basis, num_atoms, num_gauss, hidden)

In [354]:
Zs

tensor([[ 46,  27,  75,   7,  48,  48,   0,   0,   0,   0],
        [  6,  64,  82,  16,  83,  23,   0,   0,   0,   0],
        [ 35,   9,  61,  21,  48,   0,   0,   0,   0,   0],
        [ 26,  84,  19,  85,  75,  97,  55,  50,  98,   0],
        [ 45,  29,  19,  24,  61,  13,  95,  35,  12,   0],
        [ 36,  17,  98,  61,  38,  79,  90,  58,  35,   0],
        [ 74,  50,  97,   8,  64,   0,   0,   0,   0,   0],
        [ 51,  82,  68,  98,  45,  42,  22,   0,   0,   0],
        [ 65,  36,  27,  72,  76,  63,  43,  67,   0,   0],
        [ 85,  62,  58,  30,  69,  17,  52,   0,   0,   0],
        [ 42,  99,  65,  84,  54,  83,  23,  78,  11,   0],
        [ 47,   8,  86,  68,  91,  33,  78,  26,   0,   0],
        [ 86,  62,  96,  24,  62,  38,  50,   1,  68,   0],
        [ 67,  56,  65,  21,  44,   5,  60,   0,   0,   0],
        [ 71,  24,   3,  92,  57,  40, 100,  43,  14,   0],
        [ 84,  83,  95,  41,  26,  10,  91,  18,   0,   0],
        [ 14,   5,  84,  25,   7,  21,  

In [355]:
model(Zs, transform_D(Ds, num_gauss), sizes)

torch.Size([20, 10, 3])
torch.Size([20, 10, 10, 6])


tensor([2.3407, 2.9144, 2.4607, 4.6193, 5.1044, 3.1283, 2.4735, 3.9668, 4.6154,
        2.5609, 5.5254, 2.5080, 3.8188, 2.4784, 3.5806, 4.6213, 3.8068, 2.5653,
        2.2554, 2.9945], grad_fn=<SumBackward1>)

In [275]:
F.pad(D, (0, 5, 0, 5))

tensor([[0.0000, 0.3930, 0.5242, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3930, 0.0000, 0.6171, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5242, 0.6171, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

In [268]:
D.shape

torch.Size([10, 10])

In [292]:
F.pad(D[0], (0, 1))

tensor([0.0000, 0.3930, 0.5242, 0.0000])

In [342]:
create_mask(sizes, max_atoms)

NameError: name 'max_atoms' is not defined