In [4]:
import numpy as np
import csv
from numpy import *
import os


def get_tf_list(tf_path):
    # return tf_list
    f_tf = open(tf_path)
    tf_reader = list(csv.reader(f_tf))
    tf_list = []
    for single in tf_reader[1:]:
        tf_list.append(single[0])
    print('Load ' + str(len(tf_list)) + ' TFs successfully!')
    return tf_list

def get_origin_expression_data(gene_expression_path):
    # return 1.tf-targets dict and pair-score dict
    #        2.number of timepoints
    f_expression = open(gene_expression_path, encoding="utf-8")
    expression_reader = list(csv.reader(f_expression))
    cells = expression_reader[0][1:]
    num_cells = len(cells)

    expression_record = {}
    num_genes = 0
    for single_expression_reader in expression_reader[1:]:
        if single_expression_reader[0] in expression_record:
            print('Gene name ' + single_expression_reader[0] + ' repeat!')
        expression_record[single_expression_reader[0]] = list(map(float, single_expression_reader[1:]))
        num_genes += 1
    print(str(num_genes) + ' genes and ' + str(num_cells) + ' cells are included in origin expression data.')
    return expression_record, cells

def get_normalized_expression_data(gene_expression_path):
    # return 1.tf-targets dict and pair-score dict
    #        2.number of timepoints
    expression_record, cells = get_origin_expression_data(gene_expression_path)
    expression_matrix = np.zeros((len(expression_record), len(cells)))
    index_row = 0
    for gene in expression_record:
        expression_record[gene] = np.log10(np.array(expression_record[gene]) + 10 ** -2)
        expression_matrix[index_row] = expression_record[gene]
        index_row += 1

    # Heat map
    # plt.figure(figsize=(15,15))
    # sns.heatmap(expression_matrix[0:100,0:100])
    # plt.show()

    return expression_record, cells

def get_gene_ranking(gene_order_path, low_express_gene_list, gene_num, output_path,
                     flag):  # flag=True:write to output_path
    # 1.delete genes p-value>=0.01
    # 2.delete genes with low expression
    # 3.rank genes in descending order of variance
    # 4.return gene names list of top genes and variance_record of p-value<0.01
    f_order = open(gene_order_path)
    order_reader = list(csv.reader(f_order))
    if flag:
        f_rank = open(output_path, 'w', newline='\n')
        f_rank_writer = csv.writer(f_rank)
    variance_record = {}
    variance_list = []
    significant_gene_list = []
    for single_order_reader in order_reader[1:]:
        # column 0:gene name
        # column 1:p value
        # column 2:variance
        if float(single_order_reader[1]) >= 0.01:
            break
        if single_order_reader[0] in low_express_gene_list:
            continue
        variance = float(single_order_reader[2])
        if variance not in variance_record:  # 1 variance corresponding to 1 gene
            variance_record[variance] = single_order_reader[0]
        else:  # 1 variance corresponding to n genes
            print(str(variance_record[variance]) + ' and ' + single_order_reader[0] + ' variance repeat!')
            variance_record[variance] = [variance_record[variance]]
            variance_record[variance].append(single_order_reader[0])
        variance_list.append(variance)
        tstr = single_order_reader[0]
        single_order_reader[0] = tstr.upper()
        significant_gene_list.append(single_order_reader[0])
    print('After delete genes with p-value>=0.01 or low expression, ' + str(len(variance_list)) + ' genes left.')
    variance_list.sort(reverse=True)
    gene_rank = []
    for single_variance_list in variance_list[0:gene_num]:
        if type(variance_record[single_variance_list]) is str:  # 1 variance corresponding to 1 gene
            gene_rank.append(variance_record[single_variance_list])
        else:  # 1 variance corresponding to n genes
            gene_rank.append(variance_record[single_variance_list][0])
            del variance_record[single_variance_list][0]
            if len(variance_record[single_variance_list]) == 1:
                variance_record[single_variance_list] = variance_record[single_variance_list][0]
        if flag:
            f_rank_writer.writerow([variance_record[single_variance_list]])
    f_order.close()
    if flag:
        f_rank.close()
    return gene_rank, significant_gene_list

