In [1]:
import anndata as ad
import gc
import sys
from scipy.sparse import csc_matrix
from sklearn.model_selection import train_test_split
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LinearRegression
import random
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
import pickle
import heapq
import matplotlib
import os
from tqdm.notebook import tqdm_notebook

In [2]:
import gzip
import dgl
import pickle
import pandas as pd
import scipy.sparse as sp
import dgl.nn.pytorch as dglnn
from dgl.nn import GraphConv
import concurrent.futures
import multiprocessing
from dgl.data.utils import save_graphs
from io import StringIO
from functools import partial
from sklearn.metrics.pairwise import euclidean_distances

Using backend: pytorch


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

batch_size = 50000
pred_num = 134

cuda


In [4]:
train_mod1_file = 'phase2_private_data/predict_modality/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_train_mod1.h5ad'
train_mod2_file = 'phase2_private_data/predict_modality/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_train_mod2.h5ad'
test_mod1_file = 'phase2_private_data/predict_modality/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_test_mod1.h5ad'
test_mod2_file = 'phase2_private_data/predict_modality/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_test_mod2.h5ad'

In [5]:
input_train_mod1 = ad.read_h5ad(train_mod1_file)
input_train_mod2 = ad.read_h5ad(train_mod2_file)
input_test_mod1 = ad.read_h5ad(test_mod1_file)
input_test_mod2 = ad.read_h5ad(test_mod2_file)

In [6]:
print(input_train_mod1)
print(input_train_mod2)
print(input_test_mod1)
print(input_test_mod2)

AnnData object with n_obs × n_vars = 66175 × 13953
    obs: 'batch', 'size_factors'
    var: 'gene_ids', 'feature_types'
    uns: 'dataset_id', 'organism'
    layers: 'counts'
AnnData object with n_obs × n_vars = 66175 × 134
    obs: 'batch', 'size_factors'
    var: 'feature_types'
    uns: 'dataset_id', 'organism'
    layers: 'counts'
AnnData object with n_obs × n_vars = 1000 × 13953
    obs: 'batch', 'size_factors'
    var: 'gene_ids', 'feature_types'
    uns: 'dataset_id', 'organism'
    layers: 'counts'
AnnData object with n_obs × n_vars = 1000 × 134
    obs: 'batch', 'size_factors'
    var: 'feature_types'
    uns: 'dataset_id', 'organism'
    layers: 'counts'


In [7]:
print(sorted(set(input_train_mod1.obs['batch'])))
print(sorted(set(input_train_mod2.obs['batch'])))

['s1d1', 's1d2', 's1d3', 's2d1', 's2d4', 's2d5', 's3d1', 's3d6', 's3d7']
['s1d1', 's1d2', 's1d3', 's2d1', 's2d4', 's2d5', 's3d1', 's3d6', 's3d7']


In [8]:
print(sorted(set(input_test_mod1.obs['batch'])))
print(sorted(set(input_test_mod2.obs['batch'])))

['s4d1', 's4d8', 's4d9']
['s4d1', 's4d8', 's4d9']


In [9]:
RNA_s1d1 = input_train_mod1[input_train_mod1.obs["batch"] == "s1d1", :]
RNA_s1d2 = input_train_mod1[input_train_mod1.obs["batch"] == "s1d2", :]
RNA_s1d3 = input_train_mod1[input_train_mod1.obs["batch"] == "s1d3", :]
RNA_s2d1 = input_train_mod1[input_train_mod1.obs["batch"] == "s2d1", :]
RNA_s2d4 = input_train_mod1[input_train_mod1.obs["batch"] == "s2d4", :]
RNA_s2d5 = input_train_mod1[input_train_mod1.obs["batch"] == "s2d5", :]

RNA_s3d1 = input_train_mod1[input_train_mod1.obs["batch"] == "s3d1", :]
RNA_s3d6 = input_train_mod1[input_train_mod1.obs["batch"] == "s3d6", :]
RNA_s3d7 = input_train_mod1[input_train_mod1.obs["batch"] == "s3d7", :]

