In [1]:
# okay, just take a breath and start with something easy and canonical
!pip install MDAnalysis




In [2]:
import mdtraj as md
import numpy as np

# load the trajectory
top_file = '/project/dinner/zpengmei/Geom2Vec/Tutorial/adp/alanine-dipeptide-nowater.pdb'
traj_file = '/project/dinner/zpengmei/Geom2Vec/Tutorial/adp/alanine-dipeptide-0-250ns-nowater.xtc'
traj = md.load(traj_file, top=top_file)

traj

<mdtraj.Trajectory with 250000 frames, 22 atoms, 3 residues, and unitcells at 0x7f2f426a1510>

In [5]:
# convert to angstroms
xyz = traj.xyz * 10
atomic_numbers = [atom.element.atomic_number for atom in traj.topology.atoms]
atomic_numbers = np.array(atomic_numbers)

# exclude hydrogens
mask = atomic_numbers != 1
xyz = xyz[:, mask]
atomic_numbers = atomic_numbers[mask]

print('shape after excluding hydrogens:', xyz.shape)

shape after excluding hydrogens: (250000, 10, 3)


In [6]:
# load the pretrained model to featurize the trajectory

from Geom2Vec.geom2vec.models.torchmd.main_model import create_model, get_args
import torch

hidden_channels = 256
layers = 9
nhead = 8

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

args = get_args(hidden_channels=hidden_channels, num_layers=layers, num_rbf=64, num_heads=nhead, cutoff=7.5)
model = create_model(args=args)
checkpoint = torch.load('/project/dinner/zpengmei/Geom2Vec/geom2vec/checkpoints/et_l9_r64_c75_256_denali.pth')
model.load_state_dict(checkpoint)
model.to(device)

  _torch_pytree._register_pytree_node(


TorchMD_Net(
  (representation_model): TorchMD_ET(hidden_channels=256, num_layers=9, num_rbf=64, rbf_type=expnorm, trainable_rbf=False, activation=silu, attn_activation=silu, neighbor_embedding=None, num_heads=8, distance_influence=both, cutoff_lower=0.0, cutoff_upper=7.5)
  (output_model_noise): EquivariantScalar(
    (output_network): ModuleList(
      (0): GatedEquivariantBlock(
        (vec1_proj): Linear(in_features=256, out_features=256, bias=False)
        (vec2_proj): Linear(in_features=256, out_features=128, bias=False)
        (update_net): Sequential(
          (0): Linear(in_features=512, out_features=256, bias=True)
          (1): SiLU()
          (2): Linear(in_features=256, out_features=256, bias=True)
        )
        (act): SiLU()
      )
      (1): GatedEquivariantBlock(
        (vec1_proj): Linear(in_features=128, out_features=128, bias=False)
        (vec2_proj): Linear(in_features=128, out_features=1, bias=False)
        (update_net): Sequential(
          (0): Li

In [8]:
from tqdm import tqdm
from torch_scatter import scatter

# inferencing the trajectory to get the features

def get_features(xyz_list, model, batch_size=100, device='cuda'):
    features = []
    for traj_xyz in xyz_list:
        traj_xyz = torch.tensor(traj_xyz, device=device).float()
        with torch.no_grad():
            model.eval()
            num_atoms = len(atomic_numbers)
            z = torch.from_numpy(atomic_numbers).to(device)
            out_rep_list = []
            # Split the trajectory into batches
            for pos_batch in tqdm(torch.split(traj_xyz, batch_size, dim=0)):
                n_samples, n_atoms, _ = pos_batch.shape
                z_batch = z.expand(n_samples, -1).reshape(-1).to(device)
                batch_batch = (
                    torch.arange(n_samples).unsqueeze(1).expand(-1, n_atoms).reshape(-1)
                ).to(device)
                x_rep, _, _ = model(z=z_batch, pos=pos_batch.reshape(-1,3).contiguous().to(device), batch=batch_batch)
                # Move the data to CPU and append to the output list
                x_rep = x_rep.reshape(-1, num_atoms, hidden_channels)
                out_rep = x_rep.sum(1)
                out_rep_list.append(out_rep.detach().cpu().numpy())
                torch.cuda.empty_cache()

            features.append(np.concatenate(out_rep_list, axis=0))

    return features

In [9]:
xyz_list = [xyz]
features = get_features(xyz_list, model, batch_size=100, device=device)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2500/2500 [00:46<00:00, 53.22it/s]


In [12]:
features[0].shape

(250000, 256)

In [13]:
# save the features

np.savez('/project/dinner/zpengmei/Geom2Vec/Tutorial/adp/adp_gnn_feat', features = features[0])