In [3]:
from dataclasses import dataclass
import numpy as np
import io
from collections import defaultdict
from Bio.PDB.PDBParser import PDBParser
from ligmet.utils.constants import metals

@dataclass
class Structure:
    atom_positions: np.ndarray  # [n_atoms, 3]
    atom_names: np.ndarray  # [n_atoms, 1]
    atom_elements: np.ndarray  # [n_atoms, 1]
    atom_residues: np.ndarray  # [n_atoms, 1] if ligand: x
    residue_idxs: np.ndarray #[n_atoms, 1]
    is_ligand: np.ndarray  # [n_atoms, 1]
    metal_positions: np.ndarray  # [n_metals, 3]
    metal_types: np.ndarray  # [n_metals, 1]

@dataclass
class StructureWithGrid:
    atom_positions: np.ndarray  # [n_atoms, 3]
    atom_names: np.ndarray  # [n_atoms, 1]
    atom_elements: np.ndarray  # [n_atoms, 1]
    atom_residues: np.ndarray  # [n_atoms, 1] if ligand: x
    residue_idxs: np.ndarray #[n_atoms, 1]
    is_ligand: np.ndarray  # [n_atoms, 1]
    metal_positions: np.ndarray  # [n_metals, 3]
    metal_types: np.ndarray  # [n_metals, 1]
    grid_positions: np.ndarray #[n_grids, 3]
    
def read_pdb(pdb_path) -> Structure:
    with open(pdb_path, "r") as f:
        pdb_str = f.read()
    pdb_fh = io.StringIO(pdb_str)
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("none", pdb_fh)
    model = list(structure.get_models())[0]

    data = defaultdict(list)

    for chain in model:
        for res in chain:
            if res.id[2] != " ":
                raise ValueError(f"Insertion code found at chain {chain.id}, residue {res.id[1]}")
            if res.id[0] == " ":  # ATOM
                for atom in res:
                    data["atom_positions"].append(atom.coord)
                    data["atom_elements"].append(atom.element)
                    data["atom_residues"].append(res.get_resname())
                    data["atom_names"].append(atom.name)
                    data["is_ligand"].append(0)
                    data["residue_idxs"].append(res.get_id()[1])
            elif "H_" in res.id[0]:  # HETATM except water (which starts with "W_")
                for atom in res.get_atoms():
                    if atom.element in metals:
                        data["metal_positions"].append(atom.coord)
                        data["metal_types"].append(atom.element)
                    else:  # Ligand
                        data["atom_positions"].append(atom.coord)
                        data["atom_elements"].append(atom.element)
                        data["atom_residues"].append(res.get_resname())
                        data["atom_names"].append(atom.name)
                        data["is_ligand"].append(1)
                        data["residue_idxs"].append(res.get_id()[1])


    return Structure(**{k: np.array(v) for k, v in data.items()})

st = read_pdb('/home/qkrgangeun/LigMet/code/src/ligmet/utils/1a05_ligand.pdb')


In [4]:
print(st.atom_positions.shape)
print(st.atom_elements.shape)
print(st.atom_residues.shape)
print(st.residue_idxs.shape)
print(st.is_ligand.shape)
# st.metal_positions
# st.metal_types

(5408, 3)
(5408,)
(5408,)
(5408,)
(5408,)


In [9]:
mask = st.is_ligand
print(mask)
print(st.atom_residues[mask])

[0 0 0 ... 1 1 1]
['MET' 'MET' 'MET' ... 'MET' 'MET' 'MET']


In [23]:
print(mask)
print(st.atom_residues)

[0 0 0 ... 1 1 1]
['MET' 'MET' 'MET' ... 'IPM' 'IPM' 'IPM']


In [1]:
import re
from collections import Counter

# 로그 파일 경로 설정
log_file = '/home/qkrgangeun/LigMet/sh/0403/test2.log'

# "type label tensor(..., device='cuda:0')" 형태의 문자열에서 대괄호 안의 숫자들을 추출하기 위한 정규표현식
pattern = re.compile(r"tensor\(\[([0-9,\s]+)\],\s*device='cuda:0'\)")

