In [None]:
import numpy as np
import torch  # type: ignore
import dgl  # type: ignore
from scipy.spatial import cKDTree  # type: ignore
from pathlib import Path
import torch.nn.functional as F  # type: ignore
from typing import Tuple, Union
from ligmet.featurizer import Features, Info # type: ignore
from ligmet.utils.constants import metals, standard_residues,ATOMIC_NUMBERS, atype2num,sec_struct_dict  # type: ignore
class PreprocessedDataSet(torch.utils.data.Dataset):
    def __init__(self, data_file: str, features_dir: str, rf_result_dir: str,topk: 16, edge_dist_cutoff: 3.0, pocket_dist: 6.0, rf_threshold: 0.5, eps=1e-6):
        super().__init__()
        self.data_file=Path(data_file)
        self.features_dir=Path(features_dir)
        self.rf_result_dir=Path(rf_result_dir)
        self.topk = topk
        self.edge_dist_cutoff=edge_dist_cutoff
        self.pocket_dist=pocket_dist
        self.rf_threshold=rf_threshold
        self.pdbid_lists=[pdb.strip().split(".pdb")[0] for pdb in open(data_file)]
        self.eps = eps
        self.alpha = 5.78
        
    def __len__(self):
        return len(self.pdbid_lists)
    
    def __getitem__(self, index:int):
        G = []
        L = []
        pdb_id = self.pdbid_lists[index]
        feature_path = self.features_dir / f"{pdb_id}.npz"
        rf_result_path = self.rf_result_dir / f"{pdb_id}.npz"
        data = np.load(feature_path,allow_pickle=True)
        features = Features(
            atom_positions=data['atom_positions'],
            atom_names=data['atom_names'],
            atom_elements=data['atom_elements'],
            atom_residues=data['atom_residues'],
            residue_idxs=data['residue_idxs'],
            chain_ids=data['chain_ids'],
            is_ligand=data['is_ligand'],
            metal_positions=data['metal_positions'],
            metal_types=data['metal_types'],
            grid_positions=data['grid_positions'],
            sasas=data['sasa'],
            qs=data['qs'],
            sec_structs=data['sec_structs'],
            bond_masks=data['bond_masks']
        )
        grid_positions = features.grid_positions
        grid_probs = np.load(rf_result_path)
        grid_mask = grid_probs >= self.rf_threshold
        grids_after_rf = grid_positions[grid_mask]
        features_p, pocket_exist = self.find_pocket(features, grids_after_rf)
        
        if pocket_exist is False:
            raise AttributeError("there is no grids after randomforest")
        
        g = self.make_graph(features_p)
        l_prob, l_type, l_vector = self.make_label(features_p)
        labels = torch.cat([l_prob.unsqueeze(1), l_type.unsqueeze(1), l_vector], dim=1)  # shape [N,5]
        G.append(g)
        L.append(labels)
        if not G:
            raise AttributeError(f"{pdb_id} have none type graph")
        info = Info(
            pdb_id=np.array(pdb_id),
            grids_positions=torch.tensor(grids_after_rf, dtype=torch.float32),
            metal_positions=torch.tensor(features.metal_positions, dtype=torch.float32),
            metal_types=torch.tensor(features.metaltype, dtype=torch.long),
        )
        return G, L, info

    def find_pocket(self, features: Features, grids: np.ndarray):
        c_grids = grids
        atom_pos = features.atom_positions

        gtree = cKDTree(c_grids)
        ptree = cKDTree(atom_pos)
        ii = gtree.query_ball_tree(ptree, self.pocket_dist)

        idx = np.unique(np.concatenate([i for i in ii if i], axis=0)).astype(int)
        if len(idx) == 0:
            return None, False

        c_features = Features(
            atom_positions=atom_pos[idx],
            atom_names=features.atom_names[idx],
            atom_elements=features.atom_elements[idx],
            atom_residues=features.atom_residues[idx],
            residue_idxs=features.residue_idxs[idx],
            chain_ids=features.chain_ids[idx],
            is_ligand=features.is_ligand[idx],
            metal_positions=features.metal_positions, 
            metal_types=features.metal_types,
            grid_positions=c_grids,
            sasas=features.sasas[idx],
            qs=features.qs[idx],
            sec_structs=features.sec_structs[idx],
            bond_masks=features.bond_masks[np.ix_(idx, idx)], 
        )

        return c_features, True
    
    def make_label(self, features:Features)->Union[torch.Tensor, torch.Tensor, torch.Tensor]:
        grid = np.array(features.grid_positions, dtype=np.float32)
        grids = torch.from_numpy(grid)
        metal_pos = torch.from_numpy(features.metal_positions)
        metal_types = torch.from_numpy(features.metaltype)

        diff = grids.unsqueeze(1) - metal_pos.unsqueeze(0)  # [g,m,3]
        dist = torch.sqrt(torch.sum(diff**2, dim=-1)) + self.eps  # [g,m]

        exp_dist = torch.exp(-(dist**2) / self.alpha)
        label_p, _ = torch.max(exp_dist, dim=-1)
        label_prob = torch.where(label_p <= 0.1, torch.tensor(0.0), label_p)

        min_dist, min_idx = torch.min(dist, dim=-1)  # [g,]
        label_type = torch.where(
            min_dist <= 2.0, metal_types[min_idx], torch.tensor(len(metals))
        )
        label_vector = diff[torch.arange(diff.size(0)), min_idx]
        
        return label_prob, label_type, label_vector
    
    def make_graph(self, features: Features) -> dgl.DGLGraph:
        xyz = torch.tensor(np.concatenate([features.atom_positions, features.grid_positions]))
        grid_mask = torch.ones(len(xyz))
        grid_mask[: len(features.sasas)] = 0
        n_feats, n_polar_vec = self.get_node_features(features)
        num_nodes = xyz.shape[0]
        edge_index_src, edge_index_dst, e_feats, rel_vec = self.make_edge(features)
        G = dgl.graph((edge_index_src.to(torch.int32), edge_index_dst.to(torch.int32)),num_nodes=num_nodes)
        G.ndata["xyz"] = xyz.to(torch.float32)
        G.ndata["L0"] = n_feats.to(torch.float32)
        G.ndata["L1"] = n_polar_vec.to(torch.float32)
        G.ndata["grid_mask"] = grid_mask.to(torch.float32)
        G.edata["L0"] = e_feats.to(torch.float32)
        G.edata["L1"] = rel_vec.to(torch.float32)
        
        return G
    
    def make_polarity_vector(self, features: Features) -> np.ndarray:
        xyz = torch.from_numpy(features.atom_positions)
        neigh_masks = torch.from_numpy(features.bond_masks)

        self_idx, nei_idx = torch.nonzero(neigh_masks, as_tuple=True)

        xyz_self = xyz * neigh_masks.sum(dim=1, keepdim=True)
        xyz_nei = -xyz[nei_idx]
        xyz_self.scatter_add_(0, self_idx[:, None].expand(-1, 3), xyz_nei)

        polar_vec = F.normalize(xyz_self, dim=1)
        polarity_vectors = torch.cat(
            [polar_vec, torch.zeros(features.grid_positions.shape)], dim=0
        ).numpy()
        return polarity_vectors

    def get_node_features(self, features: Features) -> Tuple[torch.Tensor, torch.Tensor]:
        # num_res = len(features.atom_names)
        num_grids = len(features.grid_positions)

        sasas = torch.from_numpy(features.sasas)
        qs = torch.from_numpy(features.qs)
        sec_structs = torch.from_numpy(features.sec_structs)
        aatype = torch.from_numpy(features.atom_residues)
        atom_chem_type = torch.from_numpy(features.atom_elements)
        
        # one hot features: aatype, atomtype, 2nd structures
        # assign max int for grids
        aatype = [standard_residues.index(res) if res in standard_residues else len(standard_residues) for res in features.atom_residues ]
        grids_aatype = torch.ones(num_grids) * len(standard_residues)+1
        aatype = torch.cat((aatype, grids_aatype))

        atomtype = torch.from_list([ATOMIC_NUMBERS.get(elem,119) for elem in features.atom_elements])
        grids_atomtype = torch.zeros(num_grids)
        atomtype = torch.cat([atomtype, grids_atomtype], dim=0)
        
        ##TODO: ligand gentype
        grids_atomchemtype = torch.ones(num_grids) * len(atype2num)
        atom_chem_type = torch.cat([atom_chem_type, grids_atomchemtype], dim=0)
        
        grids_2nd = torch.ones(num_grids) * len(sec_struct_dict)
        sec_structs = torch.cat([sec_structs, grids_2nd])

        # one-hot encoding
        aatype = F.one_hot(aatype.to(torch.int64), num_classes=len(standard_residues) + 2)
        atomtype = F.one_hot(atomtype.to(torch.int64), num_classes=len(ATOMIC_NUMBERS) + 2)
        sec_structs = F.one_hot(
            sec_structs.to(torch.int64), num_classes=len(sec_struct_dict) + 1
        )
        atom_chemtype = F.one_hot(
            atom_chem_type.to(torch.int64), num_classes=len(atype2num) + 1
        )
        # real value features: sasas, qs
        # assign 0 for grids
        grids_feat = torch.zeros(num_grids)
        sasas = torch.cat((sasas, grids_feat)).unsqueeze(-1)
        qs = torch.cat((qs, grids_feat)).unsqueeze(-1)
        # sasas can have nan value
        sasas = sasas + self.eps

        n_feats = torch.cat(
            [aatype, atomtype, atom_chemtype, sec_structs, sasas, qs], dim=1
        )
        print(
            "aatype, atomtype, atom_chemtype, sec-str, sasas, qs)",
            aatype.shape,
            atomtype.shape,
            atom_chemtype.shape,
            sec_structs.shape,
            sasas.shape,
            qs.shape,
        )
        polarity_vectors = self.make_polarity_vector(features)
        polarity_vectors = torch.tensor(polarity_vectors)
        return n_feats, polarity_vectors

    def onehot_edge_dist(self, dists: torch.Tensor) -> torch.Tensor:
        bin_edges = np.arange(0, self.dist_cutoff + 0.5, 0.5)
        dist_binned = np.digitize(dists, bins=bin_edges) - 1
        one_hot_dist = F.one_hot(
            torch.from_numpy(dist_binned), num_classes=len(bin_edges)
        )
        return one_hot_dist

    def onehot_edge_type(
        self, edge_index_src: torch.Tensor, edge_index_dst: torch.Tensor, num_atom: int
    ) -> torch.Tensor:
        feat = np.zeros_like(edge_index_src)  # p to p :0
        feat[np.where((edge_index_src < num_atom) & (edge_index_dst >= num_atom))] = (
            1  # p to g :1
        )
        feat[np.where((edge_index_src >= num_atom) & (edge_index_dst < num_atom))] = (
            2  # g to p : 2
        )
        feat[np.where((edge_index_src >= num_atom) & (edge_index_dst >= num_atom))] = (
            3  # g to g :3
        )
        one_hot_feat = F.one_hot(torch.from_numpy(feat).to(torch.int64), num_classes=4)
        return one_hot_feat

    def cov_bond(
        self,
        edge_index_src: torch.Tensor,
        edge_index_dst: torch.Tensor,
        num_atom: int,
        features: Features,
    ) -> torch.Tensor:
        # shape (edge, )
        cov_bond = np.zeros(len(edge_index_src))
        prot_idx_mask = (edge_index_src < num_atom) & (edge_index_dst < num_atom)
        idx = (edge_index_src[prot_idx_mask], edge_index_dst[prot_idx_mask])
        cov_bond[prot_idx_mask] = features.bond_masks[tuple(idx)]
        cov_bond = torch.from_numpy(cov_bond)
        return cov_bond

    def make_edge(
        self, features: Features) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        num_atom = int(np.sum(features.atom_mask))
        num_grids = len(features.grid_positions)
        num_nodes = num_atom + num_grids

        node_pos = np.concatenate([features.atom_positions, features.grid_positions], axis=0)
        k_nearest = min(self.topk + 1, num_nodes)

        tree = cKDTree(node_pos)
        dd, ii = tree.query(
            node_pos, k=k_nearest, distance_upper_bound=self.dist_cutoff
        )
        node_pos = torch.from_numpy(node_pos).to(torch.float32)
        index_tensor = torch.arange(num_nodes, dtype=torch.int32)
        edge_index_src = torch.flatten(torch.from_numpy(ii)).to(torch.int32)
        edge_index_dst = torch.repeat_interleave(index_tensor, k_nearest)
        dists = torch.flatten(torch.from_numpy(dd))

        edge_mask = torch.logical_and(edge_index_src != edge_index_dst, edge_index_src != num_nodes)

        edge_index_src = edge_index_src[edge_mask]
        edge_index_dst = edge_index_dst[edge_mask]
        dists = dists[edge_mask]

        dist_bin = self.onehot_edge_dist(dists)
        onehot_type = self.onehot_edge_type(edge_index_src, edge_index_dst, num_atom)
        covalent_bond = self.cov_bond(edge_index_src, edge_index_dst, num_atom, features)
        covalent_bond = covalent_bond.unsqueeze(-1)
        # relative position
        e_vec = torch.tensor(
            node_pos[edge_index_dst.long()] - node_pos[edge_index_src.long()]
        )

        polarity_vectors = torch.tensor(
            self.make_polarity_vector(features), dtype=torch.float32
        )
        # edge_type을 설정: prot-to-prot, grid-to-grid, grid-to-prot, prot-to-grid 구분
        edge_type_prot_to_prot = (edge_index_src < num_atom) & (
            edge_index_dst < num_atom
        )
        edge_type_grid_to_grid = (edge_index_src >= num_atom) & (
            edge_index_dst >= num_atom
        )
        edge_type_grid_to_prot = (edge_index_src >= num_atom) & (
            edge_index_dst < num_atom
        )
        edge_type_prot_to_grid = (edge_index_src < num_atom) & (
            edge_index_dst >= num_atom
        )

        # 초기화
        start = torch.zeros((len(edge_index_src), 3), dtype=torch.float32)
        end = torch.zeros((len(edge_index_src), 3), dtype=torch.float32)

        # 1. prot to prot 또는 grid to grid
        mask = edge_type_prot_to_prot | edge_type_grid_to_grid
        start[mask] = polarity_vectors[edge_index_dst[mask].long()]
        end[mask] = polarity_vectors[edge_index_src[mask].long()]

        # 2. grid to prot
        mask = edge_type_grid_to_prot
        start[mask] = (
            node_pos[edge_index_src[mask].long()]
            - node_pos[edge_index_dst[mask].long()]
        )
        end[mask] = polarity_vectors[edge_index_dst[mask].long()]

        # 3. prot to grid
        mask = edge_type_prot_to_grid
        start[mask] = polarity_vectors[edge_index_src[mask].long()]
        end[mask] = (
            node_pos[edge_index_dst[mask].long()]
            - node_pos[edge_index_src[mask].long()]
        )

        cos = (
            torch.einsum(
                "ij,ij->i",
                start,
                end,
            ).unsqueeze(-1)
            + self.eps
        )
        sin = (
            torch.norm(
                torch.cross(
                    start,
                    end,
                ),
                dim=1,
                keepdim=True,
            )
            + self.eps
        )
        e_feats = torch.cat([onehot_type, dist_bin, covalent_bond, cos, sin], dim=1)
        print(
            "--",
            onehot_type.shape,
            dist_bin.shape,
            covalent_bond.shape,
            cos.shape,
            sin.shape,
        )
        return edge_index_src, edge_index_dst, e_feats, e_vec

    def collate(self, samples: list) -> Tuple[dgl.DGLGraph, torch.Tensor, Info]:
        graphs, labels, g_pos, m_pos, m_types, pdb_ids = [], [], [], [], [], []

        for G, L, info in samples:
            graphs.extend(G)  # 각 샘플의 그래프 리스트를 하나의 리스트로 결합
            labels.extend(L)  # 각 샘플의 결합된 라벨 리스트를 하나의 리스트로 결합
            g_pos.append(info.grids_positions)
            m_pos.append(info.metal_positions)
            m_types.append(info.metal_types)
            pdb_ids.append(info.pdb_id)
        # 배치 그래프와 배치 라벨 생성
        batched_graphs = dgl.batch(graphs)  # shape [B*N]
        batched_labels = torch.cat(labels, dim=0)  # shape [B*N,2]
        g_poss = torch.cat(g_pos, dim=0)
        m_poss = torch.cat(m_pos, dim=0)
        m_typess = torch.cat(m_types, dim=0)
        pdb_idss = np.array(pdb_ids)
        batched_infos = Info(
            pdb_id=pdb_idss,
            grids_positions=g_poss,
            metal_positions=m_poss,
            metal_types=m_typess,
        )
        return batched_graphs, batched_labels, batched_infos
    
    
