In [1]:
!nvidia-smi

Thu Aug  7 15:39:26 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:65:00.0 Off |                  Off |
| 74%   33C    P8             15W /  450W |   11853MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA L40S                    On  |   00

In [4]:
import numpy as np
import pandas as pd
import os
import torch
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from src.utils.training_utils import set_seed
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
torch.use_deterministic_algorithms(True)

In [9]:
#from src.utils.esmfinetune import VarDataset, VarCollator
from src.utils.esmfinetune import MultiTaskVarDataset, MultiTaskVarCollator
from src.varmodel_MT_AllLoRA_ESM3_proj import CYPVarAM
#from src.varmodel import CYPVarAM

from transformers import AutoTokenizer, EsmModel
from peft import LoraConfig, get_peft_model

from transformers import AutoModelForMaskedLM
esm_model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
tokenizer = esm_model.tokenizer
set_seed(42)
intermed_list = []
for i in range(30):
    if i > 24:
        intermed_list.append(str(i)+".attn.layernorm_qkv.1")
        intermed_list.append(str(i)+".attn.out_proj")
        intermed_list.append(str(i)+".ffn.1")
        intermed_list.append(str(i)+".ffn.3")
       

config = LoraConfig(
    r=16,
    lora_alpha=16,
    bias="none",
    #use_dora=True,
    target_modules=intermed_list#["layernorm_qkv.1", "out_proj", "ffn.1", "ffn.3"]#"query", "key", "value", "dense"] + intermed_list
)
lora_esm_model = get_peft_model(esm_model, config)

# for param in esm_model.parameters():
#     param.requires_grad = False
model = CYPVarAM(esm_model = lora_esm_model, drop_att = 0.1, drop_pff = 0.1,  input_size = 960, hidden_size = 300, num_heads = 6, num_tasks = 3)



In [10]:
file_path = "./esm3_MT_small_raw_5layer/checkpoint-11694/model.safetensors"
loaded = load_file(file_path)
model.load_state_dict(loaded, strict=True)

device = "cuda"
model.to(device)