def get_filtered_gold(gold_network_path, rank_list, output_path, flag):
    # 1.Load origin gold file
    # 2.Delete genes not in rank_list
    # 3.return tf-targets dict and pair-score dict
    # Note: If no score in gold network, score=999
    f_gold = open(gold_network_path, encoding='UTF-8-sig')
    gold_reader = list(csv.reader(f_gold))
    for i in range(0, len(gold_reader) - 1):
        temp = gold_reader[i]
        s1 = str(temp[0])
        s2 = str(temp[1])

        temp[0] = s1.upper()
        temp[1] = s2.upper()

        gold_reader[i] = temp
    # print("gold_reader",gold_reader)
    # print("rank_list",rank_list)
    # print("gold_reader",gold_reader)
    print("gold_reader[0]", gold_reader[0])
    has_score = True
    if len(gold_reader[0]) < 3:
        has_score = False
    gold_pair_record = {}
    gold_score_record = {}
    unique_gene_list = []
    for single_gold_reader in gold_reader[1:]:
        # column 0: TF
        # column 1: target gene
        # column 2: regulate score
        if (single_gold_reader[0] not in rank_list) or (single_gold_reader[1] not in rank_list):
            continue
        gene_pair = [single_gold_reader[0], single_gold_reader[1]]
        str_gene_pair = single_gold_reader[0] + ',' + single_gold_reader[1]

        if single_gold_reader[0] not in unique_gene_list: unique_gene_list.append(single_gold_reader[0])
        if single_gold_reader[1] not in unique_gene_list: unique_gene_list.append(single_gold_reader[1])
        if str_gene_pair in gold_score_record:
            print('Gold pair repeat!')
        if has_score:
            print("single_gold_reader[2]", single_gold_reader[2])
            gold_score_record[str_gene_pair] = float(single_gold_reader[2])
        else:
            gold_score_record[str_gene_pair] = 999
        if gene_pair[0] not in gold_pair_record:
            gold_pair_record[gene_pair[0]] = [gene_pair[1]]
        else:
            gold_pair_record[gene_pair[0]].append(gene_pair[1])
    print("gold_pair_record", gold_pair_record)
    # Some statistics of gold_network
    print(str(len(gold_pair_record)) + ' TFs and ' + str(
        len(gold_score_record)) + ' edges in gold_network consisted of genes in rank_list.')
    print(str(len(unique_gene_list)) + ' genes are common in rank_list and gold_network.')

    rank_density = len(gold_score_record) / (len(gold_pair_record) * (len(rank_list)))
    gold_density = len(gold_score_record) / (len(gold_pair_record) * (len(unique_gene_list)))

    print('Rank genes density = edges/(TFs*(len(rank_gene)-1))=' + str(rank_density))
    print('Gold genes density = edges/(TFs*len(unique_gene_list))=' + str(gold_density))

    # write to file
    print("unique_gene_list", unique_gene_list)
    if flag:
        f_unique = open(output_path, 'w', encoding="utf-8", newline='\n')
        f_unique_writer = csv.writer(f_unique)
        out_unique = np.array(unique_gene_list).reshape(len(unique_gene_list), 1)
        f_unique_writer.writerows(out_unique)
        f_unique.close()
    return gold_pair_record, gold_score_record, unique_gene_list

def generate_filtered_gold(gold_pair_record, gold_score_record, output_path):
    # write filtered_gold to output_path
    # print("cnm")
    f_filtered = open(output_path, 'w', encoding="utf-8", newline='\n')
    f_filtered_writer = csv.writer(f_filtered)
    f_filtered_writer.writerow(['TF', 'Target', 'Score'])
    # print("cnm")
    for tf in gold_pair_record:
        once_output = []
        for target in gold_pair_record[tf]:
            single_output = [tf, target, gold_score_record[tf + ',' + target]]
            once_output.append(single_output)
        f_filtered_writer.writerows(once_output)
    f_filtered.close()

