In [131]:
import sys, os
import math
import pickle
import numpy as np
import torch
from torch.nn.utils.rnn import
import lmdb
from torch.utils.data import Dataset, DataLoader
sys.path.append('/workspace')
from models.datasets.tokenizer import get_toker

import os
import pandas as pd
import lmdb
from biopandas.pdb import PandasPdb
from tqdm import tqdm
import pickle
import re
import json
import glob

In [151]:
max_length = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


# 1. data

In [145]:
class ArrayTokenizer():
    def __init__(self, n_max, vmin, vsup, fdim, offset=0):
        ooffset = offset
        self.start_token = offset
        self.end_token = offset+1
        offset += 2

        self.n_offset = offset
        self.n_max = n_max
        offset += n_max+1

        self.i_offset = offset
        self.fdim = fdim
        self.vmin = vmin
        self.dmax = vsup - 0.1**(self.fdim+1) - vmin
        offset += vsup-vmin

        self.f_offset = offset
        offset += 10**self.fdim
        self.voc_size = offset - ooffset

    def tokenize(self, array: np.ndarray):
        tokens = []
        tokens.append(self.start_token)
        # n
        tokens.append(min(len(array), self.n_max)+self.n_offset)

        # values
        for v in array.ravel():
            v = max(0, min(self.dmax, v-self.vmin))
            f, i = math.modf(v)
            tokens.append(int(i)+self.i_offset)
            tokens.append(int(f*10**self.fdim)+self.f_offset)

        tokens.append(self.end_token)
        return tokens

class UniMolLigandDataset(Dataset):
    def __init__(self, lmdb_path, voc_path, max_length):
        self.env = None

        self.env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        self.txn = self.env.begin()
        self.keys = list(self.txn.cursor().iternext(values=False))
        self.smi_toker = get_toker(voc_path)
        self.coord_toker = ArrayTokenizer(100, -20, 20, 3, self.smi_toker.voc_size)
        self.coord_rng = np.random.default_rng(seed=0)

    def __getitem__(self, index):
        data = self.txn.get(self.keys[index])
        data = pickle.loads(data)
        tokens = self.smi_toker.tokenize(data['smi'])[:-1]
        tokens += self.coord_toker.tokenize(self.coord_rng.choice(data['coordinates']))        
        return torch.tensor(tokens, dtype=torch.long)
    
    def __len__(self):
        return len(self.keys)

    def __del__(self):
        if self.env is not None:
            self.env.close()

In [147]:
import functools
dataset = UniMolLigandDataset("/workspace/cheminfodata/unimol/ligands/valid.lmdb", "/workspace/cheminfodata/vocs/bchirals2.txt", max_length)
batch_size = 128

In [148]:
collate_fn = functools.partial(pad_sequence, padding_value=dataset.smi_toker.pad_token)

loader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=4, 
    collate_fn=collate_fn, pin_memory=True, 
)
for batch in tqdm(loader):
    pass


100%|██████████| 773/773 [00:07<00:00, 106.54it/s]


In [149]:
for i in tqdm(range(len(dataset))):
    data = dataset[i]

100%|██████████| 98844/98844 [00:27<00:00, 3637.22it/s]


In [126]:
import torch.nn as nn

d_model = 768
nhead = 12
dim_feedforward = d_model*4
dropout = 0.1
num_layers = 15

layer = nn.TransformerDecoderLayer(
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=dim_feedforward,
    dropout=dropout, 
    activation='gelu',
    norm_first=True
)
norm = nn.LayerNorm(d_model, elementwise_affine=False)
model = nn.TransformerDecoder(layer, num_layers, norm)
model.to(device)

optimizer = torch.optim.Adam(model, lr=lr)



# -1. experiments

### アミノ酸配列を復元できないか？

In [3]:
lmdb_path = "/workspace/cheminfodata/unimol/pockets/valid.lmdb"
env = lmdb.open(
    lmdb_path,
    subdir=False,
    readonly=True,
    lock=False,
    readahead=False,
    meminit=False,
    max_readers=256,
)
txn = env.begin()
keys = list(txn.cursor().iternext(values=False))

In [123]:
from collections import defaultdict
residuess = defaultdict(int)
flg = False
atom_types = set()
for key in tqdm(keys):
    data = pickle.loads(txn.get(key))
    atoms = data['atoms']
    atom_types |= set(list(data['atoms']))

    residue = set()
    for i, (a, r) in enumerate(zip(atoms, data['side'])):
        if a == 'OXT': a = 'O'
        if len(a) >= 2 and a[1].isdigit():
            break
        if a == 'NA': break
        if a[0] == 'H': continue

        if r == 0:
            if a in residue:
                residue = ','.join(sorted(residue))

                residuess[residue]+=1
                residue = set()

        residue.add(a)
    if len(residue) > 0:
        residuess[','.join(sorted(residue))]+=1

100%|██████████| 164409/164409 [00:10<00:00, 15292.42it/s]


In [124]:
np.array(sorted(atom_types))