counter = Counter()

with open(log_file, 'r') as file:
    for line in file:
        # 해당 줄에 관심 있는 문자열이 포함되어 있는 경우에 한해 처리합니다.
        if "type label tensor" in line:
            match = pattern.search(line)
            if match:
                # 대괄호 안의 문자열 추출 (숫자와 쉼표, 공백)
                numbers_str = match.group(1)
                # 쉼표를 기준으로 나눈 후 정수로 변환
                numbers = [int(item.strip()) for item in numbers_str.split(',') if item.strip().isdigit()]
                counter.update(numbers)

# 0부터 11까지 각 숫자의 등장 빈도 출력
for num in range(12):
    print(f"{num}: {counter[num]} occurrences")


0: 1918 occurrences
1: 2306 occurrences
2: 160 occurrences
3: 1212 occurrences
4: 188 occurrences
5: 29 occurrences
6: 81 occurrences
7: 130 occurrences
8: 17 occurrences
9: 0 occurrences
10: 0 occurrences
11: 0 occurrences


In [2]:
import pandas as pd

# CSV 파일 경로 설정
file_path = '/home/qkrgangeun/LigMet/code/text/biolip/metal_binding_sites3.csv'

# CSV 파일을 DataFrame으로 읽어옵니다.
df = pd.read_csv(file_path)

# "Metal Type" 컬럼의 값에 대해 빈도수를 계산합니다.
metal_counts = df['Metal Type'].value_counts()

# 결과 출력
print("Metal Type 별 빈도수:")
print(metal_counts)


Metal Type 별 빈도수:
Metal Type
ZN    34358
CA    28560
MG    17566
MN     8274
FE     4763
CU     4171
CO     1515
NI      319
K       271
Name: count, dtype: int64


validation - precision,recall test: chain1_pre, posweight 10

In [None]:
pdb_id = '5s8q'
result_file = f'/home/qkrgangeun/LigMet/data/biolip/test/0507_rf/{pdb_id}.npz'
grid_file = f'/home/qkrgangeun/LigMet/data/biolip/dl/features/{pdb_id}.npz'
threshold = 0.5

data = np.load(result_file)
feature = np.load(grid_file)
pred = data['pred']
type_pred = data['type_pred']
label = data['label']
type_label = data['type_label']    
metal_positions = data['metal_positions']
metal_types = data['metal_types']
grid = feature['grid_positions']
# position recall
def write_pdb_with_grids(
    pdb_id,
    metal_positions,
    grid_positions,
    grid_predictions,
    grid_type_predictions,
    pdb_input_dir,
    pdb_output_dir,
    pred_threshold=0.5,
):
    """
    1) 원본 PDB 파일(pdb_id.pdb)을 읽어서
    2) metal pred >= pred_threshold 를 만족하는 grid 좌표에 대해
       HETATM 라인을 추가해 저장.
    """
    os.makedirs(pdb_output_dir, exist_ok=True)

    input_pdb_path = os.path.join(pdb_input_dir, f"{pdb_id}.pdb")
    output_pdb_path = os.path.join(pdb_output_dir, f"{pdb_id}.pdb")

    if not os.path.exists(input_pdb_path):
        print(f"[WARNING] {input_pdb_path} not found. Skipping.")
        return

    with open(input_pdb_path, "r") as infile:
        pdb_lines = []
        for line in infile:
            if line.startswith("ATOM") or line.startswith("HETATM"):
                pdb_lines.append(line)

    with open(output_pdb_path, "w") as outfile:
        # 1) 기존 PDB 내용 먼저 기록
        for line in pdb_lines:
            outfile.write(line)

        # 2) 조건을 만족하는 그리드 좌표 기록
        start_idx = 0  # 임의로 레지듀 번호 시작
        for idx, (grid_pos, grid_pred, grid_type_pred) in enumerate(
            zip(grid_positions, grid_predictions, grid_type_predictions)
        ):
            if grid_pred >= pred_threshold:
                metal_type_idx = torch.argmax(torch.tensor(grid_type_pred)).item()
                if metal_type_idx < len(metals):
                    metal_type = metals[metal_type_idx]
                else:
                    metal_type = "UNK"

                atom_idx = start_idx + idx
                x, y, z = grid_pos
                outfile.write(
                    f"HETATM{atom_idx:>5}  {metal_type:>3} GRD A{atom_idx:>4}    "
                    f"{x:8.3f}{y:8.3f}{z:8.3f}  {grid_pred:.2f}  0.00           {metal_type}\n"
                )

    print(f"[INFO] Saved PDB with grids: {output_pdb_path}")