def get_gene_pair_list(unique_gene_list, gold_pair_record, gold_score_record, output_file):
    # positive is relationship that tf regulate target
    # negtive is reationship that same tf doesn's regulate target.
    # When same tf doesn't have enough negtive, borrow negtive from other TFs.
    # When negtive is not enough,stop and prove positive:negtive = 1:1

    # generate all negtive gene pairs of TFs
    all_tf_negtive_record = {}
    for tf in gold_pair_record:
        # print("tf",tf)
        all_tf_negtive_record[tf] = []
        for target in unique_gene_list:
            if target in gold_pair_record[tf]:
                continue
            all_tf_negtive_record[tf].append(target)

    # generate negtive record without borrow
    rank_negtive_record = {}
    for tf in gold_pair_record:
        num_positive = len(gold_pair_record[tf])
        if num_positive > len(all_tf_negtive_record[tf]):
            rank_negtive_record[tf] = all_tf_negtive_record[tf]
            all_tf_negtive_record[tf] = []
        else:
            # maybe random.sample(all_tf_negtive_record[tf],num_positive) to promote performance
            rank_negtive_record[tf] = all_tf_negtive_record[tf][:num_positive]
            all_tf_negtive_record[tf] = all_tf_negtive_record[tf][num_positive:]

    # output positive and negtive pairs
    f_gpl = open(output_file, 'w', newline='\n')
    f_gpl_writer = csv.writer(f_gpl)
    f_gpl_writer.writerow(['TF', 'Target', 'Label', 'Score'])
    stop_flag = False
    for tf in gold_pair_record:
        once_output = []
        for target in gold_pair_record[tf]:
            # output positive
            single_output = [tf, target, '1', gold_score_record[tf + ',' + target]]
            once_output.append(single_output)
            # output negtive
            if len(rank_negtive_record[tf]) == 0:
                # borrow negtive for other TFs
                find_negtive = False
                for borrow_tf in all_tf_negtive_record:
                    if len(all_tf_negtive_record[borrow_tf]) > 0:
                        find_negtive = True
                        single_output = [borrow_tf, all_tf_negtive_record[borrow_tf][0], 0, 0]
                        del all_tf_negtive_record[borrow_tf][0]
                        break
                # if not enough negtive of others,stop and prove positive:negtive = 1:1
                if not find_negtive:
                    stop_flag = True
                    break
            else:
                # negtive without borrow
                single_output = [tf, rank_negtive_record[tf][0], 0, 0]
                del rank_negtive_record[tf][0]
            once_output.append(single_output)
        if stop_flag:
            f_gpl_writer.writerows(once_output[:-1])
            print('Negtive not enough!')
            break
        f_gpl_writer.writerows(once_output)  # output positive and negtive of 1 TF at a time
    f_gpl.close()

def get_low_express_gene(origin_expression_record, num_cells):
    # get gene_list who were expressed in fewer than 10% of the cells
    gene_list = []
    threshold = num_cells // 10
    for gene in origin_expression_record:
        num = 0
        for expression in origin_expression_record[gene]:
            if expression != 0:
                num += 1
                if num > threshold:
                    break
        if num <= threshold:
            gene_list.append(gene)
    return gene_list



