In [1]:
import os, gc, pickle, scipy.sparse
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from colorama import Fore, Back, Style
from matplotlib.ticker import MaxNLocator

from sklearn.model_selection import GroupKFold
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import mean_squared_error

In [2]:
metadata_df = pd.read_csv(r'metadata.csv', index_col='cell_id')
metadata_df = metadata_df[metadata_df.technology=="citeseq"]
print(metadata_df.shape)
print(metadata_df.head(5))

(119651, 4)
              day  donor cell_type technology
cell_id                                      
c2150f55becb    2  27678       HSC    citeseq
65b7edf8a4da    2  27678       HSC    citeseq
c1b26cb1057b    2  27678      EryP    citeseq
917168fa6f83    2  27678      NeuP    citeseq
2b29feeca86d    2  27678      EryP    citeseq


In [3]:
Y = pd.read_hdf('train_cite_targets.h5')
print(Y.shape)
print(Y.head(5))

(70988, 140)
gene_id           CD86     CD274     CD270     CD155     CD112      CD47  \
cell_id                                                                    
45006fe3e4c8  1.167804  0.622530  0.106959  0.324989  3.331674  6.426002   
d02759a80ba2  0.818970  0.506009  1.078682  6.848758  3.524885  5.279456   
c016c6b0efa5 -0.356703 -0.422261 -0.824493  1.137495  0.518924  7.221962   
ba7f733a4f75 -1.201507  0.149115  2.022468  6.021595  7.258670  2.792436   
fbcf2443ffb2 -0.100404  0.697461  0.625836 -0.298404  1.369898  3.254521   

gene_id            CD48      CD40     CD154      CD52  ...      CD94  \
cell_id                                                ...             
45006fe3e4c8   1.480766 -0.728392 -0.468851 -0.073285  ... -0.448390   
d02759a80ba2   4.930438  2.069372  0.333652 -0.468088  ...  0.323613   
c016c6b0efa5  -0.375034  1.738071  0.142919 -0.971460  ...  1.348692   
ba7f733a4f75  21.708519 -0.137913  1.649969 -0.754680  ...  1.504426   
fbcf2443ffb2  -1.65938

In [4]:
X = pd.read_hdf('train_cite_inputs.h5')
print(X.shape)
print(X.head(5))

(70988, 22050)
gene_id       ENSG00000121410_A1BG  ENSG00000268895_A1BG-AS1  \
cell_id                                                        
45006fe3e4c8                   0.0                       0.0   
d02759a80ba2                   0.0                       0.0   
c016c6b0efa5                   0.0                       0.0   
ba7f733a4f75                   0.0                       0.0   
fbcf2443ffb2                   0.0                       0.0   

gene_id       ENSG00000175899_A2M  ENSG00000245105_A2M-AS1  \
cell_id                                                      
45006fe3e4c8                  0.0                      0.0   
d02759a80ba2                  0.0                      0.0   
c016c6b0efa5                  0.0                      0.0   
ba7f733a4f75                  0.0                      0.0   
fbcf2443ffb2                  0.0                      0.0   

gene_id       ENSG00000166535_A2ML1  ENSG00000128274_A4GALT  \
cell_id                               

In [5]:
cell_index = X.index
meta = metadata_df.reindex(cell_index)
print(f"Original X shape: {str(X.shape):14} {X.size*4/1024/1024/1024:2.3f} GByte")
gc.collect()
X = scipy.sparse.csr_matrix(X.values)
gc.collect()

Original X shape: (70988, 22050) 5.831 GByte


0

In [6]:
print(f"Shape of both before SVD: {X.shape}")
svd = TruncatedSVD(n_components=512, random_state=1) # 512
X_svd = svd.fit_transform(X)
print(f"Shape of both after SVD:  {X_svd.shape}")

Shape of both before SVD: (70988, 22050)
Shape of both after SVD:  (70988, 512)


In [7]:
#存成pckl形式节省时间
import pickle
f = open('X_svd.pckl', 'wb')
pickle.dump(X_svd, f)
f.close()

In [11]:
def correlation_score(y_true, y_pred):
    if type(y_true) == pd.DataFrame: y_true = y_true.values
    if type(y_pred) == pd.DataFrame: y_pred = y_pred.values
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

In [12]:
import torch
import torch.nn as nn
from torch.optim import SGD,Adam
from torch.autograd import Variable
import torch.nn.functional as F
class MLP(nn.Module):
    def __init__(self):
        super(MLP,self).__init__()
        self.hidden1=nn.Linear(in_features=512,out_features=256,bias=True)
        self.hidden2=nn.Linear(256,140)
        
    def forward(self,x):
        x=F.relu(self.hidden1(x))
        output=self.hidden2(x)
        return output
