In [1]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch.optim import Adam
import networkx as nx
import obonet
import pandas as pd

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
GO_graph = obonet.read_obo("GNN/go-basic.obo")

go_edges = []
for u, v, data in GO_graph.edges(data=True):
    go_edges.append([u, v])
go_edges_df = pd.DataFrame(go_edges, columns=['Source', 'Target']).dropna()
print(go_edges_df)
col_name = ['GO']
for i in range(1,769):
  col_name.append('feature'+str(i))
go_features_df = pd.read_csv("GNN/go_terms_embeddings.csv", skiprows=1, names=col_name).dropna()
#remove solubility
"""
with open('GNN/soluble_go_terms.txt', 'r') as file:
    soluble_go_terms = file.read().splitlines()

# 删除‘GO’列里名字在soluble_go_terms.txt中的行
go_features_df = go_features_df[~go_features_df['GO'].isin(soluble_go_terms)]

# 保存结果到新的CSV文件（可选）
go_features_df.to_csv("GNN/remove_solubility_go_terms_embeddings.csv", index=False)"""
print(go_features_df.head())

           Source      Target
0      GO:0000001  GO:0048308
1      GO:0000001  GO:0048311
2      GO:0000002  GO:0007005
3      GO:0000003  GO:0008150
4      GO:0000006  GO:0005385
...           ...         ...
83792  GO:2001317  GO:0034309
83793  GO:2001317  GO:0042181
83794  GO:2001317  GO:0120255
83795  GO:2001317  GO:1901362
83796  GO:2001317  GO:2001316

[83797 rows x 2 columns]
           GO  feature1  feature2  feature3  feature4  feature5  feature6  \
0  GO:0000001 -1.168093 -0.355214  0.265877 -0.710051  0.515028 -0.525165   
1  GO:0000002 -1.185879 -0.098765  0.388240 -0.295556  0.327296 -0.119842   
2  GO:0000003  0.063323 -0.199995  0.151511 -0.942141  0.109313  0.015316   
3  GO:0000005  0.163135  0.301527  0.219680  0.094342 -0.129769  0.225696   
4  GO:0000006 -0.641113 -0.541363  0.413941  0.699345  0.461507 -0.497388   

   feature7  feature8  feature9  ...  feature759  feature760  feature761  \
0 -0.186588 -0.161192  0.186984  ...   -1.350874   -0.991801   -0.648123   

In [3]:
col_name = ['protein']
for i in range(1,769):
  col_name.append('feature'+str(i))
gene_features_df = pd.read_csv('GNN/gene_embedding_GeneLLM_2.csv', header=None, names=col_name).dropna()
print(gene_features_df)

col_name = ['Target', 'Source']
go_protein_df = pd.read_csv(
    "GNN/mart_export.txt", 
    skiprows=1, 
    names=col_name, 
    usecols=[1, 2]  # 使用列的索引来指定
).dropna()
print(go_protein_df)

      protein  feature1  feature2  feature3  feature4  feature5  feature6  \
0         FES  0.339602 -0.030744 -0.901381  0.100888  0.886443  0.383596   
1      HADHA  -0.131799 -0.025745 -0.677301 -0.053545  0.971046  0.180315   
2      SLC7A7  0.385693 -0.070692 -0.847796 -0.022054  0.959772  0.085487   
3        LCK   0.650428  0.014479 -0.866163  0.053508  0.951529  0.269402   
4       HSPA2  0.322262  0.017484 -0.849302  0.046401  0.920429  0.463832   
...       ...       ...       ...       ...       ...       ...       ...   
14445   BPY2C -0.840158 -0.042814 -0.853394 -0.049438  0.943925  0.104337   
14446    CLPS -0.270716 -0.036871 -0.915350 -0.013635  0.972046  0.016017   
14447    DNER  0.228932 -0.033579 -0.907262  0.010446  0.961684  0.524211   
14448    SOX7  0.140491  0.033339 -0.806014 -0.072016  0.938781  0.339959   
14449  CXCL14 -0.570266 -0.011502 -0.741149 -0.096209  0.967244  0.426519   

       feature7  feature8  feature9  ...  feature759  feature760  feature76

In [4]:
print(len(go_features_df))
go_features_df.rename(columns={'GO': 'protein'}, inplace=True)
combined_features = pd.concat([gene_features_df, go_features_df])
combined_features

47595


