In [11]:
# 安装必要的库
!pip install numpy pandas torch scikit-learn scipy prettytable rdkit matplotlib

!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch_geometric
!conda install cudatoolkit==11.1
!pip install --pre dgl==0.6.1 -f https://data.dgl.ai/wheels/repo.html -i  https://pypi.tuna.tsinghua.edu.cn/simple


# 可选：如果遇到 rdkit 安装问题，可以尝试以下命令
# !pip install -q condacolab
# import condacolab
# condacolab.install()
# !conda install -c conda-forge rdkit -y


Looking in links: https://download.pytorch.org/whl/torch_stable.html
[31mERROR: Could not find a version that satisfies the requirement torch==1.8.0+cu111 (from versions: 2.2.0, 2.2.0+cpu, 2.2.0+cpu.cxx11.abi, 2.2.0+cu118, 2.2.0+cu121, 2.2.0+rocm5.6, 2.2.0+rocm5.7, 2.2.1, 2.2.1+cpu, 2.2.1+cpu.cxx11.abi, 2.2.1+cu118, 2.2.1+cu121, 2.2.1+rocm5.6, 2.2.1+rocm5.7, 2.2.2, 2.2.2+cpu, 2.2.2+cpu.cxx11.abi, 2.2.2+cu118, 2.2.2+cu121, 2.2.2+rocm5.6, 2.2.2+rocm5.7, 2.3.0, 2.3.0+cpu, 2.3.0+cpu.cxx11.abi, 2.3.0+cu118, 2.3.0+cu121, 2.3.0+rocm5.7, 2.3.0+rocm6.0, 2.3.1, 2.3.1+cpu, 2.3.1+cpu.cxx11.abi, 2.3.1+cu118, 2.3.1+cu121, 2.3.1+rocm5.7, 2.3.1+rocm6.0, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0, 2.7.0, 2.7.1, 2.8.0)[0m[31m
[0m[31mERROR: No matching distribution found for torch==1.8.0+cu111[0m[31m
[0mCollecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m4.7 MB/s[0m eta [36m

In [3]:

# 导入必要的库
import numpy as np
import pandas as pd
import torch
import os
import time
import pickle
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from scipy.stats import pearsonr, spearmanr
from prettytable import PrettyTable
import rdkit
from rdkit.Chem import Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt


In [8]:
!git clone https://github.com/unsterbliche/DTLCDR.git
# 导入克隆的仓库中的 Python 文件
import sys
sys.path.append('/content/DTLCDR/DTLCDR')  # 添加仓库路径到系统路径
os.chdir('/content/DTLCDR/DTLCDR')

fatal: destination path 'DTLCDR' already exists and is not an empty directory.


In [17]:
!wget https://www.cancerrxgene.org/gdsc1000/GDSC1000_WebResources/Data/preprocessed/Cell_line_RMA_proc_basalExp.txt.zip


--2025-10-15 07:07:54--  https://www.cancerrxgene.org/gdsc1000/GDSC1000_WebResources/Data/preprocessed/Cell_line_RMA_proc_basalExp.txt.zip
Resolving www.cancerrxgene.org (www.cancerrxgene.org)... 193.62.203.105, 193.62.203.106, 2001:630:206:4::105, ...
Connecting to www.cancerrxgene.org (www.cancerrxgene.org)|193.62.203.105|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 142765551 (136M) [application/zip]
Saving to: ‘Cell_line_RMA_proc_basalExp.txt.zip’


2025-10-15 07:08:00 (29.2 MB/s) - ‘Cell_line_RMA_proc_basalExp.txt.zip’ saved [142765551/142765551]



In [18]:
!wget https://raw.githubusercontent.com/jingcheng-du/Gene2vec/master/pre_trained_emb/gene2vec_dim_200_iter_9.txt


--2025-10-15 07:08:13--  https://raw.githubusercontent.com/jingcheng-du/Gene2vec/master/pre_trained_emb/gene2vec_dim_200_iter_9.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 56682049 (54M) [text/plain]
Saving to: ‘gene2vec_dim_200_iter_9.txt’


2025-10-15 07:08:13 (496 MB/s) - ‘gene2vec_dim_200_iter_9.txt’ saved [56682049/56682049]



In [22]:

# 预设参数，替代原 argparse 传参方式
device = 'cuda:0'  # 可以改为 'cpu' 选择设备
epoch = 100
model_name = 'DTLCDR'  # 模型名称
split_type = 'warmstart'  # 数据拆分方式，可以选择 'warmstart', 'cellcoldstart', 'drugcoldstart'


In [9]:

# 细胞数据处理类
class process_cell():
    def __init__(self):
        # 初始化细胞表达和编码数据路径
        self.cell_encoder = './process_data/exp_enc.csv'
        self.cell_expression = './process_data/exp.csv'

    # 获取细胞数据
    def get_celldata(self, data):
        cellid = list(data['COSMIC_ID'])
        enc_data = pd.read_csv(self.cell_encoder, index_col=0)
        exp_data = pd.read_csv(self.cell_expression, index_col=0)
        encdata = enc_data.loc[cellid]
        expdata = exp_data.loc[cellid]
        return encdata, expdata


In [19]:

# 预测数据加载
pred_dti_gdsc2 = pd.read_csv('../GCADTI/pred_dti_gdsc.csv', index_col=0)
pred_dti = []
for i in range(len(pred_dti_gdsc2.smiles.unique())):
    pred_dti.append(pred_dti_gdsc2[i*1572:(i+1)*1572].label.tolist())
pred_dti = pd.DataFrame(pred_dti)
pred_dti.index = pred_dti_gdsc2.smiles.unique().tolist()

# 加载 Gene2Vec 数据
gene2vec_dim_200_iter_9 = pd.read_csv('./gene2vec_dim_200_iter_9.txt', sep='	| ', header=None)
gene2vec_dim_200_iter_9.index = gene2vec_dim_200_iter_9[0]
gene2vec_dim_200_iter_9 = gene2vec_dim_200_iter_9.loc[:, 1:]
gene = pd.read_csv('./process_data/exp_enc.csv', index_col=0)
col = gene.columns
np.save('gene2vec_595.npy', np.vstack([np.array(gene2vec_dim_200_iter_9.loc[col]), np.zeros([16906-595, 200])]))


  gene2vec_dim_200_iter_9 = pd.read_csv('./gene2vec_dim_200_iter_9.txt', sep='	| ', header=None)


In [21]:

# 模型字典初始化
modeldict = {'DTLCDR': DTLCDR, 'DTLCDR_cellenc': DTLCDR_cellenc, 'DTLCDR_cellexp': DTLCDR_cellexp, 'DTLCDR_drugdti': DTLCDR_drugdti,
             'DTLCDR_drugGIN': DTLCDR_drugGIN, 'DTLCDR_drugdesc': DTLCDR_drugdesc}


NameError: name 'DTLCDR' is not defined

In [None]:

# 模型类
class Model:
    def __init__(self, modeldir, model, kfold, device, epoch):
        self.model = modeldict[model](pred_dti.shape[1], 596, './gene2vec_595.npy', 3000)
        self.model._build()
        self.device = torch.device(device)
        self.modeldir = modeldir
        self.kfold = kfold
        self.epoch = epoch
        self.record_fileval = os.path.join(self.modeldir, "valid_markdowntable.txt")
        self.record_filetest = os.path.join(self.modeldir, str(self.kfold) + '.txt')
        self.pkl_file = os.path.join(self.modeldir, "loss_curve_iter.pkl")
        self.val_pkl_file = os.path.join(self.modeldir, "val_loss_curve_iter.pkl")
        self.val_loss_history = []

    # 测试函数，返回预测结果与真实值的相关统计
    def test(self, datagenerator, model, mode=None):
        y_label = []
        y_pred = []
        loss_s = 0
        model.eval()

        for i, data in enumerate(datagenerator):
            drug_graph = data[0].to(self.device)
            drug_dti = data[1].to(self.device)
            drug_desc = data[2].to(self.device)
            cell_enc = data[3].to(self.device)
            cell_exp = data[4].to(self.device)
            label = data[5].to(self.device)

            score = model(drug_graph, drug_dti, drug_desc, cell_enc, cell_exp)
            loss_fct = torch.nn.MSELoss()
            score = torch.squeeze(score, 1)
            loss = loss_fct(score, label)
            self.val_loss_history.append(loss.item())
            logits = score.detach().cpu().numpy()
            label_ids = label.to('cpu').numpy()
            y_label += label_ids.flatten().tolist()
            y_pred += logits.flatten().tolist()
            loss_s += loss

        loss_m = loss_s / (i + 1)
        pcc = pearsonr(y_label, y_pred)[0]
        spm = spearmanr(y_label, y_pred)[0]
        mse = mean_squared_error(y_label, y_pred)
        r2 = r2_score(y_label, y_pred)
        mae = mean_absolute_error(y_label, y_pred)
        model.train()

        if mode == 'val':
            return loss_m, np.sqrt(mse), mse, pcc, spm, r2, mae
        elif mode == 'test':
            return loss_m, np.sqrt(mse), mse, pcc, spm, r2, mae
        elif mode == 'predict':
            return y_label, y_pred


In [None]:

# 训练函数
    def train(self, trainset, valset, testset):
        lr = 1e-4
        BATCH_SIZE = 64
        train_epoch = self.epoch
        self.model = self.model.to(self.device)
        opt = torch.optim.Adam(self.model.parameters(), lr=lr)
        loss_history = []

        collate_fn = DRPCollator()
        trainparams = {'batch_size': BATCH_SIZE, 'shuffle': True, 'num_workers': 4, 'drop_last': True, 'pin_memory': True, "collate_fn": collate_fn}
        training_generator = DataLoader(trainset, **trainparams)
        valtestparams = {'batch_size': BATCH_SIZE, 'shuffle': False, 'num_workers': 4, 'drop_last': False, 'pin_memory': True, "collate_fn": collate_fn}
        validation_generator = DataLoader(valset, **valtestparams)
        testing_generator = DataLoader(testset, **valtestparams)

        best_loss = 10000
        model_max = copy.deepcopy(self.model)

        valid_metric_record = []
        valid_metric_header = ['# epoch', "loss", 'rmse_val', 'mse_val', 'pcc_val', 'spm_val', 'r2_val', 'mae_val']

        # 训练循环
        for epo in range(train_epoch):
            for i, data in enumerate(training_generator):
                drug_graph = data[0].to(self.device)
                drug_dti = data[1].to(self.device)
                drug_desc = data[2].to(self.device)
                cell_enc = data[3].to(self.device)
                cell_exp = data[4].to(self.device)
                label = data[5].to(self.device)

                score = self.model(drug_graph, drug_dti, drug_desc, cell_enc, cell_exp)
                loss_fct = torch.nn.MSELoss()
                score = torch.squeeze(score, 1)
                loss = loss_fct(score, label)
                loss_history.append(loss.item())
                opt.zero_grad()
                loss.backward()
                opt.step()

                if i % 1000 == 0:
                    print(f'Epoch {epo + 1}, Iteration {i}, Loss: {loss.item()}')

            with torch.no_grad():
                loss_val, rmse_val, mse_val, pcc_val, spm_val, r2_val, mae_val = self.test(validation_generator, self.model, mode='val')
                valid_metric_record.append([f"Epoch {epo + 1}", loss_val, rmse_val, mse_val, pcc_val, spm_val, r2_val, mae_val])

            if loss_val <= best_loss:
                model_max = copy.deepcopy(self.model)
                best_loss = loss_val

        self.model = model_max
        print('Training finished')
        return valid_metric_record


In [None]:

# 预测函数
    def predict(self, dataset):
        self.model = self.model.to(self.device)
        collate_fn = DRPCollator()
        params = {'batch_size': 128, 'shuffle': False, 'num_workers': 2, 'drop_last': False, "collate_fn": collate_fn}
        generator = DataLoader(dataset, **params)
        y_label, y_pred = self.test(generator, self.model, mode='predict')
        return y_label, y_pred

# 执行示例
# 预备训练数据集
modeldir = './model_warmstart/'
if not os.path.exists(modeldir):
    os.mkdir(modeldir)

# 初始化并训练模型
net = Model(modeldir=modeldir, model=model_name, kfold=1, device=device, epoch=epoch)
valid_metrics = net.train(train_set, val_set, test_set)

# 打印训练结果
for metric in valid_metrics:
    print(metric)