In [10]:
pro_s1d1 = input_train_mod2[input_train_mod2.obs["batch"] == "s1d1", :]
pro_s1d2 = input_train_mod2[input_train_mod2.obs["batch"] == "s1d2", :]
pro_s1d3 = input_train_mod2[input_train_mod2.obs["batch"] == "s1d3", :]
pro_s2d1 = input_train_mod2[input_train_mod2.obs["batch"] == "s2d1", :]
pro_s2d4 = input_train_mod2[input_train_mod2.obs["batch"] == "s2d4", :]
pro_s2d5 = input_train_mod2[input_train_mod2.obs["batch"] == "s2d5", :]

pro_s3d1 = input_train_mod2[input_train_mod2.obs["batch"] == "s3d1", :]
pro_s3d6 = input_train_mod2[input_train_mod2.obs["batch"] == "s3d6", :]
pro_s3d7 = input_train_mod2[input_train_mod2.obs["batch"] == "s3d7", :]

In [11]:
RNA_s4d1 = input_test_mod1[input_test_mod1.obs["batch"] == "s4d1", :]
RNA_s4d8 = input_test_mod1[input_test_mod1.obs["batch"] == "s4d8", :]
RNA_s4d9 = input_test_mod1[input_test_mod1.obs["batch"] == "s4d9", :]

In [12]:
pro_s4d1 = input_test_mod2[input_test_mod2.obs["batch"] == "s4d1", :]
pro_s4d8 = input_test_mod2[input_test_mod2.obs["batch"] == "s4d8", :]
pro_s4d9 = input_test_mod2[input_test_mod2.obs["batch"] == "s4d9", :]

In [13]:
RNA_s1d1 = RNA_s1d1.X.toarray()
RNA_s1d2 = RNA_s1d2.X.toarray()
RNA_s1d3 = RNA_s1d3.X.toarray()
RNA_s2d1 = RNA_s2d1.X.toarray()
RNA_s2d4 = RNA_s2d4.X.toarray()
RNA_s2d5 = RNA_s2d5.X.toarray()

RNA_s3d1 = RNA_s3d1.X.toarray()
RNA_s3d6 = RNA_s3d6.X.toarray()
RNA_s3d7 = RNA_s3d7.X.toarray()

In [14]:
pro_s1d1 = pro_s1d1.X.toarray()
pro_s1d2 = pro_s1d2.X.toarray()
pro_s1d3 = pro_s1d3.X.toarray()
pro_s2d1 = pro_s2d1.X.toarray()
pro_s2d4 = pro_s2d4.X.toarray()
pro_s2d5 = pro_s2d5.X.toarray()

pro_s3d1 = pro_s3d1.X.toarray()
pro_s3d6 = pro_s3d6.X.toarray()
pro_s3d7 = pro_s3d7.X.toarray()

In [15]:
RNA_s4d1 = RNA_s4d1.X.toarray()
RNA_s4d8 = RNA_s4d8.X.toarray()
RNA_s4d9 = RNA_s4d9.X.toarray()

In [16]:
pro_s4d1 = pro_s4d1.X.toarray()
pro_s4d8 = pro_s4d8.X.toarray()
pro_s4d9 = pro_s4d9.X.toarray()

In [17]:
print(RNA_s1d1.shape)
print(pro_s1d1.shape)

(4721, 13953)
(4721, 134)


In [18]:
print(RNA_s4d1.shape)
print(pro_s4d1.shape)

(334, 13953)
(334, 134)


In [19]:
batch_s1d1 = input_train_mod1[input_train_mod1.obs["batch"] == "s1d1", :].obs["batch"]
batch_s1d2 = input_train_mod1[input_train_mod1.obs["batch"] == "s1d2", :].obs["batch"]
batch_s1d3 = input_train_mod1[input_train_mod1.obs["batch"] == "s1d3", :].obs["batch"]
batch_s2d1 = input_train_mod1[input_train_mod1.obs["batch"] == "s2d1", :].obs["batch"]
batch_s2d4 = input_train_mod1[input_train_mod1.obs["batch"] == "s2d4", :].obs["batch"]
batch_s2d5 = input_train_mod1[input_train_mod1.obs["batch"] == "s2d5", :].obs["batch"]

