In [2]:
INPUT_TSV  = "/workspace/binding_affinity/datasets/CSAR-HiQ/CSAR36.tsv"   
OUTPUT_PKL = "/workspace/binding_affinity/datasets/CSAR-HiQ/36_complex.pkl"
USE_LABELS = True  

import pickle
import pandas as pd
import numpy as np
import torch
import esm
import os
os.environ["PATH"] += os.pathsep + "/opt/conda/envs/team05/bin"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  

AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
charge_dict = {"D": -1, "E": -1, "K": 1, "R": 1, "H": 1}
polar_dict = {"S": 1, "T": 1, "N": 1, "Q": 1, "Y": 1, "C": 1, "D": 1, "E": 1, "K": 1, "R": 1, "H": 1}
hydro_dict = {
    "A": 1.8, "C": 2.5, "D": -3.5, "E": -3.5, "F": 2.8,
    "G": -0.4, "H": -3.2, "I": 4.5, "K": -3.9, "L": 3.8,
    "M": 1.9, "N": -3.5, "P": -1.6, "Q": -3.5, "R": -4.5,
    "S": -0.8, "T": -0.7, "V": 4.2, "W": -0.9, "Y": -1.3
}
weight_dict = {
    "A": 89.1, "C": 121.2, "D": 133.1, "E": 147.1, "F": 165.2,
    "G": 75.1, "H": 155.2, "I": 131.2, "K": 146.2, "L": 131.2,
    "M": 149.2, "N": 132.1, "P": 115.1, "Q": 146.2, "R": 174.2,
    "S": 105.1, "T": 119.1, "V": 117.1, "W": 204.2, "Y": 181.2
}

def clean_sequence(seq: str) -> str:
    """
    Replace unsupported amino acids with 'X' (unknown).
    ESM alphabet supports 20 AA + B, Z, X.
    """
    allowed = set("ACDEFGHIKLMNPQRSTVWYBXZ")
    return "".join([aa if aa in allowed else "X" for aa in seq])

def get_physchem_features(sequence: str) -> np.ndarray:
    """Return [L, 6] = [Acidic, Basic, Neutral, Polar, Hydrophobicity, MolWeight]."""
    feats = []
    for aa in sequence:
        charge = charge_dict.get(aa, 0)
        acidic  = 1 if charge == -1 else 0
        basic   = 1 if charge == 1 else 0
        neutral = 1 if charge == 0 else 0
        polar   = polar_dict.get(aa, 0)
        hydro   = float(hydro_dict.get(aa, 0.0))
        weight  = float(weight_dict.get(aa, 0.0))
        feats.append([acidic, basic, neutral, polar, hydro, weight])
    return np.array(feats, dtype=np.float32)

print("1. Load data ...")
df = pd.read_csv(INPUT_TSV, sep="\t")
if USE_LABELS:
    IDs = df.iloc[:, 0].astype(str).values
    SEQS = df.iloc[:, 2].astype(str).values
    BIND = df.iloc[:, 3].astype(str).values
else:
    IDs = df.iloc[:, 0].astype(str).values
    SEQS = df.iloc[:, 2].astype(str).values

print("2. Load pretrained ESM2 model ...")
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
batch_converter = alphabet.get_batch_converter()

print("3. Compute (physchem + ESM) per residue and concat along L ...")
feat_list = []

with torch.no_grad():
    for seq in SEQS:
        chain_seqs = [s.strip() for s in seq.split(",") if s.strip()]
        per_residue_blocks = []

        for chain in chain_seqs:

            chain = clean_sequence(chain)
            data = [("protein", chain)]
            _, _, tokens = batch_converter(data)
            tokens = tokens.to(device)

            out = model(tokens, repr_layers=[33], return_contacts=False)
            token_repr = out["representations"][33]   
            emb = token_repr[0, 1:-1].detach().cpu().numpy().astype(np.float32) 

            phys = get_physchem_features(chain)       

            per_chain = np.concatenate([phys, emb], axis=1)
            per_residue_blocks.append(per_chain)

        combined = np.concatenate(per_residue_blocks, axis=0) if len(per_residue_blocks) > 1 else per_residue_blocks[0]
        feat_list.append(combined)

print("4. Save to pickle ...")
os.makedirs(os.path.dirname(OUTPUT_PKL), exist_ok=True)
with open(OUTPUT_PKL, "wb") as f:
    if USE_LABELS:
        pickle.dump((IDs, SEQS, BIND, feat_list), f)
    else:
        pickle.dump((IDs, SEQS, feat_list), f)
print(f"✅ Done. Saved: {OUTPUT_PKL}")


1. Load data ...
2. Load pretrained ESM2 model ...
3. Compute (physchem + ESM) per residue and concat along L ...
4. Save to pickle ...
✅ Done. Saved: /workspace/binding_affinity/datasets/generalization/36_complex.pkl


In [5]:
import os
import pandas as pd
from typing import List
from rdkit import Chem
from rdkit.Chem import Descriptors

