# Import package

In [1]:
import os
import time
import tensorflow as tf
import networkx as nx
import numpy as np
import pandas as pd
import scipy.sparse as sp

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Flatten
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import sparse_categorical_accuracy
from tensorflow.keras.optimizers import Adam, SGD, RMSprop, Adagrad
from tensorflow.keras.regularizers import l2
from spektral.data import Dataset, Graph
from spektral.data import MixedLoader
from spektral.utils import normalized_laplacian
from spektral.layers import GCNConv, GlobalSumPool, ChebConv, GlobalAttnSumPool
from spektral.layers.ops import sp_matrix_to_sp_tensor

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [3]:
!rm -rf ./checkpoints/train/*

# Self-define dataset

In [4]:
class scRNA(Dataset):

    def __init__(self, **kwargs):
        self.a = None
        super().__init__(**kwargs)

    def read(self):
        x, y, node_order = _get_scRNA_exprs()
        self.a = _get_adjacency(node_order)

        return [Graph(x=x_, y=y_) for x_, y_ in zip(x, y)]

def _get_adjacency(node_order):
    #Read in edge file
    g = nx.read_adjlist("../5.1.Edge_of_gene_gene_interaction_network.csv",
                        delimiter = ",")
    # Adjacency
    A = nx.adj_matrix(g, weight = None, nodelist = node_order)
    return A

def _get_scRNA_exprs():
    #Read in exprs file
    exprs = pd.read_hdf("../6.Filtered_node_exprs.h5", key = "exprs")
    cell = pd.read_csv("../3.Cell_label.csv", index_col = 0)
    #Transpose the dataframe
    exprs = exprs.T
    #Exprs values reshape
    x = exprs.values.reshape(exprs.shape[0], exprs.shape[1], 1)
    #Cell label
    y = cell["Number_label"].values
    #Node order
    node_order = exprs.columns.to_list()
    return x, y, node_order

# Set seed

In [5]:
np.random.seed(seed = 203)

# Loading self defined dataset

In [6]:
# Load data
data = scRNA()

# The adjacency matrix is stored as an attribute of the dataset.
# Create filter for GCN and convert to sparse tensor.

#data.a = normalized_laplacian(data.a)
data.a = ChebConv.preprocess(data.a)
#data.a = GCNConv.preprocess(data.a)
data.a = sp_matrix_to_sp_tensor(data.a)

# Parameter setting

In [7]:
# Parameters
batch_size = 32  # Batch size
epochs = 500  # Number of training epochs
patience = 20  # Patience for early stopping
l2_reg = 9e-3  # Regularization rate for l2

# Dataset spilt

In [8]:
cell = pd.read_csv("../3.Cell_label.csv", index_col = 0)

In [9]:
train_position = np.array([], dtype = int)
testing_position = np.array([], dtype = int)
validation_position = np.array([], dtype = int)

for i in range(0,8):
    if i < 1:
        cell_numbers = cell["Number_label"].value_counts()[i]
        start = 0
        end = cell_numbers
        cell_total = np.array(range(start, end))
        tmp_train = np.random.choice(cell_total, int(cell_numbers * 0.7), replace = False)
        tmp = np.setdiff1d(cell_total, tmp_train)
        tmp_validation = np.random.choice(tmp, int(cell_numbers * 0.2), replace = False)
        tmp_testing = np.setdiff1d(tmp, tmp_validation)
        train_position = np.append(train_position, tmp_train)
        validation_position = np.append(validation_position, tmp_validation)
        testing_position = np.append(testing_position, tmp_testing)
        
    else:
        cell_numbers = cell["Number_label"].value_counts()[i]
        start = end 
        end = end + cell_numbers
        cell_total = np.array(range(start, end))
        tmp_train = np.random.choice(cell_total, int(cell_numbers * 0.7), replace = False)
        tmp = np.setdiff1d(cell_total, tmp_train)
        tmp_validation = np.random.choice(tmp, int(cell_numbers * 0.2), replace = False)
        tmp_testing = np.setdiff1d(tmp, tmp_validation)
        train_position = np.append(train_position, tmp_train)
        validation_position = np.append(validation_position, tmp_validation)
        testing_position = np.append(testing_position, tmp_testing)
        
del start, end, cell_total, tmp, tmp_train, tmp_testing, tmp_validation

In [10]:
train = data[train_position]
testing = data[testing_position]
valid = data[validation_position]

In [11]:
print("Number of cell in train set: " + str(train.n_graphs))
print("Number of cell in validation set: " + str(valid.n_graphs))
print("Number of cell in testing set: " + str(testing.n_graphs))

Number of cell in train set: 59793
Number of cell in validation set: 17081
Number of cell in testing set: 8549


In [12]:
#The data need to be predicted
Donor_A = data[85423:88323]
Donor_C = data[88323:97842]
Data6k = data[97842:103261]
Data8k = data[103261:111642]

# Data Loader

In [13]:
# We use a MixedLoader since the dataset is in mixed mode
loader_tr = MixedLoader(train, 
                        batch_size = batch_size, epochs = epochs)
loader_va = MixedLoader(valid, 
                        batch_size = batch_size)
loader_te = MixedLoader(testing, 
                        batch_size = batch_size)

# Build model 

In [14]:
# Build model
class Net(Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.conv1 = ChebConv(32, K = 5, activation = "elu", kernel_regularizer = l2(l2_reg))
        self.bn0 = BatchNormalization()
        self.conv2 = ChebConv(32, K = 5, activation = "elu", kernel_regularizer = l2(l2_reg))
        self.bn1 = BatchNormalization()
        #self.conv1 = GCNConv(32, activation = "elu", kernel_regularizer = l2(l2_reg))
        #self.conv2 = GCNConv(32, activation = "elu", kernel_regularizer = l2(l2_reg))
        #self.conv3 = GCNConv(32, activation = "elu", kernel_regularizer = l2(l2_reg))
        #self.conv4 = GCNConv(32, activation = "elu", kernel_regularizer = l2(l2_reg))

        # 參考 https://www.kaggle.com/kmader/mnist-graph-deep-learning
        
        self.flatten = Flatten()
        
        #Flattten有281280
        self.fc1 = Dense(1024, activation = "relu", kernel_regularizer = l2(l2_reg))
        self.fc2 = Dense(256, activation = "relu", kernel_regularizer = l2(l2_reg))
        self.fc3 = Dense(64, activation = "relu", kernel_regularizer = l2(l2_reg))
        self.fc4 = Dense(8, activation = "softmax")  # scRNA-seq has 8 classes

    def call(self, inputs):
        x, a = inputs
        
        x = self.conv1([x, a])
        x = self.bn0(x)
        x = self.conv2([x, a])
        x = self.bn1(x)
        output = self.flatten(x)
        output = self.fc1(output)
        output = self.fc2(output)
        output = self.fc3(output)
        output = self.fc4(output)

        return output

In [15]:
# Create model
GNN_model = Net()
GNN_optimizer = SGD(learning_rate = 0.001)
loss_fn = SparseCategoricalCrossentropy()

# Checkpoint

In [16]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(GNN_model = GNN_model,
                           GNN_optimizer = GNN_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep = 2)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

# Model training method define

In [17]:
training_total_loss = np.array([])
training_acc = np.array([])
training_loss = np.array([])
model_loss = np.array([])
validation_loss = np.array([])
validation_acc = np.array([])
testing_loss = np.array([])
testing_acc = np.array([])

# Training function
@tf.function
def train_on_batch(inputs, target):
    with tf.GradientTape() as tape:
        predictions = GNN_model(inputs, training = True)
        train_losses = loss_fn(target, predictions)
        model_losses = sum(GNN_model.losses)
        loss = train_losses + model_losses
        acc = tf.reduce_mean(sparse_categorical_accuracy(target, predictions))

    gradients = tape.gradient(loss, GNN_model.trainable_variables)
    GNN_optimizer.apply_gradients(zip(gradients, GNN_model.trainable_variables))
    return loss, acc, train_losses, model_losses


# Evaluation function
def evaluate(loader):
    step = 0
    results = []
    for batch in loader:
        step += 1
        inputs, target = batch
        predictions = GNN_model(inputs, training = False)
        loss = loss_fn(target, predictions)
        acc = tf.reduce_mean(sparse_categorical_accuracy(target, predictions))
        results.append((loss, acc, len(target)))  # Keep track of batch size
        if step == loader.steps_per_epoch:
            results = np.array(results)
            return np.average(results[:, :-1], 0, weights = results[:, -1])


# Model training

In [18]:
%%time
# Setup training
best_val_loss = 999999
current_patience = patience
step = 0
epochs = 0
# Training loop
results_tr = []
start = time.time()
for batch in loader_tr:
    step += 1

    # Training step
    inputs, target = batch
    loss, acc, train_lossed, model_losses = train_on_batch(inputs, target)
    results_tr.append((loss, acc, train_lossed, model_losses, len(target)))
    
    if step == loader_tr.steps_per_epoch:
        results_va = evaluate(loader_va)
        if results_va[0] < best_val_loss:
            best_val_loss = results_va[0]
            current_patience = patience
            results_te = evaluate(loader_te)
        else:
            current_patience -= 1
            if current_patience == 0:
                print("Early stopping")
                break

        # Print results
        results_tr = np.array(results_tr)
        results_tr = np.average(results_tr[:, :-1], 0, weights = results_tr[:, -1])
        training_total_loss = np.append(training_total_loss, results_tr[0])
        training_acc = np.append(training_acc, results_tr[1])
        training_loss = np.append(training_loss, results_tr[2])
        model_loss = np.append(model_loss, results_tr[3])
        validation_loss = np.append(validation_loss, results_va[0])
        validation_acc = np.append(validation_acc, results_va[1])
        testing_loss = np.append(testing_loss, results_te[0])
        testing_acc = np.append(testing_acc, results_te[1])
        print("Epoch {}: {}s \n"
              "Total train loss: {:.4f}, acc: {:.4f} | Train loss: {:.4f} | Model loss: {:.4f} | "
              "Valid loss: {:.4f}, acc: {:.4f} | "
              "Test loss: {:.4f}, acc: {:.4f}".format(epochs, time.time() - start ,
                                                      *results_tr, *results_va, *results_te))

        # Reset epoch
        results_tr = []
        step = 0
        start = time.time()
        # Save model
        ckpt_save_path = ckpt_manager.save()
        print ("Saving checkpoint for epoch {} sucessfully.".format(epochs))
        epochs += 1

Epoch 0: 147.80355381965637s 
Total train loss: 22.9112, acc: 0.8347 | Train loss: 0.4088 | Model loss: 22.5025 | Valid loss: 0.3142, acc: 0.8687 | Test loss: 0.3190, acc: 0.8685
Saving checkpoint for epoch 0 sucessfully.
Epoch 1: 129.75979161262512s 
Total train loss: 21.1969, acc: 0.9460 | Train loss: 0.1487 | Model loss: 21.0482 | Valid loss: 0.3101, acc: 0.8704 | Test loss: 0.3094, acc: 0.8732
Saving checkpoint for epoch 1 sucessfully.
Epoch 2: 138.35864305496216s 
Total train loss: 19.7552, acc: 0.9866 | Train loss: 0.0667 | Model loss: 19.6885 | Valid loss: 0.3244, acc: 0.8720 | Test loss: 0.3094, acc: 0.8732
Saving checkpoint for epoch 2 sucessfully.
Epoch 3: 122.41925001144409s 
Total train loss: 18.4490, acc: 0.9986 | Train loss: 0.0332 | Model loss: 18.4159 | Valid loss: 0.3405, acc: 0.8747 | Test loss: 0.3094, acc: 0.8732
Saving checkpoint for epoch 3 sucessfully.
Epoch 4: 122.46817183494568s 
Total train loss: 17.2459, acc: 0.9999 | Train loss: 0.0214 | Model loss: 17.2244 

In [19]:
with open("./1.Training_history.csv", "w") as outfile:
    outfile.write("{}, {}, {}, {}, {}, {}, {}, {}, {}\n".format("Epochs", "Training_total_loss", 
                                                                "Training_accuracy","Training_loss", "Model_loss",
                                                                "Validation_loss","Validation_accuracy",
                                                                "Testing_loss", "Testing_accuracy"))
    for i in range(len(training_loss)):
        outfile.write("{}, {}, {}, {}, {}, {}, {}, {}, {}\n".format(i, training_total_loss[i], training_acc[i],
                                                                    training_loss[i], model_loss[i],  
                                                                    validation_loss[i], validation_acc[i],
                                                                    testing_loss[i], testing_acc[i]))
        