Unnamed: 0,protein,feature1,feature2,feature3,feature4,feature5,feature6,feature7,feature8,feature9,...,feature759,feature760,feature761,feature762,feature763,feature764,feature765,feature766,feature767,feature768
0,FES,0.339602,-0.030744,-0.901381,0.100888,0.886443,0.383596,-0.192082,-0.032063,-0.154869,...,-0.549204,-0.856123,0.714672,-0.046649,-0.894424,-0.001815,0.739485,0.015581,-0.023863,-0.022002
1,HADHA,-0.131799,-0.025745,-0.677301,-0.053545,0.971046,0.180315,-0.028189,-0.077389,-0.095152,...,0.927885,-0.817812,0.809631,-0.005827,-0.848839,0.024516,0.526404,-0.039926,-0.102787,-0.026980
2,SLC7A7,0.385693,-0.070692,-0.847796,-0.022054,0.959772,0.085487,0.076455,-0.003006,-0.032268,...,0.941094,-0.912443,0.789828,0.046979,-0.715636,0.085842,0.150494,0.025392,-0.066035,-0.028283
3,LCK,0.650428,0.014479,-0.866163,0.053508,0.951529,0.269402,-0.214788,0.045179,-0.506429,...,-0.576739,-0.969558,0.916549,-0.080332,-0.927649,-0.047398,0.741663,-0.000096,-0.096318,-0.056501
4,HSPA2,0.322262,0.017484,-0.849302,0.046401,0.920429,0.463832,-0.050414,-0.033398,0.387791,...,0.387301,-0.860696,0.678607,-0.060695,-0.945793,0.040472,0.831079,-0.001711,-0.079842,-0.011189
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47590,GO:2001313,0.174428,0.194728,-0.284376,0.282102,-0.713190,-0.272055,0.121190,0.129901,-0.983496,...,0.500545,0.429651,-0.292929,-0.464941,-0.740187,0.179149,-0.960807,-0.746958,1.069112,-0.848182
47591,GO:2001314,0.025886,0.306214,-0.254303,0.253673,-0.533680,-0.269355,0.150939,-0.229323,-1.078991,...,0.042979,0.134560,-0.356661,-0.381828,-0.638338,0.077176,-0.788312,-0.683442,1.087031,-0.593092
47592,GO:2001315,0.027134,0.241391,-0.227353,0.317366,-0.726657,-0.197968,0.045653,0.038912,-0.954113,...,0.349853,0.370059,-0.144606,-0.493184,-0.655063,0.217335,-0.841272,-0.821077,1.036363,-0.836614
47593,GO:2001316,0.139543,0.028883,0.899480,0.152932,0.576852,0.330342,0.916943,0.012306,-0.020316,...,-0.354748,-0.083168,0.043640,-0.663565,0.543016,-0.652230,-1.427882,-0.985257,1.673561,0.109659


In [5]:
gene_edges_df = pd.read_csv('GNN/protein_interactions.csv', usecols=[0, 1], names=col_name).dropna()

In [6]:
combined_edges = pd.concat([go_protein_df, go_edges_df, gene_edges_df])
#combined_edges = combined_edges[['Source', 'Target']]
combined_edges

Unnamed: 0,Target,Source
0,MT-TF,GO:0030533
1,MT-TF,GO:0006412
4,MT-RNR2,GO:0003735
5,MT-RNR2,GO:0005840
6,MT-TL1,GO:0030533
...,...,...
13715124,LDB1,SAMD14
13715125,LDB1,KDM6B
13715126,LDB1,WWP2
13715127,LDB1,VPS33B


In [7]:
nodes_in_features = set(combined_features['protein'])

filtered_edges_df = combined_edges[
    combined_edges['Source'].isin(nodes_in_features) & combined_edges['Target'].isin(nodes_in_features)
]


In [8]:
node_id_to_index = {node_id: i for i, node_id in enumerate(combined_features['protein'])}
# 确保edge_index是按照这个新的索引顺序排列的
source_indices = [node_id_to_index[node_id] for node_id in filtered_edges_df['Source']]
target_indices = [node_id_to_index[node_id] for node_id in filtered_edges_df['Target']]
edge_index = torch.tensor([source_indices, target_indices], dtype=torch.long)
edge_index

tensor([[24323, 17460, 27132,  ...,   947,  6874, 13222],
        [ 2077,  2077,  2077,  ..., 10107, 10107, 10107]])

In [9]:
labels_df = pd.read_csv('GNN/phastcons.csv').dropna()
#print(original_labels_df['Conservation'][587])
#original_labels_df['Conservation'][587] = 0
#labels_df = original_labels_df[:588]
#labels_df.rename(columns={'Gene name': 'protein'}, inplace=True)
labels_df.rename(columns={'GeneSymbol': 'protein'}, inplace=True)
labels_df.rename(columns={'Conservation': 'Label'}, inplace=True)
print(len(labels_df))

30467


In [10]:
labels_df.dtypes

protein     object
Label      float64
dtype: object

In [11]:
combined_features