def compute_global_properties(smiles: str) -> list[float]:
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return [0.0] * 10
    return [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol),
    ]

input_csv  = "/workspace/binding_affinity/datasets/CSAR36/CSAR36.tsv"
output_pkl = "/workspace/binding_affinity/datasets/CSAR36/36_smi_features.pkl"

df = pd.read_csv(input_csv, sep="\t", engine="python")

if "SMILES" not in df.columns:
    raise KeyError(f"'SMILES' 컬럼이 없습니다. 현재 컬럼: {list(df.columns)[:10]} ...")

df["global_feat"] = [compute_global_properties(smi) for smi in df["SMILES"].tolist()]

os.makedirs(os.path.dirname(output_pkl), exist_ok=True)
df.to_pickle(output_pkl)
print(f"저장 완료 → {output_pkl}")


저장 완료 → /workspace/binding_affinity/datasets/generalization/36_smi_features.pkl


In [1]:
from __future__ import annotations
import os, re, ast, pickle
from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
from transformers import AutoTokenizer

PDB_RE = re.compile(r"^[0-9][A-Za-z0-9]{3}$")  

def _parse_bs(text: Any, index_base: int, Lp: Optional[int] = None) -> np.ndarray:
    if text is None or (isinstance(text, float) and np.isnan(text)):
        return np.zeros((0,), dtype=np.int64)
    if isinstance(text, (list, tuple, np.ndarray)):
        arr = np.asarray(text, dtype=np.int64)
    else:
        parts = [t.strip() for t in str(text).split(",") if t.strip()]
        arr = np.asarray([int(t) for t in parts], dtype=np.int64)
    if index_base == 1:
        arr = arr - 1
    if Lp is not None:
        arr = arr[(arr >= 0) & (arr < Lp)]
    return arr.astype(np.int64)

def _parse_global_feat(v: Any) -> np.ndarray:
    if v is None:
        return np.zeros((0,), dtype=np.float32)
    if isinstance(v, (list, tuple, np.ndarray)):
        return np.asarray(v, dtype=np.float32)
    s = str(v).strip()
    if not s:
        return np.zeros((0,), dtype=np.float32)
    try:
        x = ast.literal_eval(s)
    except Exception:
        toks = [t for t in re.split(r"[\s,]+", s.strip("[]")) if t]
        x = [float(t) for t in toks]
    return np.asarray(x, dtype=np.float32)

def _choose_target(row: dict|pd.Series, mode: str) -> float:
    if mode == "pAff":
        return float(row["pAff"])
    if mode == "Affinity_nM":
        return float(row["Affinity_nM"])
    if mode == "Affinity_nM_log10":
        nM = max(float(row["Affinity_nM"]), 1e-12)
        return float(9.0 - np.log10(nM))
    raise ValueError(mode)

def _load_complex_pkl(path: str, force_idx: tuple[int|None,int|None,int|None,int|None]=(None,None,None,None)) -> Dict[str, Dict[str, Any]]:
    """complex.pkl(tuple-of-columns) → {PDB: {'prot_feats',(Lp,C) float32, 'prot_mask',(Lp,) bool}}"""
    with open(path, "rb") as f:
        obj = pickle.load(f)
    if not isinstance(obj, (list, tuple)):
        raise TypeError(f"complex.pkl top-level must be tuple/list, got {type(obj).__name__}")
    cols = list(obj)
    ncols = len(cols)
    lengths = [len(c) if hasattr(c, "__len__") else None for c in cols]
    valid_idx = [i for i,L in enumerate(lengths) if L is not None]
    if not valid_idx:
        raise ValueError("No list-like columns in complex.pkl")
    N = lengths[valid_idx[0]]
    if not all(lengths[i] == N for i in valid_idx):
        raise ValueError(f"Columns have different lengths: {[lengths[i] for i in valid_idx]}")

    def kind_auto(col):
        m = min(20, len(col))
        cnt_id=cnt_feat2d=cnt_mask1d=cnt_seq=0
        for i in range(m):
            v = col[i]
            if isinstance(v, str) and PDB_RE.match(v): cnt_id+=1
            if isinstance(v, np.ndarray) and v.ndim==2 and np.issubdtype(v.dtype, np.number): cnt_feat2d+=1
            if isinstance(v, np.ndarray) and v.ndim==1 and (v.dtype==bool or np.issubdtype(v.dtype, np.integer)):
                x = v[:min(len(v),20)]
                if x.size>0 and np.isin(np.unique(x), [0,1,True,False]).all(): cnt_mask1d+=1
            if isinstance(v, str) and len(v)>=50: cnt_seq+=1
        if cnt_id>=m*0.7: return "ids"
        if cnt_feat2d>=m*0.7: return "feats"
        if cnt_mask1d>=m*0.7: return "mask"
        if cnt_seq>=m*0.7: return "seq"
        return "other"

    id_i, feat_i, mask_i, seq_i = force_idx
    if id_i is None or feat_i is None:
        kinds = [kind_auto(cols[i]) for i in range(ncols)]
        if id_i   is None: id_i   = next((i for i,k in enumerate(kinds) if k=="ids"), None)
        if feat_i is None: feat_i = next((i for i,k in enumerate(kinds) if k=="feats"), None)
        if mask_i is None: mask_i = next((i for i,k in enumerate(kinds) if k=="mask"), None)

    if id_i is None or feat_i is None:
        raise ValueError(f"Could not find ids/features in complex.pkl (force_idx={force_idx})")

    ids, feats_col = cols[id_i], cols[feat_i]
    mask_col = cols[mask_i] if (mask_i is not None) else None

    out = {}
    for i in range(N):
        pid   = str(ids[i]).strip()
        feats = feats_col[i].astype(np.float32) if isinstance(feats_col[i], np.ndarray) else np.asarray(feats_col[i], dtype=np.float32)
        if feats.ndim != 2:
            raise ValueError(f"[{pid}] features is not 2D: {feats.shape}")
        Lp = feats.shape[0]
        if mask_col is not None and isinstance(mask_col[i], np.ndarray) and mask_col[i].ndim==1 and len(mask_col[i])==Lp:
            pmask = mask_col[i].astype(bool)
        else:
            pmask = np.ones((Lp,), dtype=bool)
        out[pid] = {"prot_feats": feats, "prot_mask": pmask}
    return out