# ========== 모든 PDB에 대해 write_pdb_with_grids 수행 ==========
def save_all_pdb_with_grids(results, infos, pdb_input_dir, pdb_output_dir):
    """
    각 샘플별로 write_pdb_with_grids()를 호출하여
    pred >= 특정 threshold 만족하는 그리드를 HETATM으로 기록
    """
    for i, info in enumerate(infos):
        pdb_id = info.pdb_id[0]
        metal_positions = info.metal_positions.cpu().numpy()
        grid_positions = info.grids_positions.cpu().numpy()
        grid_predictions = results[i][0].cpu().numpy()  # pred
        grid_type_preds = results[i][1].cpu().numpy()  # type_pred

        write_pdb_with_grids(
            pdb_id,
            metal_positions,
            grid_positions,
            grid_predictions,
            grid_type_preds,
            pdb_input_dir,
            pdb_output_dir,
            pred_threshold=0.0,  # 필요에 따라 수정
        )



In [None]:
import re

# 로그 파일 경로
log_path = "/home/qkrgangeun/LigMet/benchmark/test_chain1_pre3.log"

# 그룹 분류용 딕셔너리 초기화
group_precision_recall = {
    "A": [],  # recall > 0.7 and precision > 0.5
    "B": [],  # recall > 0.7 and precision <= 0.5
    "C": [],  # recall <= 0.7 and precision > 0.5
    "D": [],  # recall <= 0.7 and precision <= 0.5
}

group_type_accuracy = {
    "HIGH": [],  # type_accuracy > 0.5
    "LOW": []    # type_accuracy <= 0.5
}

with open(log_path, "r") as f:
    lines = f.readlines()

pdb_id = None
for i, line in enumerate(lines):
    # PDB ID 찾기
    if line.startswith("=== PDB:"):
        pdb_id_match = re.search(r"\['(.+?)'\]", line)
        if pdb_id_match:
            pdb_id = pdb_id_match.group(1)

    # precision & recall 값 찾기
    if pdb_id and "threshold 0.5 | precision:" in line:
        pr_match = re.search(r"precision: ([0-9.]+) \| recall: ([0-9.]+)", line)
        if pr_match:
            precision = float(pr_match.group(1))
            recall = float(pr_match.group(2))
            # 그룹 분류
            if recall > 0.7:
                if precision > 0.5:
                    group_precision_recall["A"].append(pdb_id)
                else:
                    group_precision_recall["B"].append(pdb_id)
            else:
                if precision > 0.5:
                    group_precision_recall["C"].append(pdb_id)
                else:
                    group_precision_recall["D"].append(pdb_id)

    # type_accuracy 값 찾기
    if pdb_id and "threshold 0.5 | type_accuracy:" in line:
        acc_match = re.search(r"type_accuracy: ([0-9.]+)", line)
        if acc_match:
            type_acc = float(acc_match.group(1))
            if type_acc > 0.5:
                group_type_accuracy["HIGH"].append(pdb_id)
            else:
                group_type_accuracy["LOW"].append(pdb_id)

# 결과 출력
print("=== Precision/Recall 그룹 ===")
for group, ids in group_precision_recall.items():
    print(f"Group {group}: {ids}")

print("\n=== Type Accuracy 그룹 ===")
for group, ids in group_type_accuracy.items():
    print(f"{group}: {ids}")