CYPVarAM(
  (am_feature): Featurizer(
    (esm): PeftModel(
      (base_model): LoraModel(
        (model): ESMplusplusForMaskedLM(
          (embed): Embedding(64, 960)
          (transformer): TransformerStack(
            (blocks): ModuleList(
              (0-24): 25 x UnifiedTransformerBlock(
                (attn): MultiHeadAttention(
                  (layernorm_qkv): Sequential(
                    (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
                    (1): Linear(in_features=960, out_features=2880, bias=False)
                  )
                  (out_proj): Linear(in_features=960, out_features=960, bias=False)
                  (q_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
                  (k_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
                  (rotary): RotaryEmbedding()
                )
                (ffn): Sequential(
                  (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
           

In [11]:
import pickle
with open("data/uniprot_cyp_variant_protvar_250709.pkl", "rb") as f:
    all_cyp_variants_rev = pickle.load(f)
wt_seq_dict = all_cyp_variants_rev[all_cyp_variants_rev['variant']=='WT'][['Gene', 'Sequence']].set_index('Gene').to_dict()['Sequence']
all_cyp_variants_rev['wt_seq'] = all_cyp_variants_rev['Gene'].map(lambda x: wt_seq_dict[x])
#foldx_shifted = all_cyp_variants_rev['foldx_score'] - all_cyp_variants_rev['foldx_score'].min()
#all_cyp_variants_rev['foldx_score'] = np.log1p(foldx_shifted)
all_cyp_variants_rev = all_cyp_variants_rev[all_cyp_variants_rev['Gene'].isin(['CYP2D6'])]

In [12]:
import pandas as pd
import numpy as np
import random
from collections import defaultdict

def extract_mutation_positions(mut_position_array):
    """numpy array에서 1.0인 위치들을 찾아서 반환"""
    if hasattr(mut_position_array, '__iter__'):
        try:
            positions = [i+1 for i, val in enumerate(mut_position_array) if val == 1.0]
            return positions
        except:
            return []
    return []

def convert_single_missense_variants(df):
    """
    기존 single missense의 variant 컬럼을 A35T 형태로 변환
    """
    print("Converting single missense variants to A35T format...")
    
    converted_df = df.copy()
    converted_count = 0
    
    for idx, row in converted_df.iterrows():
        if row['variant'] == 'missense' and row['status'] == 'Success':
            # Position 추출
            positions = extract_mutation_positions(row['mut_position'])
            
            if positions:
                sequence = list(row['Sequence'])
                wt_sequence = list(row['wt_seq']) 
                variant_descriptions = []
                
                for pos in positions:
                    pos_idx = pos - 1  # 0-based indexing
                    if 0 <= pos_idx < len(sequence) and pos_idx < len(wt_sequence):
                        wt_aa = wt_sequence[pos_idx]
                        mut_aa = sequence[pos_idx]
                        variant_descriptions.append(f"{wt_aa}{pos}{mut_aa}")
                
                if variant_descriptions:
                    variant_string = ", ".join(variant_descriptions)
                    converted_df.at[idx, 'variant'] = variant_string
                    converted_count += 1
    
    print(f"Converted {converted_count} single missense variants")
    return converted_df


In [13]:
all_cyp_variants_rev = convert_single_missense_variants(all_cyp_variants_rev)

Converting single missense variants to A35T format...
Converted 1332 single missense variants


In [14]:
all_cyp_variants_rev

Unnamed: 0,Gene,Sequence,mut_position,variant,am_score,am_label,esm_score,conserv_score,foldx_score,status,mut_sum,esm1v_score,wt_seq
6537,CYP2D6,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",WT,0.0000,wild_type,0.000,1.000,0.000000,wild_type,0.0,0.000000,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
6538,CYP2D6,MRLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",G2R,0.2136,BENIGN,-7.061,0.867,0.026774,Success,1.0,-1.637347,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
6539,CYP2D6,MVLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",G2V,0.1502,BENIGN,-5.508,0.867,0.146168,Success,1.0,-0.716359,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
6540,CYP2D6,MWLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",G2W,0.1895,BENIGN,-7.153,0.867,0.092842,Success,1.0,-0.486700,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
6541,CYP2D6,MGQEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",L3Q,0.0798,BENIGN,-4.553,0.940,0.424936,Success,1.0,-2.643526,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
7865,CYP2D6,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",V495M,0.1450,BENIGN,-6.897,0.649,-0.409300,Success,1.0,-3.354896,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
7866,CYP2D6,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",P496S,0.1115,BENIGN,-6.671,0.669,1.920530,Success,1.0,-3.539027,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
7867,CYP2D6,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",R497C,0.4641,AMBIGUOUS,-9.810,1.000,2.516240,Success,1.0,-6.398916,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
7868,CYP2D6,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",R497H,0.3970,AMBIGUOUS,-7.684,1.000,3.502170,Success,1.0,-6.000917,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...


In [15]:
all_vars = all_cyp_variants_rev[['variant', 'Sequence', 'wt_seq']].drop_duplicates()
wt_texts = all_vars['wt_seq'].values
vr_texts = all_vars['Sequence'].values
#custom_collator = VarCollator()
custom_collator = MultiTaskVarCollator()

import torch


model.eval()
vr_feat_dict = {}
set_seed(0)
for index, item in all_vars.iterrows():
    wt_seq = item[2]
    vr_seq = item[1]
    wt_inputs = tokenizer(wt_seq, return_tensors="pt", add_special_tokens=True).to(device)
    vr_inputs = tokenizer(vr_seq, return_tensors="pt", add_special_tokens=True).to(device)
    
    with torch.no_grad():
        temp = model.am_feature(wt_inputs['input_ids'], wt_inputs['attention_mask'], vr_inputs['input_ids'], vr_inputs['attention_mask'])
        #temp = model.am_feature(vr_inputs['input_ids'], vr_inputs['attention_mask'])
        vr_feat_dict[item[0]] = temp[0].detach().to('cpu').numpy()

  wt_seq = item[2]
  vr_seq = item[1]
  vr_feat_dict[item[0]] = temp[0].detach().to('cpu').numpy()


In [None]:
vr_feat_dict

In [16]:
import pickle
with open("data/CYP2D6_variant_ESM3_missense_250714.pkl", "wb") as f:
    pickle.dump(vr_feat_dict, f)

In [17]:
import pandas as pd
cyp2d6 = pd.read_csv("data/cyp2d6_final_ours_preprocessed_250407.csv")
cyp2d6

Unnamed: 0,Allele,Substrate,cl_rev,vmax_rev,km_rev,source,avg_cl_rev,cl,wt_seqs,vr_seqs
0,*1,Amitriptyline,0.322035,3.37200,10.9000,KIT,0.322035,-1.133096,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
1,*1,Aripiprazole,0.082551,2.59250,31.4050,KIST,0.082551,-2.494344,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
2,*1,Atomoxetine,6.303000,55.31000,8.7750,KIT,6.303000,1.841026,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
3,*1,Carvedilol,0.264200,3.21500,12.1700,KIT,0.264200,-1.331049,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
4,*1,Chlorpromazine,4.311000,41.24000,9.5670,KIT,4.311000,1.461170,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
...,...,...,...,...,...,...,...,...,...,...
286,*9,Sertraline,0.164000,0.20840,1.2680,KIT,0.164000,-1.807889,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
287,*9,Tamoxifen,0.058400,,,KIT,0.058400,-2.840439,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
288,*9,Thioridazine,0.047000,0.04493,0.9567,KIT,0.047000,-3.057608,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...
289,*9,Tolterodine,1.709041,1.87770,1.0987,KIST,1.709041,0.535933,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAVIVAIFLLLVDLMHRRQRWAARYPPGPLPLPGLGNL...


In [18]:
all_vars = cyp2d6[['Allele', 'vr_seqs', 'wt_seqs']].drop_duplicates()

In [19]:
wt_texts = all_vars['wt_seqs'].values
vr_texts = all_vars['vr_seqs'].values


In [20]:
#custom_collator = VarCollator()
custom_collator = MultiTaskVarCollator()

In [22]:
cyp2d6['Substrate'].unique()

array(['Amitriptyline', 'Aripiprazole', 'Atomoxetine', 'Carvedilol',
       'Chlorpromazine', 'Clomipramine', 'Clozapine', 'Dextromethorphan',
       'Doxepin', 'Dronedarone', 'Duloxetine', 'Fluoxetine', 'Gefitinib',
       'Imipramine', 'Meclizine', 'Mexiletine', 'Mirtazapine',
       'Nefazodone', 'Nortriptyline', 'Olanzapine', 'Ondansetron',
       'Paroxetine', 'Perhexiline', 'Perphenazine', 'Pimozide',
       'Primaquine', 'Promethazine', 'Propafenone', 'Ranolazine',
       'Ritonavir', 'Sertraline', 'Tamoxifen', 'Thioridazine',
       'Tolterodine', 'Venlafaxine'], dtype=object)

In [23]:
import torch


model.eval()
vr_feat_dict = {}
set_seed(0)
for index, item in all_vars.iterrows():
    wt_seq = item[2]
    vr_seq = item[1]
    wt_inputs = tokenizer(wt_seq, return_tensors="pt", add_special_tokens=True).to(device)
    vr_inputs = tokenizer(vr_seq, return_tensors="pt", add_special_tokens=True).to(device)
    
    with torch.no_grad():
        temp = model.am_feature(wt_inputs['input_ids'], wt_inputs['attention_mask'], vr_inputs['input_ids'], vr_inputs['attention_mask'])
        #temp = model.am_feature(vr_inputs['input_ids'], vr_inputs['attention_mask'])
        vr_feat_dict[item[0]] = temp[0].detach().to('cpu').numpy()

  wt_seq = item[2]
  vr_seq = item[1]
  vr_feat_dict[item[0]] = temp[0].detach().to('cpu').numpy()


In [None]:
import pickle
with open("data/CYP2D6_MT_KIT_ESM3_250624.pkl", "wb") as f:
    pickle.dump(vr_feat_dict, f)

In [1]:
import os
import sys
import pandas as pd
import numpy as np
from itertools import repeat

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_batch

from rdkit import Chem
from rdkit.Chem import AllChem


sys.path.append("./src/molebert")
from model import GNN, GNN_graphpred
from loader import mol_to_graph_data_obj_simple, allowable_features


In [2]:
import numpy as np
import pandas as pd
import os
import torch
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from src.utils.training_utils import set_seed
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
torch.use_deterministic_algorithms(True)

In [3]:
class GraphDataset(InMemoryDataset):
    def __init__(self, root, transform = None, pre_transform = None, pre_filter = None, empty=False, force_reload=True):
        self.data_path = root
        super(GraphDataset, self).__init__(root, transform, pre_transform, pre_filter)
        self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
        if not empty:
            self.data, self.slices = torch.load(self.processed_paths[0])
    
    def get(self,idx):
        data = Data()
        for key in self.data.keys():
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
            data[key] = item[s]

        return data
    
    @property
    def raw_file_names(self):
        file_name_list = os.listdir(self.raw_dir)
        return file_name_list
    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'
    def download(self):
        raise NotImplementedError('Must indicate valid location of raw data. No download allowed')
    
    def process(self):
        # 원본 데이터 로드
        input_df = pd.read_csv(f'{self.data_path}/raw/sub2smi.csv')
        smiles_list = input_df['SMILES'].tolist()
        rdkit_mol_objs = [AllChem.MolFromSmiles(s) for s in smiles_list]
        data_list = []
        data_smiles_list = []

        for i in range(len(smiles_list)):
            rdkit_mol = rdkit_mol_objs[i]
            data = mol_to_graph_data_obj_simple(rdkit_mol)
            data_list.append(data)
            data_smiles_list.append(smiles_list[i])

        # 데이터 및 슬라이스 저장
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [6]:
molebert_model = GNN_graphpred(num_layer=5, emb_dim=300, num_tasks=5, JK='last', drop_ratio=0.1, graph_pooling='mean',gnn_type='gin')

molebert_model.from_pretrained('./src/molebert/model_gin/Mole-BERT.pth')
device = "cuda"
molebert_model.to(device)

In [None]:
gdataset = GraphDataset('./src/molebert/dataset/cyp_ours_subs')

In [None]:
g_loader = DataLoader(gdataset, batch_size = 1, shuffle=False, drop_last=False)

In [None]:
device = "cuda"
molebert_model.to(device)
drug2smi = pd.read_csv("./src/molebert/dataset/cyp_ours_subs/raw/sub2smi.csv")

In [None]:

import numpy as np
mol_arr_dict = {}
i = 0
molebert_model.eval()
set_seed(0)
for g in g_loader:
    g.to(device)

    
    with torch.no_grad():
        node_rep = molebert_model.gnn(g['x'], g['edge_index'], g['edge_attr'])
        d_name = drug2smi.iloc[i, :]['Substrate']
        mol_arr_dict[d_name] = node_rep.detach().to('cpu').numpy()
        i+=1


In [None]:
import pickle
with open("data/MoleBERT_Substrate_NoPreMoleBERT_KIT_0624.pkl", "wb") as f:
    pickle.dump(mol_arr_dict, f)