In [3]:
import h5py
import numpy as np
import random

In [5]:
### Load neighborhood data from tiny_pdb_dir (25 random pdbs)

filename = "/gscratch/scrubbed/wgalvin/clean_clone/protein_holography-pytorch/protein_holography_pytorch/neighborhoods_tiny.hdf5"

max_atoms = 1000
dt = np.dtype([
    ('res_id','S6', (6)), # S5, 5 (old) ; S6, 6 (new with 2ndary structure)
    ('atom_names', 'S4', (max_atoms)),
    ('elements', 'S1', (max_atoms)),
    ('res_ids', 'S6', (max_atoms, 6)), # S5, 5 (old) ; S6, 6 (new with 2ndary structure)
    ('coords', 'f8', (max_atoms, 3)),
    ('SASAs', 'f8', (max_atoms)),
    ('charges', 'f8', (max_atoms)),
])

BACKBONE_ATOMS = [b' N  ', b' CA ', b' C  ', b' O  ']

with h5py.File(filename, "r") as f:
    data = np.unique(np.array(f['data'], dtype=dt), axis=0)

In [31]:
def get_res_ids(neighborhood):
    """
    Returns set of unique residue IDs
    found in the res_ids field
    """
    ids = set()
    for item in neighborhood['res_ids']:
        ids.add(id_to_str(item))
    return ids

def id_to_str(res_id):
    """
    Converts a res_id to a string; 1 to 1
    """
    return '_'.join([x.decode('utf-8') for x in res_id])

def sample_ids(ids, p):
    """
    Samples p proportion of list
    """
    ids = list(ids)
    random.shuffle(ids)
    return [ids[i] for i in range(int(p * len(ids)))]

def get_mask(res_ids, atom_names, ids, max_atoms):
    """
    Returns a boolean mask of size (max_atoms) where
    m[i] is True if corresponding res_id is in the list of 
    res_ids to keep
    """
    mask = np.zeros((max_atoms), dtype=bool)
    for i, (res_id, atom_name) in enumerate(zip(res_ids, atom_names)):
        mask[i] = id_to_str(res_id) in ids or atom_name in BACKBONE_ATOMS
    return mask


def pad(arrays, max_atoms):
    """
    Returns LIST of ndarrays padded to max_atoms
    """
    return [pad_arr(arr, max_atoms) for arr in arrays];

def pad_arr(arr, padded_length):
    # get dtype of input array
    dt = arr.dtype

    # shape of sub arrays and first dimension (to be padded)
    shape = arr.shape[1:]
    orig_length = arr.shape[0]

    # check that the padding is large enough to accomdate the data
    if padded_length < orig_length:
        print('Error: Padded length of {}'.format(padded_length),
              'is smaller than original length of array {}'.format(orig_length))

    # create padded array
    padded_shape = (padded_length,*shape)
    mat_arr = np.zeros(padded_shape, dtype=dt)

    # add data to padded array
    mat_arr[:orig_length] = np.array(arr)

    return mat_arr

def downsample(neighborhood, p, max_atoms=1000, remove_central=True):
    """
    Takes a neighborhood, removes all but p proportion
    of sidechains.
    
    Leaves backbone atoms in place. 
    
    Returns a copy of the neighborhood modifies
    
    ```
    # USAGE: 
    for neighborhood in data:
        neighborhood = downsample(neighborhood, .5)
        ...
    ```
    """
    
    # all ids
    ids = get_res_ids(neighborhood)
    
    if remove_central:
        ids.remove(id_to_str(neighborhood['res_id']))
    
    #ids to keep
    ids = sample_ids(ids, p)
    
    mask = get_mask(neighborhood['res_ids'], neighborhood['atom_names'], ids, max_atoms)
    
    info = [neighborhood['res_id'], 
    neighborhood['atom_names'][mask],
    neighborhood['elements'][mask],
    neighborhood['res_ids'][mask],
    neighborhood['coords'][mask],
    neighborhood['SASAs'][mask],
    neighborhood['charges'][mask]]
    
    info[1:] = pad(info[1:], max_atoms)
    
    x = np.zeros(shape=(1), dtype=dt)
    x[0] = (*info, )
    return x[0]

    