In [1]:
import os, gc, pickle, scipy.sparse
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from colorama import Fore, Back, Style
from matplotlib.ticker import MaxNLocator

from sklearn.model_selection import GroupKFold
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import mean_squared_error

In [2]:
metadata_df = pd.read_csv(r'metadata.csv', index_col='cell_id')
metadata_df = metadata_df[metadata_df.technology=="citeseq"]
print(metadata_df.shape)

(119651, 4)


In [3]:
Y = pd.read_hdf('train_cite_targets.h5')
print(Y.shape)

(70988, 140)


In [4]:
import pickle
X_svd= pickle.load(open("X_svd.pckl", "rb"))  
print(X_svd.shape) 

(70988, 512)


In [5]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(2019)

TIME_STEP = 1  
INPUT_SIZE = 512  
INIT_LR = 0.02  
N_EPOCHS = 200  


class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=256,  
            num_layers=1,  
        )
        self.out = nn.Linear(256, 140)

    def forward(self, x, h):
        # x (time_step, batch_size, input_size)
        # h (n_layers, batch, hidden_size)
        # out (time_step, batch_size, hidden_size)
        out, h = self.rnn(x, h)
        prediction = self.out(out)
        return prediction, h


rnn = RNN()
print(rnn)


RNN(
  (rnn): RNN(512, 256)
  (out): Linear(in_features=256, out_features=140, bias=True)
)


In [6]:
device = torch.device('cuda:0')
rnn=rnn.to(device)
optimizer = torch.optim.Adam(rnn.parameters(), lr=INIT_LR)
loss_func = nn.MSELoss().to(device)
h_state = None 

X_svd=np.array(X_svd)
print(X_svd.shape)
#X_svd=np.reshape(X_svd, (X_svd.shape[0], 1, 512))
X_svd= torch.from_numpy(X_svd[:, np.newaxis]).to(device)
print(X_svd.shape)
Y=np.array(Y)
Y= torch.from_numpy(Y[:, np.newaxis]).to(device)

for step in range(N_EPOCHS):
    start, end = step * 100, (step + 1) * 100  # 时间跨度
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32, endpoint=False)
    #print(steps)
    
    prediction, h_state = rnn(X_svd, h_state)  
    h_state = h_state.detach() 
    loss = loss_func(prediction, Y)
    if step%5==0:
        print('Epoch={} loss={}'.format(step,loss.item() ))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

(70988, 512)
torch.Size([70988, 1, 512])
Epoch=0 loss=15.085948944091797
Epoch=5 loss=6.234812259674072
Epoch=10 loss=5.139936923980713
Epoch=15 loss=4.099538803100586
Epoch=20 loss=3.6375885009765625
Epoch=25 loss=3.3653881549835205
Epoch=30 loss=3.16005277633667
Epoch=35 loss=3.0219593048095703
Epoch=40 loss=2.911156415939331
Epoch=45 loss=2.8283615112304688
Epoch=50 loss=2.762176513671875
Epoch=55 loss=2.7098398208618164
Epoch=60 loss=2.6681408882141113
Epoch=65 loss=2.6329500675201416
Epoch=70 loss=2.6043338775634766
Epoch=75 loss=2.5798699855804443
Epoch=80 loss=2.5590834617614746
Epoch=85 loss=2.5412113666534424
Epoch=90 loss=2.525472640991211
Epoch=95 loss=2.511753559112549
Epoch=100 loss=2.4995908737182617
Epoch=105 loss=2.4887871742248535
Epoch=110 loss=2.479078531265259
Epoch=115 loss=2.4705095291137695
Epoch=120 loss=2.462122678756714
Epoch=125 loss=2.454838275909424
Epoch=130 loss=2.447923421859741
Epoch=135 loss=2.441408634185791
Epoch=140 loss=2.4356276988983154
Epoch=145