batch_s3d1 = input_train_mod1[input_train_mod1.obs["batch"] == "s3d1", :].obs["batch"]
batch_s3d6 = input_train_mod1[input_train_mod1.obs["batch"] == "s3d6", :].obs["batch"]
batch_s3d7 = input_train_mod1[input_train_mod1.obs["batch"] == "s3d7", :].obs["batch"]

batch_s4d1 = input_test_mod1[input_test_mod1.obs["batch"] == "s4d1", :].obs["batch"]
batch_s4d8 = input_test_mod1[input_test_mod1.obs["batch"] == "s4d8", :].obs["batch"]
batch_s4d9 = input_test_mod1[input_test_mod1.obs["batch"] == "s4d9", :].obs["batch"]

In [20]:
batch_train = list(batch_s1d1)+list(batch_s1d2)+list(batch_s1d3)+list(batch_s2d1)+list(batch_s2d4)+list(batch_s2d5)
batch_val = list(batch_s3d1)+list(batch_s3d6)+list(batch_s3d7)
batch_test = list(batch_s4d1)+list(batch_s4d8)+list(batch_s4d9)

In [21]:
train_input = [RNA_s1d1, RNA_s1d2, RNA_s1d3, RNA_s2d1, RNA_s2d4, RNA_s2d5]
train_output = [pro_s1d1, pro_s1d2, pro_s1d3, pro_s2d1, pro_s2d4, pro_s2d5]

val_input = [RNA_s3d1, RNA_s3d6, RNA_s3d7]
val_output = [pro_s3d1, pro_s3d6, pro_s3d7]

test_input = [RNA_s4d1, RNA_s4d8, RNA_s4d9]
test_output = [pro_s4d1, pro_s4d8, pro_s4d9]

In [22]:
train_input = train_input + val_input
train_output = train_output + val_output

In [23]:
train_input = np.concatenate(train_input, axis=0)
val_input = np.concatenate(val_input, axis=0)
test_input = np.concatenate(test_input, axis=0)

train_output = np.concatenate(train_output, axis=0)
val_output = np.concatenate(val_output, axis=0)
test_output = np.concatenate(test_output, axis=0)

In [24]:
RNA_s1d1 = torch.from_numpy(RNA_s1d1)
RNA_s1d2 = torch.from_numpy(RNA_s1d2)
RNA_s1d3 = torch.from_numpy(RNA_s1d3)
RNA_s2d1 = torch.from_numpy(RNA_s2d1)
RNA_s2d4 = torch.from_numpy(RNA_s2d4)
RNA_s2d5 = torch.from_numpy(RNA_s2d5)
RNA_s3d1 = torch.from_numpy(RNA_s3d1)
RNA_s3d6 = torch.from_numpy(RNA_s3d6)
RNA_s3d7 = torch.from_numpy(RNA_s3d7)
RNA_s4d1 = torch.from_numpy(RNA_s4d1)
RNA_s4d8 = torch.from_numpy(RNA_s4d8)
RNA_s4d9 = torch.from_numpy(RNA_s4d9)

In [25]:
pro_s1d1 = torch.from_numpy(pro_s1d1)
pro_s1d2 = torch.from_numpy(pro_s1d2)
pro_s1d3 = torch.from_numpy(pro_s1d3)
pro_s2d1 = torch.from_numpy(pro_s2d1)
pro_s2d4 = torch.from_numpy(pro_s2d4)
pro_s2d5 = torch.from_numpy(pro_s2d5)
pro_s3d1 = torch.from_numpy(pro_s3d1)
pro_s3d6 = torch.from_numpy(pro_s3d6)
pro_s3d7 = torch.from_numpy(pro_s3d7)
pro_s4d1 = torch.from_numpy(pro_s4d1)
pro_s4d8 = torch.from_numpy(pro_s4d8)
pro_s4d9 = torch.from_numpy(pro_s4d9)

