In [1]:
# okay, just take a breath and start with something easy and canonical
!pip install MDAnalysis




In [2]:
import MDAnalysis as mda
import numpy as np

top_file = '/beagle3/dinner/kjeong/Insulin_md/input/parm/charmm-gui-3w7y-dimer-ins/step3_pbcsetup.psf'
traj_file = '/beagle3/dinner/kjeong/Insulin_md/run_openmm/output_cat/concat_s00i00.dcd'
u = mda.Universe(top_file, traj_file)




In [3]:
from collections import Counter

def count_segments(numbers):
    segments = []
    current_segment = [numbers[0]]
    
    for i in range(1, len(numbers)):
        if numbers[i] == numbers[i-1]:
            current_segment.append(numbers[i])
        else:
            segments.append(current_segment)
            current_segment = [numbers[i]]
    segments.append(current_segment)  # Add the last segment
    
    # Count elements in each segment
    segment_counts = [Counter(segment) for segment in segments]
    
    segment_counts_array = np.array([list(segment_count.values()) for segment_count in segment_counts])

    segment_counts_array = np.concatenate(segment_counts_array, axis=0)

    return segment_counts_array

def extract_info(u):

    # Select protein atoms, excluding hydrogens
    protein = u.select_atoms('protein and prop mass > 1.5 ')

    # Get all residues in the protein selection
    protein_residues = protein.residues

    # Calculate the center of mass of the protein
    protein_center = protein.center_of_mass()

    # Get all TIP3P atoms (adjust the resname as necessary)
    all_waters = u.select_atoms('resname TIP3')

    # Calculate the distances from each lipid atom to the protein center of mass
    distances = np.linalg.norm(all_waters.positions - protein_center, axis=1)

    # Filter atoms within 15 angstroms of the protein center
    waters_near_protein = all_waters[distances < 15]

    all_water_prot = protein + waters_near_protein

    # get the mapping from mass to atom type
    atomic_masses = all_water_prot.masses
    # just match the first three decimal places
    atomic_masses = np.round(atomic_masses, 3)
    mass_mapping = {'C':12.011, 'N':14.007, 'O':15.999,'P':30.974, 'H': 1.008, 'S': 32.06}
    # map to atom types
    atomic_types = [list(mass_mapping.keys())[list(mass_mapping.values()).index(mass)] for mass in atomic_masses]
    # mapping from atom type to atomic number
    atomic_mapping = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'P': 15, 'S': 16}
    # get atomic numbers
    atomic_numbers = [atomic_mapping[atom] for atom in atomic_types]

    unique_numbers, counts = np.unique(all_water_prot.resids, return_counts=True)

    # get the positions of the whole trajectory
    # loop over the trajectory frames
    positions = []
    for ts in u.trajectory:
        # get the positions of the lipids
        positions.append(all_water_prot.positions.copy())
    
    positions = np.array(positions)

    segment_counts = count_segments(all_water_prot.resids)

    # stride 3

    positions = positions[::3]

    return positions, np.array(atomic_numbers), np.array(segment_counts)

positions, atomic_numbers, segment_counts = extract_info(u)

In [5]:
# Select protein atoms, excluding hydrogens
protein = u.select_atoms('protein and prop mass > 1.5 ')
# Get all residues in the protein selection
protein_residues = protein.residues
protein_residues

<ResidueGroup with 102 residues>

In [6]:
segment_counts # insulin dimer has 102 residues and remaining are environment