def merge_to_train(
    smi_pkl: str,
    complex_pkl: str,
    out_path: str,
    *,
    index_base: int = 0,
    use_y: str = "pAff",
    chemberta_model: str = "seyonec/ChemBERTa-zinc-base-v1",
    max_len: int = 256,
    force_complex_idx: tuple[int|None,int|None,int|None,int|None]=(None,None,None,None),
):
    df = pd.read_pickle(smi_pkl)
    assert isinstance(df, pd.DataFrame), "smi_pkl must be a DataFrame pkl"
    req = {"PDB","Sequence","BS","SMILES","global_feat"}
    miss = req - set(df.columns)
    if miss: raise KeyError(f"{smi_pkl} missing columns: {sorted(miss)}")

    prot_map = _load_complex_pkl(complex_pkl, force_idx=force_complex_idx)

    tokzr = AutoTokenizer.from_pretrained(chemberta_model, use_fast=True)
    smiles_list = df["SMILES"].astype(str).tolist()
    tok = tokzr(smiles_list, padding="max_length", truncation=True,
                max_length=max_len, return_attention_mask=True, return_tensors=None)
    all_ids, all_masks = tok["input_ids"], tok["attention_mask"]

    samples, not_found = [], 0
    for i, row in enumerate(df.itertuples(index=False)):
        s = row._asdict()
        pdb = str(s["PDB"]).strip()
        if pdb not in prot_map:
            not_found += 1
            continue

        feats = prot_map[pdb]["prot_feats"]; pmask = prot_map[pdb]["prot_mask"]
        Lp = feats.shape[0]
        br_idx = _parse_bs(s["BS"], index_base, Lp=Lp)
        lig_ids = np.asarray(all_ids[i], dtype=np.int64)
        lig_msk = np.asarray(all_masks[i], dtype=np.int64).astype(bool)
        lig_glb = _parse_global_feat(s["global_feat"]).astype(np.float32)
        yval    = np.float32(_choose_target(s, use_y))

        samples.append({
            "id": pdb,
            "prot_feats": feats.astype(np.float32),
            "prot_mask":  pmask.astype(bool),
            "br_indices": br_idx.astype(np.int64),
            "lig_token_ids": lig_ids,
            "lig_mask": lig_msk,
            "lig_global": lig_glb,
            "y": yval,
        })

    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    with open(out_path, "wb") as f:
        pickle.dump(samples, f)

    kept = len(samples)
    print(f"Saved {kept} samples -> {out_path} (skipped {not_found} without protein features)")
    if kept:
        Lp, C = samples[0]["prot_feats"].shape
        Ll = samples[0]["lig_token_ids"].shape[0]
        Fg = samples[0]["lig_global"].shape[0]
        print(f"Example shapes: prot_feats=({Lp},{C}), lig_len={Ll}, lig_global_dim={Fg}")


In [2]:
merge_to_train(
    smi_pkl="/workspace/binding_affinity/datasets/CSAR36/36_smi_features.pkl",
    complex_pkl="/workspace/binding_affinity/datasets/CSAR36/36_complex.pkl",
    out_path="/workspace/binding_affinity/datasets/CSAR36/CSAR36.pkl",
    index_base=0,
    use_y="pAff",
    chemberta_model="seyonec/ChemBERTa-zinc-base-v1",
    max_len=256,
)




Saved 36 samples -> /workspace/binding_affinity/datasets/generalization/CSAR36.pkl (skipped 0 without protein features)
Example shapes: prot_feats=(443,1286), lig_len=256, lig_global_dim=10