Unnamed: 0,protein,feature1,feature2,feature3,feature4,feature5,feature6,feature7,feature8,feature9,...,feature759,feature760,feature761,feature762,feature763,feature764,feature765,feature766,feature767,feature768
0,FES,0.339602,-0.030744,-0.901381,0.100888,0.886443,0.383596,-0.192082,-0.032063,-0.154869,...,-0.549204,-0.856123,0.714672,-0.046649,-0.894424,-0.001815,0.739485,0.015581,-0.023863,-0.022002
1,HADHA,-0.131799,-0.025745,-0.677301,-0.053545,0.971046,0.180315,-0.028189,-0.077389,-0.095152,...,0.927885,-0.817812,0.809631,-0.005827,-0.848839,0.024516,0.526404,-0.039926,-0.102787,-0.026980
2,SLC7A7,0.385693,-0.070692,-0.847796,-0.022054,0.959772,0.085487,0.076455,-0.003006,-0.032268,...,0.941094,-0.912443,0.789828,0.046979,-0.715636,0.085842,0.150494,0.025392,-0.066035,-0.028283
3,LCK,0.650428,0.014479,-0.866163,0.053508,0.951529,0.269402,-0.214788,0.045179,-0.506429,...,-0.576739,-0.969558,0.916549,-0.080332,-0.927649,-0.047398,0.741663,-0.000096,-0.096318,-0.056501
4,HSPA2,0.322262,0.017484,-0.849302,0.046401,0.920429,0.463832,-0.050414,-0.033398,0.387791,...,0.387301,-0.860696,0.678607,-0.060695,-0.945793,0.040472,0.831079,-0.001711,-0.079842,-0.011189
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47590,GO:2001313,0.174428,0.194728,-0.284376,0.282102,-0.713190,-0.272055,0.121190,0.129901,-0.983496,...,0.500545,0.429651,-0.292929,-0.464941,-0.740187,0.179149,-0.960807,-0.746958,1.069112,-0.848182
47591,GO:2001314,0.025886,0.306214,-0.254303,0.253673,-0.533680,-0.269355,0.150939,-0.229323,-1.078991,...,0.042979,0.134560,-0.356661,-0.381828,-0.638338,0.077176,-0.788312,-0.683442,1.087031,-0.593092
47592,GO:2001315,0.027134,0.241391,-0.227353,0.317366,-0.726657,-0.197968,0.045653,0.038912,-0.954113,...,0.349853,0.370059,-0.144606,-0.493184,-0.655063,0.217335,-0.841272,-0.821077,1.036363,-0.836614
47593,GO:2001316,0.139543,0.028883,0.899480,0.152932,0.576852,0.330342,0.916943,0.012306,-0.020316,...,-0.354748,-0.083168,0.043640,-0.663565,0.543016,-0.652230,-1.427882,-0.985257,1.673561,0.109659


In [12]:
labels_df = labels_df[
    labels_df['protein'].isin(nodes_in_features)]
print(len(labels_df))
labels_df = labels_df.reset_index(drop=True)

label_indices = [node_id_to_index[node_id] for node_id in labels_df['protein']]
print(len(label_indices))
num_nodes = len(combined_features)
labels = torch.full((num_nodes,), -1, dtype=torch.float)
for i, index in enumerate(labels_df['Label']):
    labels[label_indices[i]] = index

labels_tensor = torch.tensor(labels)
print(labels_tensor)
print(labels)

13377
13377
tensor([ 0.6410, -1.0000,  0.4188,  ..., -1.0000, -1.0000, -1.0000])
tensor([ 0.6410, -1.0000,  0.4188,  ..., -1.0000, -1.0000, -1.0000])


  labels_tensor = torch.tensor(labels)


In [13]:
features = combined_features.iloc[:, 1:].values
features_tensor = torch.tensor(features, dtype=torch.float)
features

array([[ 0.33960226, -0.03074448, -0.90138096, ...,  0.01558092,
        -0.02386307, -0.02200161],
       [-0.13179901, -0.02574519, -0.67730105, ..., -0.03992649,
        -0.10278717, -0.02697964],
       [ 0.38569278, -0.07069244, -0.8477959 , ...,  0.0253919 ,
        -0.06603534, -0.02828273],
       ...,
       [ 0.02713387,  0.24139147, -0.22735251, ..., -0.82107705,
         1.036363  , -0.83661443],
       [ 0.13954346,  0.02888298,  0.89947975, ..., -0.9852566 ,
         1.6735605 ,  0.10965873],
       [ 0.08306409,  0.09089889,  0.8885408 , ..., -0.79955566,
         1.5193683 ,  0.2632099 ]])

In [14]:
labels_tensor

tensor([ 0.6410, -1.0000,  0.4188,  ..., -1.0000, -1.0000, -1.0000])

In [15]:
from torch_geometric.data import Data
y = torch.rand((62045), dtype=torch.float)
y[0] = 0
data = Data(x=features_tensor, edge_index=edge_index, y=labels_tensor)#

