In [2]:
import h5py
import numpy as np
import os
from moleculib.protein.datum import ProteinDatum
from tqdm import tqdm
import pickle
from collections import defaultdict

In [3]:
base_path = f'../../workspace/new_fold3d/'

In [5]:
from biotite.structure import Atom, array

def h5_to_moleculib(h5File):
    atoms = []
    for (chain_id, atom_name, pos, res_id, res_name) in zip(
        h5File['atom_chain_names'],
        h5File['atom_names'],
        h5File['atom_pos'][0],
        h5File['atom_residue_id'],
        h5File['atom_residue_names'],
    ):
        element= atom_name.decode('UTF-8')[0]
        atom = Atom(
            pos, 
            chain_id=chain_id.decode('UTF-8'),
            res_id=res_id, 
            res_name=res_name.decode('UTF-8'), 
            element=element, 
            hetero=False, 
            atom_name=atom_name.decode('UTF-8')
        )
        atoms.append(atom) 
    atom_array = array(atoms)  
    header = dict(
        idcode='aeho',
        resolution=None
    )
    return ProteinDatum.from_atom_array(atom_array, header=header)


In [6]:
SPLITS = ["train", "valid", "test_fold", "test_family", "test_superfamily"]
label_files = [os.path.join(base_path, '%s.txt' % split) for split in SPLITS]


In [7]:
classmap = os.path.join(base_path, 'class_map.txt')
with open(classmap, "r") as fin:
    lines = [line.strip() for line in fin.readlines()]
    class_map = dict([line.split('\t') for line in lines])
label_list = {}
for fname in label_files:
    label_file = open(fname, 'r')
    for line in label_file.readlines():
        line = line.strip().split('\t')
        name, label = line[0], line[-1]
        label_list[name] = int(class_map[label])

In [9]:
from tqdm.contrib.concurrent import process_map 

def process_file(fname):
    h5File = h5py.File(fname, 'r')
    datum = h5_to_moleculib(h5File)
    id = fname[fname.rfind('/') + 1:fname.rfind('.')]
    label = label_list[id]
    datum.idcode = id
    datum.fold_label = label
    return datum

#### Prepare Dataset

In [17]:
data = defaultdict(list)

In [18]:
for split in SPLITS:
    split_path = os.path.join(base_path, split)
    files = os.listdir(split_path)
    files = [os.path.join(split_path, fname) for fname in files]    
    data[split] = process_map(process_file, files, max_workers=8)


  data[split] = process_map(process_file, files, max_workers=8)


  0%|          | 0/12312 [00:00<?, ?it/s]

  0%|          | 0/736 [00:00<?, ?it/s]

  0%|          | 0/718 [00:00<?, ?it/s]

  data[split] = process_map(process_file, files, max_workers=8)


  0%|          | 0/1272 [00:00<?, ?it/s]

  data[split] = process_map(process_file, files, max_workers=8)


  0%|          | 0/1254 [00:00<?, ?it/s]

In [19]:

with open('fold3d.pyd', 'wb') as fout:
    print('Saving data to fold3d.pyd')
    pickle.dump(data, fout)

Saving data to fold3d.pyd


#### Load Dataset

In [None]:
with open('fold3d.pyd', 'rb') as fin:
    data = pickle.load(fin)

In [119]:
for datum in data['train'][:5]:
    viz_datum_grid([datum], window_size=(300, 300)).show()

In [None]:
for k, v in h5File.items():
    print(k, v.shape)

In [None]:
from biotite.structure import AtomArray, Atom, array

In [None]:
from tempfile import gettempdir
from biotite.database import rcsb
import biotite.structure.io.mmtf as mmtf

In [None]:
filepath = rcsb.fetch('1AKE', "mmtf", gettempdir())
mmtf_file = mmtf.MMTFFile.read(filepath)
atom_array = mmtf.get_structure(mmtf_file, model=1)

In [None]:
atom_array[0]

In [None]:
atom_array

In [None]:
viz_datum([datum], window_size=(300, 300))

In [None]:
datum.atom_coord