In [26]:
RNA_s1d1 = RNA_s1d1.float()
RNA_s1d2 = RNA_s1d2.float()
RNA_s1d3 = RNA_s1d3.float()
RNA_s2d1 = RNA_s2d1.float()
RNA_s2d4 = RNA_s2d4.float()
RNA_s2d5 = RNA_s2d5.float()
RNA_s3d1 = RNA_s3d1.float()
RNA_s3d6 = RNA_s3d6.float()
RNA_s3d7 = RNA_s3d7.float()
RNA_s4d1 = RNA_s4d1.float()
RNA_s4d8 = RNA_s4d8.float()
RNA_s4d9 = RNA_s4d9.float()

pro_s1d1 = pro_s1d1.float()
pro_s1d2 = pro_s1d2.float()
pro_s1d3 = pro_s1d3.float()
pro_s2d1 = pro_s2d1.float()
pro_s2d4 = pro_s2d4.float()
pro_s2d5 = pro_s2d5.float()
pro_s3d1 = pro_s3d1.float()
pro_s3d6 = pro_s3d6.float()
pro_s3d7 = pro_s3d7.float()
pro_s4d1 = pro_s4d1.float()
pro_s4d8 = pro_s4d8.float()
pro_s4d9 = pro_s4d9.float()

In [27]:
# RNA_s1d1 = RNA_s1d1.to(device)
# RNA_s1d2 = RNA_s1d2.to(device)
# RNA_s1d3 = RNA_s1d3.to(device)
# RNA_s2d1 = RNA_s2d1.to(device)
# RNA_s2d4 = RNA_s2d4.to(device)
# RNA_s2d5 = RNA_s2d5.to(device)
# RNA_s3d1 = RNA_s3d1.to(device)
# RNA_s3d6 = RNA_s3d6.to(device)
# RNA_s3d7 = RNA_s3d7.to(device)
# RNA_s4d1 = RNA_s4d1.to(device)
# RNA_s4d8 = RNA_s4d8.to(device)
# RNA_s4d9 = RNA_s4d9.to(device)

# pro_s1d1 = pro_s1d1.to(device)
# pro_s1d2 = pro_s1d2.to(device)
# pro_s1d3 = pro_s1d3.to(device)
# pro_s2d1 = pro_s2d1.to(device)
# pro_s2d4 = pro_s2d4.to(device)
# pro_s2d5 = pro_s2d5.to(device)
# pro_s3d1 = pro_s3d1.to(device)
# pro_s3d6 = pro_s3d6.to(device)
# pro_s3d7 = pro_s3d7.to(device)
# pro_s4d1 = pro_s4d1.to(device)
# pro_s4d8 = pro_s4d8.to(device)
# pro_s4d9 = pro_s4d9.to(device)

In [28]:
g_s1d1 = dgl.graph(([0]*(len(batch_s1d1)), list(range(len(batch_s1d1)))), num_nodes=len(batch_s1d1))
g_s1d2 = dgl.graph(([0]*(len(batch_s1d2)), list(range(len(batch_s1d2)))), num_nodes=len(batch_s1d2))
g_s1d3 = dgl.graph(([0]*(len(batch_s1d3)), list(range(len(batch_s1d3)))), num_nodes=len(batch_s1d3))
g_s2d1 = dgl.graph(([0]*(len(batch_s2d1)), list(range(len(batch_s2d1)))), num_nodes=len(batch_s2d1))
g_s2d4 = dgl.graph(([0]*(len(batch_s2d4)), list(range(len(batch_s2d4)))), num_nodes=len(batch_s2d4))
g_s2d5 = dgl.graph(([0]*(len(batch_s2d5)), list(range(len(batch_s2d5)))), num_nodes=len(batch_s2d5))
g_s3d1 = dgl.graph(([0]*(len(batch_s3d1)), list(range(len(batch_s3d1)))), num_nodes=len(batch_s3d1))
g_s3d6 = dgl.graph(([0]*(len(batch_s3d6)), list(range(len(batch_s3d6)))), num_nodes=len(batch_s3d6))
g_s3d7 = dgl.graph(([0]*(len(batch_s3d7)), list(range(len(batch_s3d7)))), num_nodes=len(batch_s3d7))