print("x:", data.x.shape, data.x.dtype)
print("edge_index:", data.edge_index.shape, data.edge_index.dtype)
print("labels:", data.y.shape, data.y.dtype)

x: torch.Size([62045, 768]) torch.float32
edge_index: torch.Size([2, 9914754]) torch.int64
labels: torch.Size([62045]) torch.float32


In [16]:
class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

In [18]:
import numpy as np
known_mask = data.y != -1.0
known_mask
known_indices = np.where(known_mask)[0]
np.random.shuffle(known_indices)

num_train = int(0.8 * len(known_indices))  # 计算80%的数量
train_indices = known_indices[:num_train]
test_indices = known_indices[num_train:]
train_mask = torch.zeros_like(data.y, dtype=torch.bool)
test_mask = torch.zeros_like(data.y, dtype=torch.bool)

train_mask[train_indices] = True
test_mask[test_indices] = True


In [25]:

import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def train(data, train_mask):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out.squeeze()[train_mask], data.y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

labeled_indices = label_indices
random.shuffle(labeled_indices)
num_labeled = len(labeled_indices)
num_train = int(num_labeled * 0.8)
num_test = num_labeled - num_train
print(num_test)

# 创建训练和测试掩码
train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

train_mask[labeled_indices[:num_train]] = True
test_mask[labeled_indices[num_train:num_train+num_test]] = True
print(test_mask)

from scipy.stats import pearsonr
def evaluate_model(model, data, known_mask):
    model.eval()
    with torch.no_grad():
        predictions = model(data.x, data.edge_index).squeeze()
        # 选择已知标签的预测结果和真实标签
        predictions_known = predictions[known_mask]
        labels_known = data.y[known_mask]
        
        # 计算MSE和MAE
        mse = F.mse_loss(predictions_known, labels_known)
        mae = F.l1_loss(predictions_known, labels_known)
        
        # 计算Pearson相关系数
        if len(predictions_known) > 1 and torch.isfinite(predictions_known).all() and torch.isfinite(labels_known).all():
            pearson_corr, _ = pearsonr(predictions_known.cpu().numpy(), labels_known.cpu().numpy())
        else:
            pearson_corr = float('nan')  # 如果数据无效或只有一个数据点，则返回nan

    return mse.item(), mae.item(), pearson_corr


# 训练循环
for i in range(1):
    result = []
    model = GCN(num_features=768, hidden_dim=64, output_dim=1).to(device)
    data = data.to(device)
    optimizer = Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.MSELoss()
    for epoch in range(1000):
        loss = train(data, train_mask)
        if epoch % 50 == 0:
            mse, mae, p_corr = evaluate_model(model, data, test_mask)
            print(f'Epoch {epoch}: Loss {loss:.4f}, MSE: {mse:.4f}, MAE: {mae:.4f}, pearson_corr: {p_corr:.4f}')
    result.append({'mse': mse,
                    'mae': mae,
                    'pearson_corr': p_corr}) 
print(result)

2676
tensor([False, False, False,  ..., False, False, False])
Epoch 0: Loss 1.0575, MSE: 65.5424, MAE: 7.6644, pearson_corr: 0.2902
Epoch 50: Loss 0.0466, MSE: 0.0496, MAE: 0.1802, pearson_corr: 0.2952
Epoch 100: Loss 0.0399, MSE: 0.0413, MAE: 0.1642, pearson_corr: 0.3682
Epoch 150: Loss 0.0368, MSE: 0.0382, MAE: 0.1578, pearson_corr: 0.4008
Epoch 200: Loss 0.0335, MSE: 0.0352, MAE: 0.1513, pearson_corr: 0.4368
Epoch 250: Loss 0.0312, MSE: 0.0337, MAE: 0.1473, pearson_corr: 0.4580
Epoch 300: Loss 0.0294, MSE: 0.0325, MAE: 0.1435, pearson_corr: 0.4835
Epoch 350: Loss 0.0290, MSE: 0.0321, MAE: 0.1431, pearson_corr: 0.4943
Epoch 400: Loss 0.0278, MSE: 0.0312, MAE: 0.1404, pearson_corr: 0.5020
Epoch 450: Loss 0.0276, MSE: 0.0311, MAE: 0.1395, pearson_corr: 0.5067
Epoch 500: Loss 0.0267, MSE: 0.0310, MAE: 0.1410, pearson_corr: 0.5103
Epoch 550: Loss 0.0261, MSE: 0.0302, MAE: 0.1377, pearson_corr: 0.5149
Epoch 600: Loss 0.0258, MSE: 0.0300, MAE: 0.1373, pearson_corr: 0.5182
Epoch 650: Loss 0

In [23]:
print(result)

[{'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 63.39752197265625, 'mae': 7.54330587387085, 'pearson_corr': 0.24959241099269375}, {'mse': 6