mlp=MLP()
print(mlp)        
        
        

MLP(
  (hidden1): Linear(in_features=512, out_features=256, bias=True)
  (hidden2): Linear(in_features=256, out_features=140, bias=True)
)


In [18]:
import torch.utils.data as Data
X=torch.from_numpy(X_svd.astype(np.float32))
Y=np.array(Y)
Y=torch.from_numpy(Y.astype(np.float32))
train_data=Data.TensorDataset(X,Y)
train_loader=Data.DataLoader(dataset=train_data,batch_size=4096,shuffle=True,num_workers=0)

In [19]:
import warnings
warnings.filterwarnings("ignore")
device = torch.device('cuda:0')
mlp=mlp.to(device)

optimizer=torch.optim.SGD(mlp.parameters(),lr=0.005)
loss_func=nn.MSELoss().to(device)
train_loss_all=[]
for epoch in range(200):
    train_loss=0
    train_num=0
    for step, (b_x,b_y) in enumerate(train_loader):
        b_x=b_x.to(device)
        b_y=b_y.to(device)
        output=mlp(b_x)
        loss=loss_func(output,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss+=loss.item()*b_x.size(0)
        train_num+=b_x.size(0)
        loss1=train_loss/train_num
    train_loss_all.append(train_loss/train_num)
    if epoch%5==0:
        print('Epoch={} loss={}'.format(epoch,loss1 ))
    

Epoch=0 loss=2.7131641620836535
Epoch=5 loss=2.7175598742357745
Epoch=10 loss=2.783204019267828
Epoch=15 loss=2.717164559479048
Epoch=20 loss=2.6871127448982337
Epoch=25 loss=2.68901474933836
Epoch=30 loss=2.6753204198409426
Epoch=35 loss=2.712481170639143
Epoch=40 loss=2.655567690091126
Epoch=45 loss=2.667031989498206
Epoch=50 loss=2.6772681033167296
Epoch=55 loss=2.67260834886879
Epoch=60 loss=2.652836464798443
Epoch=65 loss=2.6644095359281423
Epoch=70 loss=2.638566842890326
Epoch=75 loss=2.6301489575195225
Epoch=80 loss=2.6244608960689315
Epoch=85 loss=2.621786571297302
Epoch=90 loss=2.691368432718579
Epoch=95 loss=2.615667095760389
Epoch=100 loss=2.650921754100299
Epoch=105 loss=2.620020844670775
Epoch=110 loss=2.634805714626638
Epoch=115 loss=2.6160130263140635
Epoch=120 loss=2.5998920588021264
Epoch=125 loss=2.6003416876406806
Epoch=130 loss=2.5947563621220926
Epoch=135 loss=2.597401530949399
Epoch=140 loss=2.642548188106197
Epoch=145 loss=2.5877286801037536
Epoch=150 loss=2.5888

In [24]:
torch.save(mlp,'mlp.pt')

In [25]:
X_test = pd.read_hdf('test_cite_inputs.h5')
print(X_test.shape)
print(X_test.head(5))

(48663, 22050)
gene_id       ENSG00000121410_A1BG  ENSG00000268895_A1BG-AS1  \
cell_id                                                        
c2150f55becb                   0.0                       0.0   
65b7edf8a4da                   0.0                       0.0   
c1b26cb1057b                   0.0                       0.0   
917168fa6f83                   0.0                       0.0   
2b29feeca86d                   0.0                       0.0   

gene_id       ENSG00000175899_A2M  ENSG00000245105_A2M-AS1  \
cell_id                                                      
c2150f55becb                  0.0                      0.0   
65b7edf8a4da                  0.0                      0.0   
c1b26cb1057b                  0.0                      0.0   
917168fa6f83                  0.0                      0.0   
2b29feeca86d                  0.0                      0.0   

gene_id       ENSG00000166535_A2ML1  ENSG00000128274_A4GALT  \
cell_id                               

In [26]:
X_test = scipy.sparse.csr_matrix(X_test.values)
gc.collect()
print(f"Shape of both before SVD: {X_test.shape}")
svd = TruncatedSVD(n_components=512, random_state=1) # 512
X_test_svd = svd.fit_transform(X_test)
print(f"Shape of both after SVD:  {X_test_svd.shape}")

Shape of both before SVD: (48663, 22050)
Shape of both after SVD:  (48663, 512)


In [28]:
#存成pckl形式节省时间
import pickle
f1 = open('X_test_svd.pckl', 'wb')
pickle.dump(X_test_svd, f)
f1.close()

ValueError: write to closed file