def loadData(dataset,gene_pair_list_path,gene_expression_path,resultPath):

    origin_expression_record, cells = get_normalized_expression_data(gene_expression_path)
    print("len(origin_expression_record)", len(origin_expression_record))

    # Load gold_pair_record
    all_gene_list = []
    gold_pair_record = {}
    f_genePairList = open(gene_pair_list_path, encoding='UTF-8')  ### read the gene pair and label file

    for single_pair in list(csv.reader(f_genePairList))[1:]:
        print("single_pair",single_pair)
        if single_pair[2] == '1':
            if single_pair[0] not in gold_pair_record:
                gold_pair_record[single_pair[0]] = [single_pair[1]]
            else:
                gold_pair_record[single_pair[0]].append(single_pair[1])
            # count all genes in gold edges
            if single_pair[0] not in all_gene_list:
                all_gene_list.append(single_pair[0])
            if single_pair[1] not in all_gene_list:
                all_gene_list.append(single_pair[1])
    f_genePairList.close()
    # print dataset statistics
    print('All genes:' + str(len(all_gene_list)))
    print('TFs:' + str(len(gold_pair_record.keys())))
    print("len(single_pair)", len(single_pair))
    # Generate Pearson matrix
    label_list = []
    pair_list = []
    total_matrix = []
    num_tf = -1
    num_label1 = 0
    num_label0 = 0

    # control cell numbers by means of timepoints
    timepoints = len(cells)
    # timepoints=800
    x = []
    for i in gold_pair_record:
        num_tf += 1
        for j in range(len(all_gene_list)):
            # for j in range(2):
            print('Generating matrix of gene pair ' + str(num_tf) + ' ' + str(j))
            tf_name = i
            target_name = all_gene_list[j]

            flag = False
            if (origin_expression_record.__contains__(tf_name) & origin_expression_record.__contains__(target_name)):
                flag = True

            if (flag):
                if tf_name in gold_pair_record and target_name in gold_pair_record[tf_name]:
                    label = 1
                    num_label1 += 1
                else:
                    label = 0
                    num_label0 += 1
                label_list.append(label)
                pair_list.append(tf_name + ',' + target_name)

                tf_data = origin_expression_record[tf_name]
                target_data = origin_expression_record[target_name]
            else:
                miss = miss + 1
                continue

            single_tf_list = []
            gap = 100
            for k in range(0, len(tf_data), gap):
                feature = []
                a = tf_data[k:k + gap]
                b = target_data[k:k + gap]
                feature.extend(a)
                feature.extend(b)
                # single_tf_list.append(feature)
                feature = np.asarray(feature)
                # print("feature.shape", feature.shape)
                if (len(feature) == 2 * gap):
                    # print("feature.shape xixihaha", feature.shape)
                    single_tf_list.append(feature)

            single_tf_list = np.asarray(single_tf_list)

            total_matrix.append(single_tf_list)

    total_matrix = np.asarray(total_matrix)
    label_list = np.array(label_list)
    # print("label_list.shape", label_list.shape)
    pair_list = np.array(pair_list)

    np.save(resultPath + 'matrix.npy', total_matrix)
    np.save(resultPath + 'label.npy', label_list)
    np.save(resultPath + 'gene_pair.npy', pair_list)

    print('PCC matrix generation finish.')
    print('Positive edges:' + str(num_label1))
    print('Negative edges:' + str(num_label0))
    print('Density=' + str(num_label1 / (num_label1 + num_label0)))
##generating the data can be inputted by the STGRNS

dataset = "06hHep_ExpressionDataOrdered.csv"

known_network = "known_network.csv"
gene_pair_list_path = "Data/exampleData/" + known_network
gene_expression_path = "Data/exampleData/" + dataset
resultPath = "Data/exampleData/input/"
loadData(dataset,gene_pair_list_path,gene_expression_path,resultPath)


unknown_network = "unknown_network.csv"
gene_pair_list_path = "Data/exampleData/" + unknown_network
gene_expression_path = "Data/exampleData/" + dataset
resultPath = "Data/exampleData/input/"
loadData(dataset,gene_pair_list_path,gene_expression_path,resultPath)