array(['C', "C1'", 'C1B', 'C1D', 'C2', "C2'", 'C2A', 'C2B', 'C2D', 'C2N',
       "C3'", 'C3B', 'C3D', 'C3N', 'C4', "C4'", 'C4A', 'C4B', 'C4D',
       'C4N', 'C5', "C5'", 'C5A', 'C5B', 'C5D', 'C5N', 'C6', 'C6A', 'C6N',
       'C7N', 'C8', 'C8A', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1',
       'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', 'CL', 'CU', 'CZ', 'CZ2',
       'CZ3', 'H', 'H1', "H1'", 'H1B', 'H1D', 'H2', "H2'", 'H2A', 'H2B',
       'H2D', 'H2N', 'H3', "H3'", 'H3B', 'H3D', "H4'", 'H42N', 'H4B',
       'H4D', 'H4N', "H5'", "H5''", 'H51A', 'H51N', 'H52A', 'H52N', 'H5N',
       'H61A', 'H62A', 'H6N', 'H71N', 'H72N', 'H8', 'H8A', 'HA', 'HA2',
       'HA3', 'HB', 'HB1', 'HB2', 'HB3', 'HD1', 'HD11', 'HD12', 'HD13',
       'HD2', 'HD21', 'HD22', 'HD23', 'HD3', 'HE', 'HE1', 'HE2', 'HE21',
       'HE22', 'HE3', 'HG', 'HG1', 'HG11', 'HG12', 'HG13', 'HG2', 'HG21',
       'HG22', 'HG23', 'HG3', 'HH', 'HH11', 'HH12', 'HH2', 'HH21', 'HH22',
       'HN1', 'HN21', 'HN22', "HO2'", 'HO2A', 'HO2N',

In [116]:
print(len(residuess))
df = pd.Series(residuess)
df.sort_values()[::-1].to_csv("residues.tsv", sep='\t')

60


In [57]:
atoms_1 = [a for a in atoms2 if len(a) == 1]
atoms_i = [a for a in atoms2 if len(a) >= 2 and a[1].isdigit()]
atoms_a = [a for a in atoms2 if len(a) >= 2 and not a[1].isdigit()]

In [120]:
np.array(atoms)

array(['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ',
       'OH', 'H', 'HA', 'HB2', 'HB3', 'HD1', 'HD2', 'HE1', 'HE2', 'HH',
       'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', 'H', 'HA',
       'HB2', 'HB3', 'HG2', 'HG3', 'N', 'CA', 'C', 'O', 'CB', 'CG', 'SD',
       'CE', 'H', 'HA', 'HB2', 'HB3', 'HG2', 'HG3', 'HE1', 'HE2', 'HE3',
       'N', 'CA', 'C', 'O', 'CB', 'OG', 'H', 'HA', 'HB2', 'HB3', 'HG',
       'N', 'CA', 'C', 'O', 'CB', 'CG', 'H', 'HA', 'HB2', 'HB3', 'ND1',
       'CD2', 'CE1', 'NE2', 'HD1', 'HD2', 'HE1', 'N', 'CA', 'C', 'O',
       'CB', 'CG', 'OD1', 'OD2', 'H', 'HA', 'HB2', 'HB3', 'CU', 'O', 'H1',
       'H2', 'O', 'H1', 'H2', 'O', 'H1', 'H2'], dtype='<U3')

In [114]:

np.where(np.array(atoms) == 'NA')
atoms[261:]

['NA',
 'O',
 'H1',
 'H2',
 'O',
 'H1',
 'H2',
 'O',
 'H1',
 'H2',
 'O',
 'H1',
 'H2',
 'O',
 'H1',
 'H2']

In [119]:
print(np.array(atoms_i))

["C1'" 'C1B' 'C1D' 'C2' "C2'" 'C2A' 'C2B' 'C2D' 'C2N' "C3'" 'C3B' 'C3D'
 'C3N' 'C4' "C4'" 'C4A' 'C4B' 'C4D' 'C4N' 'C5' "C5'" 'C5A' 'C5B' 'C5D'
 'C5N' 'C6' 'C6A' 'C6N' 'C7N' 'C8' 'C8A' 'H1' "H1'" 'H1B' 'H1D' 'H2' "H2'"
 'H2A' 'H2B' 'H2D' 'H2N' 'H3' "H3'" 'H3B' 'H3D' "H4'" 'H42N' 'H4B' 'H4D'
 'H4N' "H5'" "H5''" 'H51A' 'H51N' 'H52A' 'H52N' 'H5N' 'H61A' 'H62A' 'H6N'
 'H71N' 'H72N' 'H8' 'H8A' 'N1' 'N1A' 'N1N' 'N2' 'N3' 'N3A' 'N6A' 'N7'
 'N7A' 'N7N' 'N9' 'N9A' 'O1A' 'O1B' 'O1N' "O2'" 'O2A' 'O2B' 'O2D' 'O2N'
 'O3' "O3'" 'O3A' 'O3B' 'O3D' "O4'" 'O4B' 'O4D' "O5'" 'O5B' 'O5D' 'O6'
 'O7N']


In [60]:
print(np.array(atoms_a))

['CA' 'CB' 'CD' 'CD1' 'CD2' 'CE' 'CE1' 'CE2' 'CE3' 'CG' 'CG1' 'CG2' 'CH2'
 'CL' 'CU' 'CZ' 'CZ2' 'CZ3' 'HA' 'HA2' 'HA3' 'HB' 'HB1' 'HB2' 'HB3' 'HD1'
 'HD11' 'HD12' 'HD13' 'HD2' 'HD21' 'HD22' 'HD23' 'HD3' 'HE' 'HE1' 'HE2'
 'HE21' 'HE22' 'HE3' 'HG' 'HG1' 'HG11' 'HG12' 'HG13' 'HG2' 'HG21' 'HG22'
 'HG23' 'HG3' 'HH' 'HH11' 'HH12' 'HH2' 'HH21' 'HH22' 'HN1' 'HN21' 'HN22'
 "HO2'" 'HO2A' 'HO2N' "HO3'" 'HO3A' 'HO3N' 'HOA2' 'HOB2' 'HOB3' 'HZ' 'HZ1'
 'HZ2' 'HZ3' 'NA' 'ND1' 'ND2' 'NE' 'NE1' 'NE2' 'NH1' 'NH2' 'NZ' 'OD1'
 'OD2' 'OE1' 'OE2' 'OG' 'OG1' 'OH' 'OXT' 'SD' 'SG']
