In [1]:
import os
import pickle
import glob
import time
from datetime import datetime as dt
import numpy as np
import torch
import torch.optim as optim

from typing import List, Union, Tuple

from schrodinger.structure import StructureReader


def load_latents(files: str) -> Tuple[torch.Tensor, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Loads latent vectors from .pkl files

    Returns:
        z_list: torch.Tensor; list of latent vectors
        E_list: np.ndarray; list of energies for each latent vector (same order)
        sr_list: np.ndarray; list of protein site rmsd for each latent vector (same order)
        lr_list: np.ndarray; list of ligand rmsd for each latent vector (same order)
        gen_list: np.ndarray; generation of GA each latent vector is from (same order)
    """
    z_list, E_list, sr_list, lr_list, gen_list = [], [], [], [], []
    for file in files:
        with open(file, 'rb') as f:
            ld = pickle.load(f)
        if 'site_rmsd' not in ld['scores']:
            continue
        if 'lig_rmsd' not in ld['scores']:
            continue
        z_list.append(ld['z'])
        E_list.append(ld['scores']['energy'])
        sr_list.append(ld['scores']['site_rmsd'])
        lr_list.append(ld['scores']['lig_rmsd'])
        gen_list.append(ld['curr_gen'])
    z_list, E_list, sr_list, lr_list, gen_list = \
        torch.stack(z_list), np.array(E_list), np.array(sr_list), \
        np.array(lr_list), np.array(gen_list)
    return z_list, E_list, sr_list, lr_list, gen_list

    

In [2]:
# Load latent vector data
files = sorted(glob.glob('pim1_4lmu_pim1_4bzo/pim1_4lmu_pim1_4bzo_optimization/*.pkl'))
z_list, E_list, sr_list, lr_list, gen_list = load_latents(files)


In [3]:
z_list[0]


tensor([ 2.5763,  0.2171, -3.5389,  1.3997,  0.5700,  0.5732,  1.3237, -1.6378,
        -1.2579,  0.3466, -0.0482, -4.1818, -1.4384, -0.5433, -2.5757,  3.2151,
         1.1616,  2.1566, -5.9279, -0.6370,  1.8095, -4.2530, -5.8992, -2.9332,
        -0.6002,  1.1810, -2.4478, -2.2614, -8.9930,  0.8321])

In [5]:
# Your structure file
file = 'pim1_4lmu_pim1_4bzo/pim1_4lmu_pim1_4bzo_optimization/pim1_4lmu_pim1_4bzo_opt_R0-ALL-out.maegz'
# Load single structure
st = StructureReader.read(file)
# Load all structures in file
st_list = [st for st in StructureReader(file)]

# Save structure as .pdb file
# Now you can load as .pdb to view in Py3Dmol or nglview
st.write('temp.pdb')


In [6]:
len(st_list)


1590