Gene name 1-Mar repeat!
11515 genes and 425 cells are included in origin expression data.
len(origin_expression_record) 11514
single_pair ['KLF6', 'CRABP2', '0']
single_pair ['ID2', 'SCHIP1', '0']
single_pair ['FOS', 'DPCD', '1']
single_pair ['ZNF143', 'APEH', '0']
single_pair ['PTTG1', 'NIP7', '0']
single_pair ['ZNF143', 'MRPS27', '0']
single_pair ['NFYB', 'MCM3', '0']
single_pair ['NFYB', 'DPYSL2', '1']
single_pair ['KLF6', 'MT2A', '0']
single_pair ['OTX2', 'RUVBL1', '0']
single_pair ['MCM3', 'ZNF143', '0']
single_pair ['PTTG1', 'ZWILCH', '0']
single_pair ['MCM4', 'MCM3', '1']
single_pair ['MCM3', 'SERPINB1', '0']
single_pair ['FOS', 'CCBL2', '1']
single_pair ['CRABP2', 'GRPR', '0']
single_pair ['ZNF143', 'ATG5', '0']
single_pair ['KLF6', 'SMYD3', '0']
single_pair ['KLF6', 'DPYSL2', '0']
single_pair ['NFYB', 'MTFP1', '0']
single_pair ['MCM3', 'BUB1B', '0']
single_pair ['OTX2', 'TNFRSF12A', '0']
single_pair ['ZNF143', 'KNTC1', '0']
single_pair ['NFYB', 'DPCD', '0']
single_pair ['FOS',

single_pair ['FOS', 'KNTC1', '0']
single_pair ['OTX2', 'UTRN', '0']
single_pair ['RUVBL1', 'MCM3', '0']
single_pair ['NFYB', 'HAUS1', '0']
single_pair ['CRABP2', 'DKC1', '0']
single_pair ['OTX2', 'GSN', '0']
single_pair ['NFYB', 'TF', '0']
single_pair ['ID2', 'TNFRSF12A', '0']
single_pair ['MCM4', 'ASNS', '0']
single_pair ['ID3', 'ASNS', '0']
single_pair ['CRABP2', 'ZWILCH', '0']
single_pair ['FOS', 'GSN', '1']
single_pair ['PTTG1', 'PIM2', '0']
single_pair ['ID3', 'TM2D2', '0']
single_pair ['TGIF1', 'CAMK2D', '0']
single_pair ['PITX2', 'OTX2', '0']
single_pair ['ID3', 'DDIT3', '0']
single_pair ['ID2', 'ASNS', '0']
single_pair ['RUVBL1', 'ZWILCH', '0']
single_pair ['DDIT3', 'UBE2T', '0']
single_pair ['KNTC1', 'OTX2', '0']
single_pair ['TGIF1', 'GSN', '0']
single_pair ['DDIT3', 'NIF3L1', '0']
single_pair ['RUVBL1', 'DDIT3', '0']
single_pair ['KNTC1', 'NFYB', '0']
single_pair ['MCM7', 'PCCA', '0']
single_pair ['ID3', 'PCCA', '0']
single_pair ['MCM3', 'PITX2', '0']
single_pair ['MCM4', 'I

Generating matrix of gene pair 0 33
Generating matrix of gene pair 0 34
Generating matrix of gene pair 0 35
Generating matrix of gene pair 0 36
Generating matrix of gene pair 0 37
Generating matrix of gene pair 0 38
Generating matrix of gene pair 0 39
Generating matrix of gene pair 0 40
Generating matrix of gene pair 0 41
Generating matrix of gene pair 0 42
Generating matrix of gene pair 0 43
Generating matrix of gene pair 0 44
Generating matrix of gene pair 0 45
Generating matrix of gene pair 0 46
Generating matrix of gene pair 0 47
Generating matrix of gene pair 0 48
Generating matrix of gene pair 0 49
Generating matrix of gene pair 0 50
Generating matrix of gene pair 0 51
Generating matrix of gene pair 0 52
Generating matrix of gene pair 0 53
Generating matrix of gene pair 0 54
Generating matrix of gene pair 0 55
Generating matrix of gene pair 0 56
Generating matrix of gene pair 0 57
Generating matrix of gene pair 0 58
Generating matrix of gene pair 0 59
Generating matrix of gene pa

Generating matrix of gene pair 11 1
Generating matrix of gene pair 11 2
Generating matrix of gene pair 11 3
Generating matrix of gene pair 11 4
Generating matrix of gene pair 11 5
Generating matrix of gene pair 11 6
Generating matrix of gene pair 11 7
Generating matrix of gene pair 11 8
Generating matrix of gene pair 11 9
Generating matrix of gene pair 11 10
Generating matrix of gene pair 11 11
Generating matrix of gene pair 11 12
Generating matrix of gene pair 11 13
Generating matrix of gene pair 11 14
Generating matrix of gene pair 11 15
Generating matrix of gene pair 11 16
Generating matrix of gene pair 11 17
Generating matrix of gene pair 11 18
Generating matrix of gene pair 11 19
Generating matrix of gene pair 11 20
Generating matrix of gene pair 11 21
Generating matrix of gene pair 11 22
Generating matrix of gene pair 11 23
Generating matrix of gene pair 11 24
Generating matrix of gene pair 11 25
Generating matrix of gene pair 11 26
Generating matrix of gene pair 11 27
Generating

Generating matrix of gene pair 6 21
Generating matrix of gene pair 6 22
Generating matrix of gene pair 6 23
Generating matrix of gene pair 7 0
Generating matrix of gene pair 7 1
Generating matrix of gene pair 7 2
Generating matrix of gene pair 7 3
Generating matrix of gene pair 7 4
Generating matrix of gene pair 7 5
Generating matrix of gene pair 7 6
Generating matrix of gene pair 7 7
Generating matrix of gene pair 7 8
Generating matrix of gene pair 7 9
Generating matrix of gene pair 7 10
Generating matrix of gene pair 7 11
Generating matrix of gene pair 7 12
Generating matrix of gene pair 7 13
Generating matrix of gene pair 7 14
Generating matrix of gene pair 7 15
Generating matrix of gene pair 7 16
Generating matrix of gene pair 7 17
Generating matrix of gene pair 7 18
Generating matrix of gene pair 7 19
Generating matrix of gene pair 7 20
Generating matrix of gene pair 7 21
Generating matrix of gene pair 7 22
Generating matrix of gene pair 7 23
Generating matrix of gene pair 8 0
Gen

In [5]:
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import warnings
warnings.filterwarnings('ignore')
from sklearn import metrics
import os
import csv
import math
from torch.utils.data import (DataLoader)
torch.set_default_tensor_type(torch.DoubleTensor)

def numpy2loader(X, y, batch_size):
    X_set = torch.from_numpy(X)
    X_loader = DataLoader(X_set, batch_size=batch_size)
    y_set = torch.from_numpy(y)
    y_loader = DataLoader(y_set, batch_size=batch_size)

    return X_loader, y_loader

def loaderToList(data_loader):
    length = len(data_loader)
    data = []
    for i in data_loader:
        data.append(i)
    return data

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class STGRNS(nn.Module):
    def __init__(self, input_dim, nhead=2, d_model=80, num_classes=2, dropout=0.1):
        super().__init__()
        self.prenet = nn.Linear(input_dim, d_model)
        self.positionalEncoding = PositionalEncoding(d_model=d_model, dropout=dropout)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, dim_feedforward=256, nhead=2, dropout=dropout
        )

        self.pred_layer = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, num_classes),
        )

    def forward(self, window_size):
        out = window_size.permute(1, 0, 2)
        out = self.positionalEncoding(out)
        out = self.encoder_layer(out)
        out = out.transpose(0, 1)
        stats = out.mean(dim=1)
        out = self.pred_layer(stats)
        return out