class OnTheFlyDataSet(torch.utils.data.Dataset):
    def __init__(self, data_file: str, pdb_dir: str, rf_model: str, topk: int, edge_dist_cutoff: float, pocket_dist: float, rf_threshold: float):
        super().__init__()
        self.data_file = Path(data_file)
        self.pdb_dir = Path(pdb_dir)
        self.topk = topk
        self.edge_dist_cutoff=edge_dist_cutoff
        self.pocket_dist=pocket_dist
        self.rf_threshold=rf_threshold
        
    def __len__(self):
        return len()
    
    def __getitem__(self, index:int):
        
        return 1
    
def get_dataset_class(config):
    dataset_type = config["dataset"]["type"]
    
    if dataset_type == "preprocessed":
        return PreprocessedDataSet(**config["dataset"]["preprocessed"])
    elif dataset_type == "on_the_fly":
        return OnTheFlyDataSet(**config["dataset"]["onthefly"])
    else:
        raise ValueError(f"Unknown dataset type: {dataset_type}")

In [1]:
from ligmet.dataset import OnTheFlyDataSet
Dataset = OnTheFlyDataSet(
    data_file='/home/qkrgangeun/LigMet/code/src/ligmet/utils/examples/example.txt',
    pdb_dir='/home/qkrgangeun/LigMet/code/src/ligmet/utils/examples',
    rf_model='random_forest_model',
    topk=16,
    edge_dist_cutoff=3.0,
    pocket_dist=6.0,
    rf_threshold=1.0
)