g_s4d1 = dgl.graph(([0]*(len(batch_s4d1)), list(range(len(batch_s4d1)))), num_nodes=len(batch_s4d1))
g_s4d8 = dgl.graph(([0]*(len(batch_s4d8)), list(range(len(batch_s4d8)))), num_nodes=len(batch_s4d8))
g_s4d9 = dgl.graph(([0]*(len(batch_s4d9)), list(range(len(batch_s4d9)))), num_nodes=len(batch_s4d9))

In [29]:
start = time.perf_counter()
    
for i in tqdm_notebook(range(1, g_s1d1.num_nodes())):
    g_s1d1 = dgl.add_edges(g_s1d1, torch.tensor([i]*len(batch_s1d1)), torch.tensor(list(range(len(batch_s1d1)))))
    
for i in tqdm_notebook(range(1, g_s1d2.num_nodes())):
    g_s1d2 = dgl.add_edges(g_s1d2, torch.tensor([i]*len(batch_s1d2)), torch.tensor(list(range(len(batch_s1d2)))))
    
for i in tqdm_notebook(range(1, g_s1d3.num_nodes())):
    g_s1d3 = dgl.add_edges(g_s1d3, torch.tensor([i]*len(batch_s1d3)), torch.tensor(list(range(len(batch_s1d3)))))

for i in tqdm_notebook(range(1, g_s2d1.num_nodes())):
    g_s2d1 = dgl.add_edges(g_s2d1, torch.tensor([i]*len(batch_s2d1)), torch.tensor(list(range(len(batch_s2d1)))))
    
for i in tqdm_notebook(range(1, g_s2d4.num_nodes())):
    g_s2d4 = dgl.add_edges(g_s2d4, torch.tensor([i]*len(batch_s2d4)), torch.tensor(list(range(len(batch_s2d4)))))
    
for i in tqdm_notebook(range(1, g_s2d5.num_nodes())):
    g_s2d5 = dgl.add_edges(g_s2d5, torch.tensor([i]*len(batch_s2d5)), torch.tensor(list(range(len(batch_s2d5)))))
    
for i in tqdm_notebook(range(1, g_s3d1.num_nodes())):
    g_s3d1 = dgl.add_edges(g_s3d1, torch.tensor([i]*len(batch_s3d1)), torch.tensor(list(range(len(batch_s3d1)))))
    
for i in tqdm_notebook(range(1, g_s3d6.num_nodes())):
    g_s3d6 = dgl.add_edges(g_s3d6, torch.tensor([i]*len(batch_s3d6)), torch.tensor(list(range(len(batch_s3d6)))))
    
for i in tqdm_notebook(range(1, g_s3d7.num_nodes())):
    g_s3d7 = dgl.add_edges(g_s3d7, torch.tensor([i]*len(batch_s3d7)), torch.tensor(list(range(len(batch_s3d7)))))


for i in tqdm_notebook(range(1, g_s4d1.num_nodes())):
    g_s4d1 = dgl.add_edges(g_s4d1, torch.tensor([i]*len(batch_s4d1)), torch.tensor(list(range(len(batch_s4d1)))))
    
for i in tqdm_notebook(range(1, g_s4d8.num_nodes())):
    g_s4d8 = dgl.add_edges(g_s4d8, torch.tensor([i]*len(batch_s4d8)), torch.tensor(list(range(len(batch_s4d8)))))
    
for i in tqdm_notebook(range(1, g_s4d9.num_nodes())):
    g_s4d9 = dgl.add_edges(g_s4d9, torch.tensor([i]*len(batch_s4d9)), torch.tensor(list(range(len(batch_s4d9)))))
    
end = time.perf_counter()
print(f'Constructing graphs finished in {(end-start)/60} mins')

  0%|          | 0/4720 [00:00<?, ?it/s]

  0%|          | 0/4463 [00:00<?, ?it/s]

  0%|          | 0/5483 [00:00<?, ?it/s]

  0%|          | 0/9352 [00:00<?, ?it/s]

  0%|          | 0/5025 [00:00<?, ?it/s]

  0%|          | 0/8205 [00:00<?, ?it/s]

  0%|          | 0/8581 [00:00<?, ?it/s]

  0%|          | 0/9976 [00:00<?, ?it/s]

  0%|          | 0/10361 [00:00<?, ?it/s]

  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/206 [00:00<?, ?it/s]

  0%|          | 0/458 [00:00<?, ?it/s]