def STGRNSForGRNSRconstruction(batch_sizes, epochs,known_data_path,unknown_data_path,num_threads):
    data_path = known_data_path
    d_models = epochs
    torch.set_num_threads(num_threads) #set num_threads
    batch_size = batch_sizes
    log_dir = "log/"
    if (not os.path.isdir(log_dir)):
        os.makedirs(log_dir)

    x_train = np.load(data_path + 'matrix.npy')
    y_train = np.load(data_path + 'label.npy')

    X_trainloader, y_trainloader = numpy2loader(x_train, y_train, batch_size)

    X_trainList = loaderToList(X_trainloader)
    y_trainList = loaderToList(y_trainloader)

    model = STGRNS(input_dim=200, nhead=2, d_model=d_models, num_classes=2)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)

    n_epochs = epochs
    acc_record = {'train': [], 'dev': []}
    loss_record = {'train': [], 'dev': []}
    model.train()
    for epoch in range(n_epochs):
        train_loss = []
        for j in range(0, len(X_trainList)):
            data = X_trainList[j]
            labels = y_trainList[j]
            logits = model(data)
            labels = torch.tensor(labels, dtype=torch.long)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            optimizer.step()
            acc = (logits.argmax(dim=-1) == labels).float().mean()
            train_loss.append(loss.item())
        train_loss = sum(train_loss) / len(train_loss)
        loss_record['train'].append(train_loss)

        print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}")

    ###predict-----------------------------------------------
    y_predict = []
    data_path = unknown_data_path
    x_test = np.load(data_path + 'matrix.npy')
    y_test = np.load(data_path + 'label.npy')

    X_testloader, y_testloader = numpy2loader(x_test, y_test, batch_size)

    X_testList = loaderToList(X_testloader)
    y_testList = loaderToList(y_testloader)

    model.eval()
    predictions = []
    for k in range(0, len(X_testList)):
        data = X_testList[k]
        with torch.no_grad():
            logits = model(data)
        predt = F.softmax(logits)
        temps = predt.cpu().numpy().tolist()
        for i in temps:
            t = i[1]
            y_predict.append(t)

    # print("y_predict", y_predict)
    fpr, tpr, thresholds = metrics.roc_curve(y_test, y_predict, pos_label=1)
    auc = metrics.auc(fpr, tpr)

    precision, recall, thresholds_PR = metrics.precision_recall_curve(y_test, y_predict)
    AUPR = metrics.auc(recall, precision)
    y_predict2 = []
    for pre in y_predict:
        if(pre >0.001):
            y_predict2.append(1)
        else:
            y_predict2.append(0)
    acc = metrics.accuracy_score(y_test, y_predict2)
    bacc = metrics.balanced_accuracy_score(y_test, y_predict2)
    f1 = metrics.f1_score(y_test, y_predict2)

    ##storing the predicted data
    np.save(log_dir + 'y_test.npy', y_test)
    np.save(log_dir + 'y_predict.npy', y_predict)
    
     ##storing the predicted network
    np.save(log_dir + 'y_predict2.npy', y_predict2)


##the data path of known data
data_path = 'Data/exampleData/input/known/'

##the data path of unknown data
unknown_data_path = 'Data/exampleData/input/unknow/'
num_threads = 1
##training model and then predicting unknown network
batch_sizes = 32
epochs = 200
STGRNSForGRNSRconstruction(batch_sizes,epochs,data_path,unknown_data_path,num_threads)

[ Train | 001/200 ] loss = 0.29910
[ Train | 002/200 ] loss = 0.32289
[ Train | 003/200 ] loss = 0.22787
[ Train | 004/200 ] loss = 0.20970
[ Train | 005/200 ] loss = 0.19196
[ Train | 006/200 ] loss = 0.17742
[ Train | 007/200 ] loss = 0.16641
[ Train | 008/200 ] loss = 0.18922
[ Train | 009/200 ] loss = 0.17154
[ Train | 010/200 ] loss = 0.13853
[ Train | 011/200 ] loss = 0.12771
[ Train | 012/200 ] loss = 0.12285



KeyboardInterrupt