In [2]:
Dataset[1]



max(idx) 5407
len(features.sasas) 5408
len(atom_names) 5408
len(qs) 5408
len(sec_structs) 5408
len(gen_types) 5408
len(bond_masks) 5408
len(is_ligand == 1) 24
tensor([7., 7., 7.,  ..., 9., 9., 9.])
aatype, atomtype, atom_chemtype, sec-str, sasas, qs) torch.Size([24369, 30]) torch.Size([24369, 120]) torch.Size([24369, 62]) torch.Size([24369, 10]) torch.Size([24369, 1]) torch.Size([24369, 1])
-- torch.Size([359490, 4]) torch.Size([359490, 7]) torch.Size([359490, 1]) torch.Size([359490, 1]) torch.Size([359490, 1])


  e_vec = torch.tensor(


([Graph(num_nodes=24369, num_edges=359490,
        ndata_schemes={'xyz': Scheme(shape=(3,), dtype=torch.float32), 'L0': Scheme(shape=(224,), dtype=torch.float32), 'L1': Scheme(shape=(3,), dtype=torch.float32), 'grid_mask': Scheme(shape=(), dtype=torch.float32)}
        edata_schemes={'L0': Scheme(shape=(14,), dtype=torch.float32), 'L1': Scheme(shape=(3,), dtype=torch.float32)})],
 [tensor([[ 0.0000, 10.0000, -7.0964, -6.6930, 33.5870],
          [ 0.0000, 10.0000, -6.4339, -6.6930, 33.5870],
          [ 0.0000, 10.0000, -8.5594, -6.3890, 34.3202],
          ...,
          [ 0.0000, 10.0000,  1.4016, -5.4000, -2.0850],
          [ 0.0000, 10.0000,  2.0641, -5.4000, -2.0850],
          [ 0.0000, 10.0000,  2.0151, -3.8800, -2.8967]])],
 Info(pdb_id=array('1a05_ligand', dtype='<U11'), grids_positions=tensor([[ 2.8266, 19.1820, 75.3980],
         [ 3.4891, 19.1820, 75.3980],
         [ 1.3636, 19.4860, 76.1312],
         ...,
         [ 0.3646, 20.2520, 21.6750],
         [ 1.0271, 20.2520,

In [6]:
print(len(Dataset))

3


In [None]:
import numpy as np
from scipy.spatial import cKDTree
from collections import defaultdict
from Bio.PDB import PDBParser
import io
from pathlib import Path
import ligmet.utils.pdb import read_pdb
metals = {"ZN", "MG", "FE", "CA", "CU", "MN", "CO", "NI", "NA", "K"}  # Metal 원소 리스트


def find_binding_residues(pdb_path, cutoff=3.0):
    """Metal 주변 3Å 이내의 Binding Residues를 찾는 함수"""
    structure = read_pdb(pdb_path)

    # Metal이 없는 경우 예외 처리
    if len(structure["metal_positions"]) == 0:
        return set()

    # KDTree 생성 (검색 최적화)
    tree = cKDTree(structure["atom_positions"])
    
    # Metal 위치 주변 cutoff 내 원자 찾기
    binding_residues = set()
    for metal_pos in structure["metal_positions"]:
        neigh_idx = tree.query_ball_point(metal_pos, cutoff)
        for idx in neigh_idx:
            res_name = structure["atom_residues"][idx]
            res_idx = structure["residue_idxs"][idx]
            binding_residues.add((res_name, res_idx))  # Residue ID까지 포함하여 저장

    # Residue 이름만 반환
    return [res_name for res_name, _ in binding_residues]

# 사용 예시
pdb_dir = Path("/path/to/pdb_files")  # 실제 PDB 경로로 변경
for pdb_file in pdb_dir.glob("*.pdb"):
    binding_residues = find_binding_residues(pdb_file)
    print(f"{pdb_file.stem}: {binding_residues}")  # PDB ID와 Binding Residues 출력


In [None]:
import csv
from pathlib import Path

# 저장할 CSV 파일 경로
output_csv = "/home/qkrgangeun/LigMet/code/text/biolip/metal_binding_sites.csv"

# PDB 파일이 저장된 디렉토리
pdb_dir = Path("/home/qkrgangeun/LigMet/code/src/ligmet/utils/examples")  # 실제 PDB 파일 경로로 변경

# 결과 저장을 위한 리스트
data = []

for pdb_file in pdb_dir.glob("*.pdb"):
    pdb_id = pdb_file.stem  # PDB ID 추출
    structure = read_pdb(pdb_file)

    # Metal이 없는 경우 건너뜀
    if len(structure["metal_positions"]) == 0:
        continue

    # Metal 원자 처리
    for metal_pos, metal_type in zip(structure["metal_positions"], structure["metal_types"]):
        binding_residues = find_binding_residues(pdb_file)

        # CSV 저장 데이터 생성
        data.append([pdb_id, metal_type, metal_pos.tolist(), binding_residues])

# CSV 파일로 저장
with open(output_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["PDB ID", "Metal Type", "Metal Position", "Binding Residues"])
    for row in data:
        writer.writerow(row)

print(f"✅ Metal binding site 정보가 {output_csv}에 저장되었습니다.")


In [1]:
import os
import numpy as np
from collections import defaultdict

# 파일 경로
pdb_list_file = '/home/qkrgangeun/LigMet/code/text/biolip/filtered/train_pdbs_chain_1_filtered.txt'
metal_data_dir = '/home/qkrgangeun/LigMet/data/biolip/metal_label/'

# PDB ID 목록 읽기
with open(pdb_list_file, 'r') as f:
    pdb_ids = [line.strip() for line in f if line.strip()]

# metal별 포함된 PDB ID 목록 저장용 딕셔너리
metal_to_pdbs = defaultdict(set)

# 각 PDB ID에 대해 metal_types 정보 수집
for pdb_id in pdb_ids:
    npz_path = os.path.join(metal_data_dir, f'{pdb_id}.npz')
    if not os.path.exists(npz_path):
        continue  # 파일이 없으면 생략
    try:
        data = np.load(npz_path, allow_pickle=True)
        metal_types = data.get('metal_types', [])
        for metal in metal_types:
            metal_to_pdbs[metal].add(pdb_id)
    except Exception as e:
        print(f'Error processing {pdb_id}: {e}')

# 결과 출력
print("### Metal 별로 포함된 PDB ID 수 ###")
for metal, pdb_set in metal_to_pdbs.items():
    print(f'{metal}: {len(pdb_set)}개 PDB')

# 각 metal에 대해 포함된 PDB ID 예시 출력 (선택적)
for metal, pdb_set in metal_to_pdbs.items():
    print(f'\n[Metal: {metal}] 포함된 PDB 예시 (최대 10개):')
    for pdb_id in list(pdb_set)[:10]:
        print(pdb_id)

# 필요 시 저장 (예: metal_to_pdbs.json 등으로 저장 가능)
# import json
# with open('metal_to_pdbs.json', 'w') as f:
#     json.dump({k: list(v) for k, v in metal_to_pdbs.items()}, f, indent=2)


### Metal 별로 포함된 PDB ID 수 ###
MN: 1082개 PDB
ZN: 5272개 PDB
MG: 2290개 PDB
CA: 4005개 PDB
CU: 507개 PDB
FE: 511개 PDB
CO: 277개 PDB
K: 25개 PDB
NI: 28개 PDB

[Metal: MN] 포함된 PDB 예시 (최대 10개):
3v91
6dq9
4o7x
5e3u
1i3h
4wte
2jdz
1yyd
2zxp
5ivy

[Metal: ZN] 포함된 PDB 예시 (최대 10개):
3s2l
1xx4
5llg
4xgl
1eh6
3bkk
1akl
2hxv
4oja
6fgs

[Metal: MG] 포함된 PDB 예시 (최대 10개):
5v0d
1ig5
1o03
1zet
5m3u
1ihu
3u7e
2hru
2a31
6bbp

[Metal: CA] 포함된 PDB 예시 (최대 10개):
1kuh
1yzp
6ioz
4asm
5xsa
5olb
2jhm
5g56
1ql9
5b4y

[Metal: CU] 포함된 PDB 예시 (최대 10개):
1w7c
5icu
1jxd
1f1d
4x4k
4yso
5zll
4dpb
4ysp
4hhg

[Metal: FE] 포함된 PDB 예시 (최대 10개):
6f65
1oq9
1qiq
3n9t
1s2z
1brf
3hhy
4x1b
3hfb
5tk5

[Metal: CO] 포함된 PDB 예시 (최대 10개):
6d3j
5yr5
3mz7
1mat
1c21
1qxw
3a3x
1kej
4u6e
3wrs

[Metal: K] 포함된 PDB 예시 (최대 10개):
4d7n
1mei
1krj
1mew
1hpm
3fpb
4jpf
2fxi
1gjv
3zdd

[Metal: NI] 포함된 PDB 예시 (최대 10개):
2y39
4m5b
3kbw
1rze
3kco
2gqk
5wk0
2gql
1ru3
3skd


In [2]:
import os
import numpy as np
import pickle
from collections import defaultdict

# 파일 경로
pdb_list_file = '/home/qkrgangeun/LigMet/code/text/biolip/filtered/train_pdbs_chain_1_filtered.txt'
metal_data_dir = '/home/qkrgangeun/LigMet/data/biolip/metal_label/'

# PDB ID 목록 읽기
with open(pdb_list_file, 'r') as f:
    pdb_ids = [line.strip() for line in f if line.strip()]

# metal → set(pdb_id)
metal_to_pdbs = defaultdict(set)

# PDB ID → list(metal)
pdb_id_to_metals = defaultdict(list)

# 데이터 수집
for pdb_id in pdb_ids:
    npz_path = os.path.join(metal_data_dir, f'{pdb_id}.npz')
    if not os.path.exists(npz_path):
        continue
    try:
        data = np.load(npz_path, allow_pickle=True)
        metal_types = data.get('metal_types', [])
        for metal in metal_types:
            metal_to_pdbs[metal].add(pdb_id)
            pdb_id_to_metals[pdb_id].append(metal)
    except Exception as e:
        print(f'Error processing {pdb_id}: {e}')

# Pickle로 저장
with open('/home/qkrgangeun/LigMet/data/biolip/metal_to_pdbs.pkl', 'wb') as f:
    pickle.dump(metal_to_pdbs, f)

with open('/home/qkrgangeun/LigMet/data/biolip/pdb_id_to_metals.pkl', 'wb') as f:
    pickle.dump(pdb_id_to_metals, f)

print("✅ metal_to_pdbs.pkl 및 pdb_id_to_metals.pkl 저장 완료.")


✅ metal_to_pdbs.pkl 및 pdb_id_to_metals.pkl 저장 완료.


In [1]:
import numpy as np
data = np.load("/home/qkrgangeun/LigMet/data/biolip/dl/features/5xwm.npz")
for key in data:
    print(key, data[key])
print(data["bond_masks"])

atom_positions [[-14.514  22.336  70.118]
 [-14.169  23.759  70.137]
 [-13.261  24.184  68.955]
 ...
 [  4.067  65.245  -7.034]
 [  9.564  77.131  -5.367]
 [  4.305  51.884  -8.604]]
atom_names ['N' 'CA' 'C' ... 'CD2' 'CL' 'CL']
atom_elements ['N' 'C' 'C' ... 'C' 'CL' 'CL']
atom_residues ['GLU' 'GLU' 'GLU' ... 'LEU' 'CL' 'CL']
residue_idxs [  1   1   1 ... 368 369 370]
chain_ids ['A' 'A' 'A' ... 'D' 'D' 'D']
is_ligand [False False False ... False  True  True]
metal_positions [[-14.655  35.938  50.437]
 [ -1.886  39.888  31.409]
 [  6.432  44.094  44.584]
 [ 12.504  62.446   9.381]
 [  8.933  72.885  -2.576]
 [  1.476  51.752  -7.974]
 [ -9.123  39.302  45.964]
 [ 12.179  48.498  48.131]
 [ 10.749  79.756  -6.108]
 [  5.982  57.212  -4.151]]
metal_types ['ZN' 'ZN' 'ZN' 'ZN' 'ZN' 'ZN' 'ZN' 'ZN' 'ZN' 'ZN']
grid_positions [[-13.816576   20.896      70.118    ]
 [-13.119152   20.896      70.118    ]
 [-15.356539   21.216      70.88983  ]
 ...
 [  8.334724   52.124     -11.167381 ]
 [  3.347

In [6]:
import numpy as np
from pathlib import Path
from ligmet.utils.pdb import read_pdb, Structure, StructureWithGrid
from ligmet.utils.grid import *
from ligmet.featurizer import * # type: ignore
from openbabel import openbabel
from dataclasses import asdict
import traceback
pdb_id = '5xwm'
pdb_dir = Path("/home/qkrgangeun/LigMet/data/biolip/pdb")
pdb_path = pdb_dir / f"{pdb_id}.pdb"
structure = read_pdb(pdb_path)
pdb_io, protein_io, ligand_io = make_pdb(structure)
ligand_pdb_str = ligand_io.getvalue()

ligand_mol = None
if ligand_pdb_str.strip():
    ob_conversion = openbabel.OBConversion()
    ob_conversion.SetInFormat("pdb")
    ob_mol = openbabel.OBMol()
    ob_conversion.ReadString(ob_mol, ligand_pdb_str)
    ligand_mol = ob_mol

new_pdb_path = process_pdb(pdb_io)
new_structure = read_pdb(new_pdb_path)
bond_masks = cov_bonds_mask(new_structure, ligand_mol)

In [9]:
sum(bond_masks[0])

2.0

In [10]:
def bondmask_to_neighidx(bond_mask: np.ndarray) -> np.ndarray:
    rows, cols = np.where(np.triu(bond_mask) > 0)
    return np.stack([rows, cols], axis=0).astype(np.int32)

In [12]:
neigh = bondmask_to_neighidx(bond_masks)
neigh

array([[    0,     0,     1, ..., 11398, 11399, 11399],
       [    1,  2752,     2, ..., 11399, 11400, 11401]], dtype=int32)

In [14]:
dd = bondmask_to_neighidx(neigh)
dd

array([[    0,     0,     0, ...,     1,     1,     1],
       [    2,     3,     4, ..., 16369, 16370, 16371]], dtype=int32)