In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
torch.set_default_dtype(torch.float64)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [None]:
from pao_tfn_dataset import PAODataset
from pao_tfn_net import PAONet
import torch.nn.functional

In [None]:
# assuming MOLOPT-DZVP as primary basis set
prim_basis_shells = {
    'H': [2, 1, 0], # two s-shells, one p-shell, no d-shells
    'O': [2, 2, 1], # two s-shells, two p-shells, one d-shell
}
pao_basis_size = 4
kind_name = "O"

dataset = PAODataset(kind_name)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)

net = PAONet(num_kinds=len(prim_basis_shells),
             pao_basis_size=pao_basis_size,
             prim_basis_shells=prim_basis_shells[kind_name],
             num_hidden=1)
net.train()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

for epoch in range(100):
    epoch_missmatch = 0
    epoch_penalty = 0
    epoch_loss = 0
    for batch in dataloader:
        kind_onehot, coords, sample_indices = batch    
        diff_M = se3cnn.point_utils.difference_matrix(coords)

        # forward pass
        output_net = net(kind_onehot, diff_M)

        missmatch = torch.tensor(0.0)
        penalty = torch.tensor(0.0)
        
        #TODO: batchify this to speed things up
        for i, idx in enumerate(sample_indices):  # loop over batch
            # We only care about the xblock of the central atom, which we rolled to the front.
            xblock_net = net.decode_xblock(output_net[i,:,0])
            
            # We penalize non-unit vectors later, but we are not going to rely on it here.
            xblock_net_unit = torch.nn.functional.normalize(xblock_net)
            #TODO: This might not be ideal as it implicitly foces the pao basis vectors to be orthogonal.
            projector = torch.matmul(torch.t(xblock_net_unit), xblock_net_unit)
            
            xblock_sample = dataset.sample_xblocks[idx]
            residual = torch.t(xblock_sample) - torch.matmul(projector, torch.t(xblock_sample))
            missmatch += torch.norm(residual)

            # penalize non-unit basis vector
            penalty = torch.norm(1 - torch.norm(xblock_net, dim=1))
               
        loss = missmatch + penalty
        epoch_loss += loss.item()
        epoch_missmatch += missmatch.item()
        epoch_penalty += penalty.item()

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch: %i  Missmatch: %f  Penalty: %f Loss: %f"%(epoch, epoch_missmatch, epoch_penalty, epoch_loss))