array([ 4,  8,  7,  9,  9,  6,  6,  7,  6,  8,  6,  6,  8, 12,  9,  8,  9,
        8, 12,  6,  9, 11,  7,  8,  9, 10,  8,  6,  4,  6, 10,  8,  7,  9,
        5,  8, 12,  8,  7,  6,  4,  9, 11,  4, 11, 11, 12,  7,  7,  9,  8,
        4,  8,  7,  9,  9,  6,  6,  7,  6,  8,  6,  6,  8, 12,  9,  8,  9,
        8, 12,  6,  9, 11,  7,  8,  9, 10,  8,  6,  4,  6, 10,  8,  7,  9,
        5,  8, 12,  8,  7,  6,  4,  9, 11,  4, 11, 11, 12,  7,  7,  9,  8,
        3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  1,  3,  3,
        3,  3,  3,  3,  1,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
        3,  1,  2,  1,  1,  3,  3,  3,  3,  3,  1,  3,  3,  3,  3,  3,  1,
        3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  2,  3,  3,  2,  3,
        3,  3,  3,  3,  1,  3,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
        3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  1,  3,  3,  3,
        3,  3,  3,  3,  3,  3,  3,  3,  3,  1,  3,  1,  3,  3,  3,  3,  3,
        3,  3,  3,  3,  3

In [7]:
# load the pretrained model to featurize the trajectory

from Geom2Vec.geom2vec.models.torchmd.main_model import create_model, get_args
import torch

hidden_channels = 256
layers = 9
nhead = 8

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

args = get_args(hidden_channels=hidden_channels, num_layers=layers, num_rbf=64, num_heads=nhead, cutoff=7.5)
model = create_model(args=args)
checkpoint = torch.load('/project/dinner/zpengmei/Geom2Vec/geom2vec/checkpoints/et_l9_r64_c75_256_denali.pth')
model.load_state_dict(checkpoint)
model.to(device)

  _torch_pytree._register_pytree_node(


TorchMD_Net(
  (representation_model): TorchMD_ET(hidden_channels=256, num_layers=9, num_rbf=64, rbf_type=expnorm, trainable_rbf=False, activation=silu, attn_activation=silu, neighbor_embedding=None, num_heads=8, distance_influence=both, cutoff_lower=0.0, cutoff_upper=7.5)
  (output_model_noise): EquivariantScalar(
    (output_network): ModuleList(
      (0): GatedEquivariantBlock(
        (vec1_proj): Linear(in_features=256, out_features=256, bias=False)
        (vec2_proj): Linear(in_features=256, out_features=128, bias=False)
        (update_net): Sequential(
          (0): Linear(in_features=512, out_features=256, bias=True)
          (1): SiLU()
          (2): Linear(in_features=256, out_features=256, bias=True)
        )
        (act): SiLU()
      )
      (1): GatedEquivariantBlock(
        (vec1_proj): Linear(in_features=128, out_features=128, bias=False)
        (vec2_proj): Linear(in_features=128, out_features=1, bias=False)
        (update_net): Sequential(
          (0): Li

In [8]:
from tqdm import tqdm
from torch_scatter import scatter

def get_features(xyz_list, counts, model):
    features = []
    features_masked = []
    for traj_xyz in xyz_list:
        traj_xyz = torch.tensor(traj_xyz, device=device).float()
        count = torch.tensor(counts, device=device).long()
        with torch.no_grad():
            model.eval()
            # batch the xyz
            batch_size = 10
            num_atoms = len(atomic_numbers)
            z = torch.from_numpy(atomic_numbers).to(device)
            out_rep_list = []
            out_rep_masked_list = []
            for pos_batch in tqdm(torch.split(traj_xyz, batch_size, dim=0)):
                n_samples, n_atoms, _ = pos_batch.shape
                z_batch = z.expand(n_samples, -1).reshape(-1).to(device)
                batch_batch = (
                    torch.arange(n_samples).unsqueeze(1).expand(-1, n_atoms).reshape(-1)
                ).to(device)
                x_rep, v_rep, _ = model(z=z_batch, pos=pos_batch.reshape(-1,3).contiguous().to(device), batch=batch_batch)
                # Move the data to CPU and append to the output list
                x_rep = x_rep.reshape(-1, num_atoms, hidden_channels)
                out_rep = x_rep
                sca_map = torch.repeat_interleave(torch.arange(count.shape[0], device=device), count, dim=0)
                out_rep_masked = scatter(out_rep, sca_map, dim=1, reduce='add')[:,0:102,:] # only keep the protein residues 
                out_rep_list.append(out_rep.sum(1).detach().cpu().numpy())
                out_rep_masked_list.append(out_rep_masked.detach().cpu().numpy())

                torch.cuda.empty_cache()

            features.append(np.concatenate(out_rep_list, axis=0))
            features_masked.append(np.concatenate(out_rep_masked_list, axis=0))
    return features, features_masked

In [9]:
xyz_list = [xyz]
features = get_features(xyz_list, model, batch_size=100, device=device)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2500/2500 [00:46<00:00, 53.22it/s]


In [12]:
# loop over the trajectory folder with .dcd files

import os
import glob
from tqdm import tqdm
top_file = '/beagle3/dinner/kjeong/Insulin_md/input/parm/charmm-gui-3w7y-dimer-ins/step3_pbcsetup.psf'
folder = '/beagle3/dinner/kjeong/Insulin_md/run_openmm/output_cat/'

traj_files = glob.glob(folder + '/*.dcd')
num_trajs = len(traj_files)

for i in range(0,int(num_trajs)):
    traj_file = traj_files[i]
    u = mda.Universe(top_file, traj_file)
    positions, atomic_numbers, segment_counts = extract_info(u)
    mask = torch.from_numpy(segment_counts).to(device)
    features, features_masked = get_features([positions], mask, model)
    # put the path to save the features
    np.savez(f'your_path/features_{i}', features=features, features_masked=features_masked, traj_file=traj_file)


(250000, 256)