In [None]:
from copy import deepcopy
from moleculekit.molecule import Molecule
from moleculekit.periodictable import periodictable
import numpy as np
import torch as pt
from torch.utils.benchmark import Timer
from torchmdnet.models.model import create_model


# TensorNet
model_1 = create_model({
    'embedding_dimension': 32,
    'num_layers': 2,
    'num_linears_tensor': 2,
    'num_linears_scalar': 2,
    'num_rbf': 32,
    'rbf_type': 'expnorm',
    'trainable_rbf': False,
    'activation': 'silu',
    'neighbor_embedding': True,
    'cutoff_lower': 0.0,
    'cutoff_upper': 4.5,
    'max_z': 100,
    'max_num_neighbors': 128,
    'model': 'tensornet',
    'aggr': 'add',
    'derivative': False,
    'atom_filter': -1,
    'prior_model': None,
    'output_model': 'Scalar',
    'reduce_op': 'add'
})

# ET
model_2 = create_model({
    'embedding_dimension': 64,
    'attn_activation': 'silu',
    'num_layers': 4,
    'num_heads': 8,
    'distance_influence': 'both',
    'num_rbf': 32,
    'rbf_type': 'expnorm',
    'trainable_rbf': False,
    'activation': 'silu',
    'neighbor_embedding': True,
    'cutoff_lower': 0.0,
    'cutoff_upper': 5.0,
    'max_z': 100,
    'max_num_neighbors': 128,
    'model': 'equivariant-transformer',
    'aggr': 'add',
    'derivative': False,
    'atom_filter': -1,
    'prior_model': None,
    'output_model': 'Scalar',
    'reduce_op': 'add'
})


def benchmark(model, pdb_file, device, compute_forces=True, batch_size=1):

    model = deepcopy(model).to(device)

    # Get molecular data
    molecule = Molecule(pdb_file)
    atomic_numbers = pt.tensor([periodictable[symbol].number for symbol in molecule.element], dtype=pt.long, device=device)
    positions = pt.tensor(molecule.coords[:,:,0], dtype=pt.float32, device=device).to(device)

    # Setup a batch
    batch = pt.flatten(pt.tile(pt.arange(batch_size).unsqueeze(1), (1, len(atomic_numbers)))).to(device)
    atomic_numbers = pt.tile(atomic_numbers, (batch_size,))
    positions = pt.tile(positions, (batch_size, 1)).detach()

    # Setup the force computation
    assert not (compute_forces and (batch_size > 1))
    positions.requires_grad = compute_forces
    
    # Benchmark
    stmt = f'''
        energy = model(atomic_numbers, positions, batch)
        {'energy[0].sum().backward()' if compute_forces else ''}
        '''
    timer = Timer(stmt=stmt, globals=locals())
    speed = timer.blocked_autorange(min_run_time=10).mean * 1000 # s --> ms

    return speed

# Benchmarking speed
device = pt.device('cuda:0')
systems = [('/systems/alanine_dipeptide.pdb', 'ALA2'),
           ('/systems/chignolin.pdb', 'CLN'),
           ('/systems/dhfr.pdb', 'DHFR'),
           ('/systems/factorIX.pdb', 'FC9')]

methods = [('TensorNet', model_1), ('ET', model_2)]

speed_methods = {}
for meth, model in methods:
    speed_methods[meth] = {}
    print(f'Method: {meth}')
    for pdb_file, name in systems:
        try:
            speed = benchmark(model, pdb_file, device, compute_forces=True, batch_size=1)
            speed_methods[meth][name] = speed
            print(f'  {name}: {speed} ms/it')
        except Exception as e:
            print(e)
            print(f'  {name}: failed')