Constructing graphs finished in 12734.938691910007 seconds


In [30]:
g_s1d1.ndata['RNA'] = RNA_s1d1
g_s1d2.ndata['RNA'] = RNA_s1d2
g_s1d3.ndata['RNA'] = RNA_s1d3
g_s2d1.ndata['RNA'] = RNA_s2d1
g_s2d4.ndata['RNA'] = RNA_s2d4
g_s2d5.ndata['RNA'] = RNA_s2d5
g_s3d1.ndata['RNA'] = RNA_s3d1
g_s3d6.ndata['RNA'] = RNA_s3d6
g_s3d7.ndata['RNA'] = RNA_s3d7

g_s4d1.ndata['RNA'] = RNA_s4d1
g_s4d8.ndata['RNA'] = RNA_s4d8
g_s4d9.ndata['RNA'] = RNA_s4d9

In [31]:
g_s1d1.ndata['pro'] = pro_s1d1
g_s1d2.ndata['pro'] = pro_s1d2
g_s1d3.ndata['pro'] = pro_s1d3
g_s2d1.ndata['pro'] = pro_s2d1
g_s2d4.ndata['pro'] = pro_s2d4
g_s2d5.ndata['pro'] = pro_s2d5
g_s3d1.ndata['pro'] = pro_s3d1
g_s3d6.ndata['pro'] = pro_s3d6
g_s3d7.ndata['pro'] = pro_s3d7

g_s4d1.ndata['pro'] = pro_s4d1
g_s4d8.ndata['pro'] = pro_s4d8
g_s4d9.ndata['pro'] = pro_s4d9

In [32]:
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

# Create the model with given dimensions
# model = GCN(g.ndata['RNA'].shape[1], 16, g.ndata['pro'].shape[1])

In [49]:
# g = g.to('cuda')
# model = GCN(g.ndata['RNA'].shape[1], 16, g.ndata['pro'].shape[1]).to('cuda')
model = GCN(g_s1d1.ndata['RNA'].shape[1], 16, g_s1d1.ndata['pro'].shape[1])
train(g_s1d1, model)

Epoch: 0 loss: tensor(1.8174, grad_fn=<MseLossBackward>)
Epoch: 10 loss: tensor(1.6156, grad_fn=<MseLossBackward>)
Epoch: 20 loss: tensor(1.4693, grad_fn=<MseLossBackward>)
Epoch: 30 loss: tensor(1.3446, grad_fn=<MseLossBackward>)
Epoch: 40 loss: tensor(1.2391, grad_fn=<MseLossBackward>)
Epoch: 50 loss: tensor(1.1499, grad_fn=<MseLossBackward>)
Epoch: 60 loss: tensor(1.0742, grad_fn=<MseLossBackward>)
Epoch: 70 loss: tensor(1.0097, grad_fn=<MseLossBackward>)
Epoch: 80 loss: tensor(0.9545, grad_fn=<MseLossBackward>)
Epoch: 90 loss: tensor(0.9071, grad_fn=<MseLossBackward>)


In [None]:
# def train(g, model):
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

#     RNA_features = g_s1d1.ndata['RNA']
#     for e in range(100):
#         output = model(g_s1d1, RNA_features)
#         loss = F.mse_loss(output, g.ndata['pro'])
#         if e%10==0:
#             print('Epoch:', e, 'loss:', loss)

#         # Backward
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
# model = GCN(g_s1d1.ndata['RNA'].shape[1], 16, g_s1d1.ndata['pro'].shape[1])
# train(g_s1d1, model)

In [None]:
train_graphs = [g_s1d1, g_s1d2, g_s1d3, g_s2d1, g_s2d4, g_s2d5, g_s3d1, g_s3d6, g_s3d7]
val_graphs = [g_s3d1, g_s3d6, g_s3d7]
test_graphs = [g_s4d1, g_s4d8, g_s4d9]

train_val_test_graphs = []

In [None]:
PATH = 'GNN_no_norm_RNA_pro_batch'

In [59]:
def fit(num_epochs, model, loss_fn):
    val_best = float('inf')

    for epoch in range(num_epochs):
        for train_g in train_graphs:
            for e in range(100):
                output = model(train_g, train_g.ndata['RNA'])
                loss = F.mse_loss(output, train_g.ndata['pro']
                loss.backward()
                opt.step()
                opt.zero_grad()
                                  
        val_temp_loss = 0
        for g in val_graphs:
            loss = F.mse_loss(pred, g.ndata['pro']
            loss = torch.sqrt(loss)
            loss = loss.cpu().detach().numpy()
            val_temp_loss = val_temp_loss + loss
                              
        if val_temp_loss < val_vest:
            torch.save(model.state_dict(), PATH)
            val_best = val_loss
                              
        if epoch % 100 == 0:
            train_val_test_losses = []
            train_val_test_predictions = []
            
            model = model.eval()
            for g in train_val_test_graphs:
                pred = model(g, g.ndata['RNA'])
                loss = loss_fn(pred, g.ndata['pro']
                loss = torch.sqrt(loss)
                loss = loss.cpu().detach().numpy()
                train_val_test_losses.append(loss)
                train_val_test_predictions.appedn(pred.cpu().detach().numpy())
                
            print('Epoch', epoch, 's1d1:', train_val_test_losses[0], 's1d2:', train_val_test_losses[1], 's1d3:', train_val_test_losses[2], 's2d1:', train_val_test_losses[3],
                  's2d4:', train_val_test_losses[4], 's2d5:', train_val_test_losses[5], 's3d1:', train_val_test_losses[6], 's3d6:', train_val_test_losses[7],
                  's3d7:', train_val_test_losses[8], 's4d1:', train_val_test_losses[9], 's4d8:', train_val_test_losses[10], 's4d9:', train_val_test_losses[11])
            
            figure, ax = plt.subplots(1, 12, figsize = (30, 3))
            for i in range(1):
                ax[0].scatter(pro_s1d1.cpu().detach().numpy()[:, 0], train_val_test_predictions[0].cpu().detach().numpy()[:, 0], c='blue', s = 30, alpha=0.05)
                ax[0].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[0].set_title('s1d1 result')
                ax[0].set_xlabel('true')
                ax[0].set_ylabel('pred')

                ax[1].scatter(pro_s1d2.cpu().detach().numpy()[:, 0], train_val_test_predictions[1].cpu().detach().numpy()[:, 0], c='blue', s = 30, alpha=0.05)
                ax[1].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[1].set_title('s1d2 result')
                ax[1].set_xlabel('true')
                ax[1].set_ylabel('pred')

                ax[2].scatter(pro_s1d3.cpu().detach().numpy()[:, 0], train_val_test_predictions[2].cpu().detach().numpy()[:, 0], c='blue', s = 30, alpha=0.05)
                ax[2].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[2].set_title('s1d3 result')
                ax[2].set_xlabel('true')
                ax[2].set_ylabel('pred')
                
                ax[3].scatter(pro_s2d1.cpu().detach().numpy()[:, 0], train_val_test_predictions[3].cpu().detach().numpy()[:, 0], c='red', s = 30, alpha=0.05)
                ax[3].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[3].set_title('s2d1 result')
                ax[3].set_xlabel('true')
                ax[3].set_ylabel('pred')

                ax[4].scatter(pro_s2d4.cpu().detach().numpy()[:, 0], train_val_test_predictions[4].cpu().detach().numpy()[:, 0], c='red', s = 30, alpha=0.05)
                ax[4].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[4].set_title('s2d4 result')
                ax[4].set_xlabel('true')
                ax[4].set_ylabel('pred')

                ax[5].scatter(pro_s2d5.cpu().detach().numpy()[:, 0], train_val_test_predictions[5].cpu().detach().numpy()[:, 0], c='red', s = 30, alpha=0.05)
                ax[5].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[5].set_title('s2d5 result')
                ax[5].set_xlabel('true')
                ax[5].set_ylabel('pred')
                
                ax[6].scatter(pro_s3d1.cpu().detach().numpy()[:, 0], train_val_test_predictions[6].cpu().detach().numpy()[:, 0], c='orange', s = 30, alpha=0.05)
                ax[6].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[6].set_title('s3d1 result')
                ax[6].set_xlabel('true')
                ax[6].set_ylabel('pred')

                ax[7].scatter(pro_s3d6.cpu().detach().numpy()[:, 0], train_val_test_predictions[7].cpu().detach().numpy()[:, 0], c='orange', s = 30, alpha=0.05)
                ax[7].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[7].set_title('s3d6 result')
                ax[7].set_xlabel('true')
                ax[7].set_ylabel('pred')

                ax[8].scatter(pro_s3d7.cpu().detach().numpy()[:, 0], train_val_test_predictions[8].cpu().detach().numpy()[:, 0], c='orange', s = 30, alpha=0.05)
                ax[8].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[8].set_title('s3d7 result')
                ax[8].set_xlabel('true')
                ax[8].set_ylabel('pred')
                
                ax[9].scatter(pro_s4d1.cpu().detach().numpy()[:, 0], train_val_test_predictions[9].cpu().detach().numpy()[:, 0], c='green', s = 30, alpha=0.05)
                ax[9].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[9].set_title('s4d1 result')
                ax[9].set_xlabel('true')
                ax[9].set_ylabel('pred')

                ax[10].scatter(pro_s4d8.cpu().detach().numpy()[:, 0], train_val_test_predictions[10].cpu().detach().numpy()[:, 0], c='green', s = 30, alpha=0.05)
                ax[10].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[10].set_title('s4d8 result')
                ax[10].set_xlabel('true')
                ax[10].set_ylabel('pred')

                ax[11].scatter(pro_s4d9.cpu().detach().numpy()[:, 0], train_val_test_predictions[11].cpu().detach().numpy()[:, 0], c='green', s = 30, alpha=0.05)
                ax[11].plot([-1.5, 3], [-1.5, 3], 'k-')
                ax[11].set_title('s4d9 result')
                ax[11].set_xlabel('true')
                ax[11].set_ylabel('pred')     
            plt.show()
    return s1d1_pred.cpu().detach().numpy(), s1d2.cpu().detach().numpy(), s1d3.cpu().detach().numpy(), s2d1.cpu().detach().numpy()

In [None]:
model = GCN(g_s1d1.ndata['RNA'].shape[1], 16, g_s1d1.ndata['pro'].shape[1])
train(5000, model, )

In [35]:
colors = ['red','green','blue','purple', 'yellow', 'orange']
k = 0
batch_lib = {}
for i in range(len(batch_train)):
    if batch_train[i] not in batch_lib:
        k = k + 1
        batch_lib[batch_train[i]] = k
    batch_train[i] = batch_lib[batch_train[i]]

In [36]:
colors = ['red','green','blue','purple', 'yellow', 'orange']
k = 0
batch_lib = {}
for i in range(len(batch_val)):
    if batch_val[i] not in batch_lib:
        k = k + 1
        batch_lib[batch_val[i]] = k
    batch_val[i] = batch_lib[batch_val[i]]

In [37]:
colors = ['red','green','blue','purple', 'yellow', 'orange']
k = 0
batch_lib = {}
for i in range(len(batch_test)):
    if batch_test[i] not in batch_lib:
        k = k + 1
        batch_lib[batch_test[i]] = k
    batch_test[i] = batch_lib[batch_test[i]]

In [38]:
batch_train = batch_train + batch_val

In [39]:
train_output = train_output[:, 0:pred_num]
train_output.shape

(66175, 134)

In [40]:
val_output = val_output[:, 0:pred_num]
val_output.shape

(28921, 134)

In [41]:
test_output = test_output[:, 0:pred_num]
test_output.shape

(1000, 134)

In [42]:
train_ds = TensorDataset(RNA_s1d1, pro_s1d1)
train_dl = DataLoader(train_ds, batch_size= batch_size, shuffle=True)

In [43]:
input_feature = train_input.shape[1]
output_feature = pred_num

In [45]:
PATH = 'No_norm_model_RNA_pro'