In [None]:
# !pip install dgl
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import scipy.sparse as sp


def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def Convert(tup, di):
    for a, b in tup:
        di.setdefault(a, []).append(b)
    return di




# def accuracy(output, labels):
#     preds = output.max(1)[1].type_as(labels)
#     correct = preds.eq(labels).double()
#     correct = correct.sum()
#     return correct / len(labels)


# def sparse_mx_to_torch_sparse_tensor(sparse_mx):
#     """Convert a scipy sparse matrix to a torch sparse tensor."""
#     sparse_mx = sparse_mx.tocoo().astype(np.float32)
#     indices = torch.from_numpy(
#         np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
#     values = torch.from_numpy(sparse_mx.data)
#     shape = torch.Size(sparse_mx.shape)
#     return torch.sparse.FloatTensor(indices, values, shape)

import math

import torch

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'


import torch.nn as nn
import torch.nn.functional as F
# import torch
import math
from torch.nn.parameter import Parameter
# from layers import GraphConvolution
import dgl
from dgl.nn.pytorch.conv import GraphConv, GATConv, SAGEConv, DenseGraphConv


def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(1))
    self.weight.data.uniform_(-stdv, stdv)


class RNNGCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(RNNGCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout

        self.Lambda = Parameter(torch.FloatTensor(1))
        self.Lambda.data.uniform_(0.2, 0.2)

    def forward(self, x, adj):
        # out=[]
        now_adj = adj[:, 0, :].clone()
        for i in range(1, adj.shape[1]):  # time_steps
            now_adj = (1 - self.Lambda) * now_adj + self.Lambda * adj[:, i, :]  # weight decay
        one_out = self.gc1(x[:, -1, :], now_adj)
        one_out = F.relu(one_out)

        one_out = F.dropout(one_out, self.dropout, training=self.training)
        one_out = self.gc2(one_out, now_adj)

        return F.log_softmax(one_out, dim=1)


class TRNNGCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, nnode, use_cuda=False):
        super(TRNNGCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout
        self.Lambda = Parameter(torch.FloatTensor(nclass, nclass))
        self.Lambda.data.uniform_(0.5, 0.5)
        self.use_cuda = use_cuda

        y = torch.randint(0, nclass, (nnode, 1)).flatten()

        if self.use_cuda:
            self.H = torch.zeros(nnode, nclass).cuda()
        else:
            self.H = torch.zeros(nnode, nclass)
        self.H[range(self.H.shape[0]), y] = 1

    def forward(self, x, adj):

        w = self.Lambda.data
        w = w.clamp(0, 1)
        self.Lambda.data = w
        if self.use_cuda:
            decay_adj = torch.mm(torch.mm(self.H, self.Lambda), self.H.T).cuda()
        else:
            decay_adj = torch.mm(torch.mm(self.H, self.Lambda), self.H.T)

        now_adj = adj[:, 0, :].clone()  # torch.zeros(adj.shape[0], adj.shape[2])
        for i in range(1, adj.shape[1]):  # time_steps
            now_adj = (1 - decay_adj) * now_adj + decay_adj * adj[:, i, :]
        del decay_adj
        one_out = F.relu(self.gc1(x[:, -1, :], now_adj))

        one_out = F.dropout(one_out, self.dropout, training=self.training)
        one_out = self.gc2(one_out, now_adj)
        output = F.log_softmax(one_out, dim=1)
        y = torch.argmax(output, dim=1)
        H_shape = self.H.shape
        del self.H
        del now_adj
        if self.use_cuda:
            self.H = torch.zeros(H_shape).cuda()
        else:
            self.H = torch.zeros(H_shape)
        self.H[range(H_shape[0]), y] = 1
        return output


class LSTMGCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(LSTMGCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout
        self.LS_begin = nn.LSTM(input_size=nfeat, hidden_size=nhid, num_layers=1, dropout=0.5, batch_first=True)

        self.nhid = nhid

    def forward(self, x, adj):
        adj = self.LS_begin(adj)
        x = F.relu(self.gc1(x[:, -1, :], adj[0][:, -1, :]))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj[0][:, -1, :])

        return F.log_softmax(x, dim=1)



# class MyGNN(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(MyGNN, self).__init__()

#         self.nhid = nhid
#         self.nclass = nclass
#         self.num_layers = 2
#         self.dropout = dropout

#         # self.gc1 = GraphConvolution(nfeat, nhid)
#         # print(self.gc1)
#         # self.gc2 = GraphConvolution(nhid, nhid)
#         # print(self.gc2)
#         # self.gc1=SAGEConv()

#         self.gc1 = SAGEConv(nfeat, nhid, aggregator_type='mean')
#         self.gc2 = SAGEConv(nhid, nhid, aggregator_type='mean')
#         # self.gc3 = GATConv(nhid, nhid, num_heads=1)
#         # self.gc1 = GraphConv(nfeat, nhid)
#         # self.gc2 = GraphConv(nhid, nhid)
#         # print(self.dropout)

#         self.LS_end = nn.LSTM(input_size=nhid, hidden_size=nclass, num_layers=2, batch_first=True, bidirectional=True)

#         # print(self.LS_end)
#         # print("\n.................\n")

#         # self.linear22 = nn.Linear(nclass, nclass) # actual
#         self.linear22 = nn.Linear(nhid * 2, nclass)
#         self.gc3 = GATConv(nclass, nclass, num_heads=1)

#         # self.Lambda = Parameter(torch.FloatTensor(1))
#         # self.Lambda.data.uniform_(0.2, 0.2)

#     def forward(self, x, adj):
#         # print(x.shape,adj.shape,sep="-------------")

#         out = []
#         # print(adj)
#         # count=0

#         # now_adj = adj[:, 0, :].clone()

#         # time_wise_attention = []
#         # all_time_rank = []
#         full_time_edges=set()
#         for i in range(0, adj.shape[1]):

#             # lol = adj[ :,i, :]
#             # print(lol,lol.shape)
#             # print(lol[0][1],lol[1][0])
#             # my_x = lol.cpu().detach().numpy()
#             # my_x=lol.reshape(lol.shape[0],lol.shape[1])
#             # print(my_x)
#             # print(my_x[0][1],my_x[1][0])

#             # b = my_x.transpose()
#             # if np.allclose(my_x, b, rtol=0, atol=0):
#             #     print("The array is Symmetric")
#             # else:
#             #     print("The array is NOT Symmetric")
#             # break

#             # kkk = np.diagonal(my_x).copy()
#             # print(kkk, kkk.shape)
#             # print(len(set(kkk)))


#             # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#             # count +=1
#             # print(adj[:, i, :],adj[:, i, :].shape,sep="&&&&&&&&&&&&&&&&")

#             # now_adj = (1 - self.Lambda) * now_adj + self.Lambda * adj[:, i, :]  # weight decay
#             # adj2 = dgl.from_networkx(nx.Graph(now_adj.cpu().detach().numpy()))
#             # adj2 = dgl.add_self_loop(adj2)

#             # print(i,type(adj))


#             adj2 = dgl.from_networkx(nx.Graph(adj[:, i, :].numpy()))
#             # # print(i,type(adj2))


#             adj2 = dgl.add_self_loop(adj2)

#             # print("Number of Edges = ",str(len(adj2.edges()[0])))
#             # p=adj2.edges()
#             # q = list(zip(p[0], p[1]))
#             # q= sorted([(int(x),int(y)) for (x,y) in q])
#             # print(q)

#             # print("\nTime: ",str(i),", Nodes: ",str(adj2.number_of_nodes()),", Edges: ",str(adj2.number_of_edges()),"\n")

#             a_list = list(zip([int(x) for x in adj2.edges()[0]], [int(x) for x in adj2.edges()[1]]))

#             h0 = torch.zeros(self.num_layers * 2, x[:, i, :].size(0), self.nhid)  # 2 for bidirection
#             c0 = torch.zeros(self.num_layers * 2, x[:, i, :].size(0), self.nhid)

#             one_out = F.relu(self.gc1(adj2, x[:, i, :]))
#             # print(one_out.shape)
#             one_out = F.dropout(one_out, self.dropout, training=self.training)
#             # print(one_out.shape)
#             one_out = self.gc2(adj2, one_out)
#             one_out = F.dropout(one_out, self.dropout, training=self.training)
#             # one_out, yy = self.gc3(adj2, one_out, get_attention=True)
#             one_out = self.gc3(adj2, one_out, get_attention=False)

#             # weight_dict = list(zip(a_list, [float(x) for x in yy]))
#             # # print(weight_dict)
#             # # print(yy,yy.shape)

#             # # length = len(weight_dict)
#             # # middle_index = length//2
#             # # first_half = weight_dict[:middle_index]

#             # # print(weight_dict)
#             # # print(first_half)
#             # new_dict = {}
#             # for i2, e2 in enumerate(weight_dict):
#             #     # print(i,e)
#             #     if(e2[0][0]==e2[0][1]):
#             #         continue
#             #     new_dict[e2[0]] = e2[1]


#             # # print("Number of Dict = ",str(len(new_dict)))
#             # # print(new_dict)
#             # r=sorted(list(new_dict.keys()))
#             # # print(len(r))
#             # ww=list()
#             # another_dict={}
#             # for x2 in r:
#             #     if((x2[0],x2[1]) in r):
#             #         # continue
#             #         if(x2[0]<x2[1]):
#             #             avg_new=(new_dict[(x2[1],x2[0])]+new_dict[(x2[0],x2[1])])/2
#             #             another_dict[(x2[0],x2[1])]=avg_new
#             #             full_time_edges.add((x2[0],x2[1]))
#             #         # del new_dict[(x2[1],x2[0])]
#             #     # ww.append(x2)

#             # # print(len(another_dict),another_dict)
#             # haha={k: v for k, v in sorted(another_dict.items(), key=lambda item: item[1],reverse=True)}
#             # # print(haha)
#             # edge_rank=zip(list(haha),range(len(haha)))


#             # haha_dictionary = {}
#             # all_time_rank.append(Convert(edge_rank, haha_dictionary))

#             # # print(list(set(q) - set(r)))



#             # time_wise_attention.append(new_dict)
#             # yy = yy.reshape(1, -1)
#             # yy = yy.squeeze()
#             # print(yy,yy.shape)
#             # print(one_out.shape)
#             one_out = one_out.reshape(one_out.shape[0], one_out.shape[2])
#             # print(one_out.shape)
#             # xx,yy= self.conv4(adj, x, get_attention=True)
#             # yy = yy.reshape(1, -1)
#             # yy = yy.squeeze()
#             # print(yy,yy.shape)
#             # print("\n***************************\n")
#             # one_out = self.linear()
#             # print(str(count)+" "+"Inside Loop: "+str(one_out.shape)+"\n")
#             # break
#             out += [one_out]
#             # print("\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n")
#         # print(type(out))
#         # print(len(out))
#         # print(out[0].shape)
#         # print(count)
#         # yy_dict={}
#         # for i5,e5 in enumerate(list(full_time_edges)):
#         #
#         #     for i4,e4 in enumerate(all_time_rank):
#         #         if(e5 in e4):
#         #             yy_dict[e5] += str(e4[e5][0])
#         #             # print(e5,e4[e5][0])
#         #         else:
#         #             yy_dict[e5] = ""
#         # print(yy_dict)

#         # print(full_time_edges)
#         # all_time_rank=all_time_rank.reverse()
#         # for i4,e4 in enumerate(all_time_rank):

#         #     print(i4,e4)
#         out = torch.stack(out, 1)
#         # print(type(out))

#         # print(len(time_wise_attention))

#         # print(self.LS_end(out)[0][:, -1, :]) #taking the last hidden state

#         # print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
#         # out, _ = self.lstm(x, (h0, c0))
#         out = self.LS_end(out, (h0, c0))[0][:, -1, :]

#         out = self.linear22(out)
#         # print(out.shape)
#         # out,yy = self.gc3(adj2,out,get_attention=True)
#         # yy = yy.reshape(1, -1)
#         # yy = yy.squeeze()
#         # print(yy,yy.shape)
#         # print(out.shape)
#         # out = out.reshape(out.shape[0], out.shape[2])

#         # print(time_wise_attention)

#         # print("\n-------------------\n")
#         return F.log_softmax(out, dim=1)


class MyGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(MyGNN, self).__init__()

        self.nhid = nhid
        self.nclass = nclass
        self.num_layers=2
        self.dropout = dropout

        # self.gc1 = GraphConvolution(nfeat, nhid)
        # print(self.gc1)
        # self.gc2 = GraphConvolution(nhid, nhid)
        # print(self.gc2)
        # self.gc1=SAGEConv()

        self.gc1 = SAGEConv(nfeat, nhid, aggregator_type='mean')
        self.gc2 = SAGEConv(nhid, nhid, aggregator_type='mean')
        # self.gc3 = GATConv(nhid, nhid, num_heads=1)
        # self.gc1 = GraphConv(nfeat, nhid)
        # self.gc2 = GraphConv(nhid, nhid)
        # print(self.dropout)

        self.LS_end = nn.LSTM(input_size=nhid, hidden_size=nclass, num_layers=2, batch_first=True,bidirectional=True)
        
        # print(self.LS_end)
        # print("\n.................\n")

        # self.linear22 = nn.Linear(nclass, nclass) # actual
        self.linear22 = nn.Linear(nhid*2, nclass)
        # self.gc3 = GATConv(nhid, nclass, num_heads=1)

        self.gc3 = GATConv(nhid, nhid, num_heads=1)


        # self.Lambda = Parameter(torch.FloatTensor(1))
        # self.Lambda.data.uniform_(0.2, 0.2)

    def forward(self, x, adj):
        # print(x.shape,adj.shape,sep="-------------")


        out = []
        # print(adj)
        # count=0

        now_adj = adj[:, 0, :].clone()

        time_wise_attention=[]

        

        for i in range(1,adj.shape[1]):
            # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
            # count +=1
            # print(adj[:, i, :],adj[:, i, :].shape,sep="&&&&&&&&&&&&&&&&")



            # now_adj = (1 - self.Lambda) * now_adj + self.Lambda * adj[:, i, :]  # weight decay
            # adj2 = dgl.from_networkx(nx.Graph(now_adj.cpu().detach().numpy()))
            # adj2 = dgl.add_self_loop(adj2)


            adj2 = dgl.from_networkx(nx.Graph(adj[:, i, :].numpy()))





            adj2 = dgl.add_self_loop(adj2)

            
            # print("\nTime: ",str(i),", Nodes: ",str(adj2.number_of_nodes()),", Edges: ",str(adj2.number_of_edges()),"\n")


            # a_list=list(zip([int(x) for x in adj2.edges()[0]],[int(x) for x in adj2.edges()[1]]))


            h0 = torch.zeros(self.num_layers*2, x[:, i, :].size(0), self.nhid) # 2 for bidirection 
            c0 = torch.zeros(self.num_layers*2, x[:, i, :].size(0), self.nhid)



            # one_out = F.relu(self.gc3(adj2, x[:, i, :]))  #different from self-defined gcn
            # one_out=one_out.reshape(one_out.shape[0],one_out.shape[2])


            one_out = F.relu(self.gc1(adj2, x[:, i, :]))
            # print(one_out.shape)
            one_out = F.dropout(one_out, self.dropout, training=self.training)
            # # print(one_out.shape)
            one_out = self.gc2(adj2,one_out)
            one_out = F.dropout(one_out, self.dropout, training=self.training)


            one_out,yy = self.gc3(adj2,one_out,get_attention=True)

            # weight_dict=list(zip(a_list,[float(x) for x in yy]))
            # print(weight_dict)
            # print(yy,yy.shape)

            # length = len(weight_dict)
            # middle_index = length//2
            # first_half = weight_dict[:middle_index]

            # print(weight_dict)
            # print(first_half)
            # new_dict={}
            # for i2,e2 in enumerate(weight_dict):
             
            #   new_dict[e2[0]]=e2[1]
        
            # time_wise_attention.append(new_dict)
            # yy = yy.reshape(1, -1)
            # yy = yy.squeeze()
            # print(yy,yy.shape)
            # print(one_out.shape)
            one_out = one_out.reshape(one_out.shape[0], one_out.shape[2])
            # print(one_out.shape)
            # xx,yy= self.conv4(adj, x, get_attention=True)
            # yy = yy.reshape(1, -1)
            # yy = yy.squeeze()
            # print(yy,yy.shape)
            # print("\n***************************\n")
            # one_out = self.linear()
            # print(str(count)+" "+"Inside Loop: "+str(one_out.shape)+"\n")
            # break
            out += [one_out]
        # print(type(out))
        # print(len(out))
        # print(out[0].shape)
        # print(count)
        out = torch.stack(out, 1)
        # print(type(out))


        # print(self.LS_end(out)[0][:, -1, :]) #taking the last hidden state

        # print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
                # out, _ = self.lstm(x, (h0, c0))
        out = self.LS_end(out,(h0, c0))[0][:, -1, :]
        
        out=self.linear22(out)
        # print(out.shape)
        # out,yy = self.gc3(adj2,out,get_attention=True)
        # yy = yy.reshape(1, -1)
        # yy = yy.squeeze()
        # print(yy,yy.shape)
        # print(out.shape)
        # out = out.reshape(out.shape[0], out.shape[2])

        # print(time_wise_attention)


        # print("\n-------------------\n")
        return F.log_softmax(out, dim=1)


# class GCN(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(GCN, self).__init__()

#         self.gc1 = GraphConvolution(nfeat, nhid)
#         self.gc2 = GraphConvolution(nhid, nclass)
#         self.dropout = dropout

#     def forward(self, x, adj):
#         x = F.relu(self.gc1(x, adj))
#         x = F.dropout(x, self.dropout, training=self.training)
#         x = self.gc2(x, adj)

#         return F.log_softmax(x, dim=1)


# class GAT(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(GAT, self).__init__()
#         self.dropout = dropout
#         self.conv1 = GATConv(nfeat, nhid, num_heads=1)
#         self.conv2 = GATConv(nhid, nclass, num_heads=1)


#     def forward(self, x, adj):
#         # Use node degree as the initial node feature. For undirected graphs, the in-degree
#         # is the same as the out_degree.
#         # Perform graph convolution and activation function.

#         x = F.relu(self.conv1(adj, x))  #different from self-defined gcn
#         x=x.reshape(x.shape[0],x.shape[2])
#         x = F.dropout(x, self.dropout, training=self.training)
#         x = self.conv2(adj, x)
#         x=x.reshape(x.shape[0],x.shape[2])

#         return F.log_softmax(x, dim=1)

# class GAT(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(GAT, self).__init__()
#         self.dropout = dropout
#         self.conv1 = GATConv(nfeat, nhid, num_heads=1)
#         self.conv2 = GATConv(nhid, nclass, num_heads=1)

#     def forward(self, x, adj):
#         # Use node degree as the initial node feature. For undirected graphs, the in-degree
#         # is the same as the out_degree.
#         # Perform graph convolution and activation function.
#         print("\n!!!!!!!!!!!!!!!!!! HELLO !!!!!!!!!!!!!!!!!!!!!\n")
#         print(x.shape)
#         x = F.relu(self.conv1(adj, x))  # different from self-defined gcn
#         print("\nAfter Conv1:\n")
#         print(x.shape)
#         print(x.shape[0],x.shape[2])
#         x = x.reshape(x.shape[0], x.shape[2])
#         print("\nAfter 1st Reshape:\n")
#         print(x.shape)
#         x = F.dropout(x, self.dropout, training=self.training)
#         print("\nAfter 1st Dropout\n")
#         print(x.shape)
        
#         x = self.conv2(adj, x)
#         print("\nAfter Conv2\n")
#         print(x.shape)
#         x = x.reshape(x.shape[0], x.shape[2])
#         print("\nAfter 2nd Reshape\n")
#         print(x.shape)
#         my_res=F.log_softmax(x, dim=1)
#         print("\nReturning log_softmax\n")
#         print(my_res.shape)
#         print("\n===========END======================\n")


#         return F.log_softmax(x, dim=1)


# class GraphSage(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(GraphSage, self).__init__()
#         self.dropout = dropout
#         self.conv1 = SAGEConv(nfeat, nhid, aggregator_type='mean')
#         self.conv2 = SAGEConv(nhid, nclass, aggregator_type='mean')
#         # self.conv3 = GATConv(nclass, 4, num_heads=8)
#         # self.conv4 = GATConv(4, 20, num_heads=8)

#     def forward(self, x, adj):
#         # Use node degree as the initial node feature. For undirected graphs, the in-degree
#         # is the same as the out_degree.
#         # Perform graph convolution and activation function.
#         x = F.relu(self.conv1(adj, x))  # different from self-defined gcn
#         # x = F.dropout(x, self.dropout, training=self.training)
#         x = self.conv2(adj, x)
#         # x = self.conv3(adj, x)
#         # x = self.conv4(adj, x)

#         return F.log_softmax(x, dim=1)

# class GraphSage_BiLSTM_GAT(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(GraphSage_BiLSTM_GAT, self).__init__()

#         self.dropout = dropout

#         # self.conv1 = GATConv(nfeat, nhid, num_heads=1)
#         # self.conv2 = GATConv(nhid, nclass, num_heads=1)
#         # self.conv3 = GATConv(nclass, nclass, num_heads=1)
#         # self.conv4 = GATConv(nclass, nclass, num_heads=1)
#         # self.conv5 = GATConv(nclass, nclass, num_heads=1)

#         self.conv1 = SAGEConv(nfeat, nhid, aggregator_type='mean')
#         self.conv2 = SAGEConv(nhid, nclass, aggregator_type='mean')
#         # self.conv3 = SAGEConv(nhid, nhid, aggregator_type='mean')
#         # self.conv4 = SAGEConv(nhid, nhid, aggregator_type='mean')
#         # self.conv5 = SAGEConv(nhid, nhid, aggregator_type='mean')

#         self.conv6 = GATConv(nhid, nclass, num_heads=1)
#         # self.conv7 = GATConv(nhid, nhid, num_heads=1)
#         # self.conv6 = GATConv(nhid, nhid, num_heads=1)


#         # self.conv8 = GATConv(nhid, nhid, num_heads=16)

#         self.LS_end = nn.LSTM(input_size=nclass, hidden_size=nclass, num_layers=8, dropout=dropout, batch_first=True,
#                               bidirectional=True)


#         # self.conv3 = SAGEConv(nclass, nclass, aggregator_type='mean')
#         self.conv9 = GATConv(nhid, nclass, num_heads=1)
#         # self.conv5 = GATConv(nclass, nclass, num_heads=1)
#         # self.conv6 = GATConv(nclass, nclass, num_heads=1)
#         # self.conv4 = SAGEConv(nclass, nclass, aggregator_type='mean')

#         # self.conv2 = SAGEConv(nhid, nclass, aggregator_type='mean')

#         # self.conv3 = SAGEConv(nclass, nclass, aggregator_type='mean')

#         # self.conv4 = SAGEConv(nclass, nclass, aggregator_type='mean')

#         # self.conv5 = SAGEConv(nclass, nclass, aggregator_type='mean')

#         # self.conv3 = GATConv(nclass, nclass, num_heads=16)

#         # self.LS_end = nn.LSTM(input_size=nclass, hidden_size=nclass, num_layers=8, dropout=dropout, batch_first=True,
#         #                       bidirectional=True)

#         # self.conv4 = GATConv(nfeat, nclass, num_heads=16)
#         # self.conv2 = GATConv(nhid, nclass, num_heads=1)

#         # self.dropout = dropout

#     def forward(self, x, adj):
#         # xx,yy= self.conv4(adj, x, get_attention=True)
#         # yy = yy.reshape(1, -1)
#         # yy = yy.squeeze()
#         # print(yy,yy.shape)

#         x = F.relu(self.conv1(adj, x))  # different from self-defined gcn
#         x = F.dropout(x, self.dropout, training=self.training)
#         x = self.conv2(adj, x)
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # x = self.conv3(adj, x)
#         # print(x.shape)
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # x = self.conv4(adj, x)
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # x = self.conv5(adj, x)
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # x = F.relu(self.conv1(adj, x))  # different from self-defined gcn
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # x = self.conv2(adj, x)
#         # x = self.conv3(adj, x)
#         # x = self.conv4(adj, x)
#         # x = self.conv5(adj, x)

#         # x = F.relu(self.conv1(adj, x))  # different from self-defined gcn
#         # x = x.reshape(x.shape[0], x.shape[2])
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # x = self.conv6(adj, x)
#         # x = x.reshape(x.shape[0], x.shape[2])
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # x = self.conv3(adj, x)
#         # x = F.dropout(x, self.dropout, training=self.training)
#         # # x = self.conv4(adj, x)

#         # x = self.conv3(adj, x)
#         # # x = x.reshape(x.shape[0], x.shape[2])
#         # # x = self.conv4(adj, x)
#         # # # x = x.reshape(x.shape[0], x.shape[2])
#         # # x = self.conv5(adj, x)
#         # # x = x.reshape(x.shape[0], x.shape[2])
#         # x = self.conv6(adj, x)
#         # x = x.reshape(x.shape[0], x.shape[2])
#         # x = self.conv7(adj, x)
#         # x = x.reshape(x.shape[0], x.shape[2])
#         # x = self.conv8(adj, x)
#         # x = x.reshape(x.shape[0], x.shape[2])
#         # x = self.conv9(adj, x)
#         # x = x.reshape(x.shape[0], x.shape[2])
#         # print(x.shape)
#         # print(x[0])
#         # print(x[1])
#         # x = self.LS_end(x)
#         # x = x.reshape(x.shape[0], x.shape[2])

#         return F.log_softmax(x, dim=1)

# class GraphSage_BiLSTM_GAT(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(GraphSage_BiLSTM_GAT, self).__init__()
#         self.dropout = dropout

#         self.conv1 = SAGEConv(nfeat, nhid, aggregator_type='mean')
#         self.conv2 = SAGEConv(nhid, nhid, aggregator_type='mean')
#         self.conv3 = GATConv(nhid, nhid, num_heads=1)
#         self.LS_end = nn.LSTM(input_size=nclass, hidden_size=nclass, num_layers=8, dropout=dropout, batch_first=True,
#                               bidirectional=True)
#         # self.conv4 = GATConv(nclass, nclass, num_heads=16)
#         # self.dropout = dropout

#     def forward(self, x, adj):

#         x = F.relu(self.conv1(adj, x))  # different from self-defined gcn
#         x = F.dropout(x, self.dropout, training=self.training)
#         x = self.conv2(adj, x)
#         x = F.dropout(x, self.dropout, training=self.training)
#         # x = self.conv3(adj, x)
#         # x = x.reshape(x.shape[0], x.shape[2])

#         return F.log_softmax(x, dim=1)
# egcn


# class GGGGG(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(GGGGG, self).__init__()

#         # self.gc1 = GraphConvolution(nfeat, nhid)
#         # self.gc2 = GraphConvolution(nhid, nhid)
#         self.dropout = dropout


#         self.gc1 = SAGEConv(nfeat, nhid, aggregator_type='mean')
#         self.gc2 = SAGEConv(nhid, nclass, aggregator_type='mean')


#         self.LS_end = nn.LSTM(input_size=nhid, hidden_size=nclass, num_layers=2, dropout=0.5,
#                               batch_first=True)
#         self.nhid = nhid
#         self.nclass = nclass
#         self.linear = nn.Linear(nclass, nclass)

#     def forward(self, x, adj):
#         out = []
#         for i in range(adj.shape[1]):
#             print(adj[:,i,:],adj[:,i,:].shape,sep="****")
#             one_out = F.relu(self.gc1(adj[:, i, :], x[:, i, :]))
#             one_out = F.dropout(one_out, self.dropout, training=self.training)
#             one_out = self.gc2(adj[:, i, :],one_out)
#             out += [one_out]
#         # print(out)
#         # print(len(out))
#         out = torch.stack(out, 1)
#         print("\n-------------\n")
#         # print(out)
#         # print(len(out))

#         out = self.LS_end(out)[0][:, -1, :]

#         return F.log_softmax(out, dim=1)

class Namespace(object):
    '''
    helps referencing object in a dictionary as dict.key instead of dict['key']
    '''

    def __init__(self, adict):
        self.__dict__.update(adict)


def pad_with_last_val(vect, k):
    device = 'cuda' if vect.is_cuda else 'cpu'
    pad = torch.ones(k - vect.size(0),
                     dtype=torch.long,
                     device=device) * vect[-1]
    vect = torch.cat([vect, pad])
    return vect


# class GAT_BiLSTM(nn.Module):
#     def __init__(self, nfeat, nhid, nclass, dropout):
#         super(GAT_BiLSTM, self).__init__()
#         self.dropout = dropout

#         # self.LS_end = nn.LSTM(input_size=nhid, hidden_size=nclass, num_layers=2, dropout=0.5,
#         #                       batch_first=True, bidirectional=True)
#         self.conv1 = GATConv(nfeat, nhid, num_heads=1)
#         self.conv2 = GATConv(nhid, nclass, num_heads=1)

#     def forward(self, x, adj):
#         # Use node degree as the initial node feature. For undirected graphs, the in-degree
#         # is the same as the out_degree.
#         # Perform graph convolution and activation function.

#         x = F.relu(self.conv1(adj, x))  # different from self-defined gcn
#         print(x.shape)
#         x = x.reshape(x.shape[0], x.shape[2])
#         print(x.shape)
#         print("\n------------Hello-----------\n")
#         x = F.dropout(x, self.dropout, training=self.training)
#         x = self.conv2(adj, x)
#         x = x.reshape(x.shape[0], x.shape[2])

#         return F.log_softmax(x, dim=1)



class GCNLSTM(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCNLSTM, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nhid)
        self.dropout = dropout
        
        self.LS_end=nn.LSTM(input_size=nhid, hidden_size=nclass, num_layers=2, dropout=0.5,
                            batch_first=True)
        self.nhid=nhid
        self.nclass=nclass
        self.linear=nn.Linear(nclass, nclass)
        
    def forward(self, x, adj):
        out=[]
        for i in range(adj.shape[1]):
            one_out=F.relu(self.gc1(x[:,i,:],adj[:,i,:]))
            one_out = F.dropout(one_out, self.dropout, training=self.training)
            one_out = self.gc2(one_out, adj[:,i,:])
            out+=[one_out]
        out = torch.stack(out, 1)   
        out=self.LS_end(out)[0][:,-1,:]

        return F.log_softmax(out, dim=1)










class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout
        
        
    def forward(self, x, adj):
        

        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)

        

        return F.log_softmax(x, dim=1)


    
class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GAT, self).__init__()
        self.dropout = dropout
        self.conv1 = GATConv(nfeat, nhid, num_heads=1)
        self.conv2 = GATConv(nhid, nclass, num_heads=1)
        

    def forward(self, x, adj):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        # Perform graph convolution and activation function.
        
        x = F.relu(self.conv1(adj, x))  #different from self-defined gcn
        x=x.reshape(x.shape[0],x.shape[2])
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.conv2(adj, x)
        x=x.reshape(x.shape[0],x.shape[2])
        
        return F.log_softmax(x, dim=1)

    
class GraphSage(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GraphSage, self).__init__()
        self.dropout = dropout
        self.conv1 = SAGEConv(nfeat, nhid,aggregator_type='mean')
        self.conv2 = SAGEConv(nhid, nhid,aggregator_type='mean')
        # self.conv3 = GATConv(nhid, nclass,num_heads=1)
        
        

    def forward(self, x, adj):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        # Perform graph convolution and activation function.
        x = F.relu(self.conv1(adj, x))  #different from self-defined gcn
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.conv2(adj, x)
        # x = F.dropout(x, self.dropout, training=self.training)
        # x = self.conv3(adj, x)
        # x=x.reshape(x.shape[0],x.shape[2])
        
        return F.log_softmax(x, dim=1)
    
class GCNLSTM(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCNLSTM, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nhid)
        self.dropout = dropout
        
        self.LS_end=nn.LSTM(input_size=nhid, hidden_size=nclass, num_layers=2, dropout=0.5,
                            batch_first=True)
        self.nhid=nhid
        self.nclass=nclass
        self.linear=nn.Linear(nclass, nclass)
        
    def forward(self, x, adj):
        out=[]
        for i in range(adj.shape[1]):
            one_out=F.relu(self.gc1(x[:,i,:],adj[:,i,:]))
            one_out = F.dropout(one_out, self.dropout, training=self.training)
            one_out = self.gc2(one_out, adj[:,i,:])
            out+=[one_out]
        out = torch.stack(out, 1)   
        out=self.LS_end(out)[0][:,-1,:]

        return F.log_softmax(out, dim=1)
 
    
    
#egcn
    
class Namespace(object):
    '''
    helps referencing object in a dictionary as dict.key instead of dict['key']
    '''
    def __init__(self, adict):
        self.__dict__.update(adict)
        
def pad_with_last_val(vect,k):
    device = 'cuda' if vect.is_cuda else 'cpu'
    pad = torch.ones(k - vect.size(0),
                         dtype=torch.long,
                         device = device) * vect[-1]
    vect = torch.cat([vect,pad])
    return vect




#only use EGCN

class EGCN(torch.nn.Module): #egcn_o
    def __init__(self, nfeat, nhid, nclass, device='cpu', skipfeats=False):
        super().__init__()
        GRCU_args = Namespace({})

        feats = [nfeat,
                 nhid,
                 nhid]
        self.device = device
        self.skipfeats = skipfeats
        self.GRCU_layers = []
        self._parameters = nn.ParameterList()
        self.mlp = torch.nn.Sequential(torch.nn.Linear(in_features = nhid,out_features = nhid),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(in_features = nhid,out_features = nclass))
        for i in range(1,len(feats)):
            GRCU_args = Namespace({'in_feats' : feats[i-1],
                                     'out_feats': feats[i],
                                     'activation': torch.nn.RReLU()})

            grcu_i = GRCU(GRCU_args)
            #print (i,'grcu_i', grcu_i)
            self.GRCU_layers.append(grcu_i.to(self.device))
            self._parameters.extend(list(self.GRCU_layers[-1].parameters()))
        
    def parameters(self):
        return self._parameters

    def forward(self,Nodes_list, A_list):#,nodes_mask_list):
        node_feats= Nodes_list[-1]
        for unit in self.GRCU_layers:
            Nodes_list = unit(A_list,Nodes_list)#,nodes_mask_list)

        out = Nodes_list[-1]
        if self.skipfeats:
            out = torch.cat((out,node_feats), dim=1)   # use node_feats.to_dense() if 2hot encoded input 
       
        
        return F.log_softmax(self.mlp(out), dim=1)

class GRCU(torch.nn.Module):
    def __init__(self,args):
        super().__init__()
        self.args = args
        cell_args = Namespace({})
        cell_args.rows = args.in_feats
        cell_args.cols = args.out_feats

        self.evolve_weights = mat_GRU_cell(cell_args)  

        self.activation = self.args.activation
        self.GCN_init_weights = Parameter(torch.Tensor(self.args.in_feats,self.args.out_feats))
        self.reset_param(self.GCN_init_weights)

    def reset_param(self,t):
        #Initialize based on the number of columns
        stdv = 1. / math.sqrt(t.size(1))
        t.data.uniform_(-stdv,stdv)

    def forward(self,A_list,node_embs_list):#,mask_list):
        GCN_weights = self.GCN_init_weights
        out_seq = []
        for t,Ahat in enumerate(A_list):
            node_embs = node_embs_list[t]
            #first evolve the weights from the initial and use the new weights with the node_embs
            GCN_weights = self.evolve_weights(GCN_weights)#,node_embs,mask_list[t])
            node_embs = self.activation(Ahat.matmul(node_embs.matmul(GCN_weights)))

            out_seq.append(node_embs)

        return out_seq

class mat_GRU_cell(torch.nn.Module):
    def __init__(self,args):
        super().__init__()
        self.args = args
        self.update = mat_GRU_gate(args.rows,
                                   args.cols,
                                   torch.nn.Sigmoid())

        self.reset = mat_GRU_gate(args.rows,
                                   args.cols,
                                   torch.nn.Sigmoid())

        self.htilda = mat_GRU_gate(args.rows,
                                   args.cols,
                                   torch.nn.Tanh())
        
        self.choose_topk = TopK(feats = args.rows,
                                k = args.cols)

    def forward(self,prev_Q):#,prev_Z,mask):     ###Same as GCNH
        # z_topk = self.choose_topk(prev_Z,mask)
        z_topk = prev_Q
        update = self.update(z_topk,prev_Q)
        reset = self.reset(z_topk,prev_Q)

        h_cap = reset * prev_Q
        h_cap = self.htilda(z_topk, h_cap)

        new_Q = (1 - update) * prev_Q + update * h_cap
        return new_Q

        

class mat_GRU_gate(torch.nn.Module):
    def __init__(self,rows,cols,activation):
        super().__init__()
        self.activation = activation
        #the k here should be in_feats which is actually the rows
        self.W = Parameter(torch.Tensor(rows,rows))
        self.reset_param(self.W)

        self.U = Parameter(torch.Tensor(rows,rows))
        self.reset_param(self.U)

        self.bias = Parameter(torch.zeros(rows,cols))

    def reset_param(self,t):
        #Initialize based on the number of columns
        stdv = 1. / math.sqrt(t.size(1))
        t.data.uniform_(-stdv,stdv)

    def forward(self,x,hidden):
        out = self.activation(self.W.matmul(x) + \
                              self.U.matmul(hidden) + \
                              self.bias)

        return out

class TopK(torch.nn.Module):
    def __init__(self,feats,k):
        super().__init__()
        self.scorer = Parameter(torch.Tensor(feats,1))
        self.reset_param(self.scorer)
        
        self.k = k

    def reset_param(self,t):
        #Initialize based on the number of rows
        stdv = 1. / math.sqrt(t.size(0))
        t.data.uniform_(-stdv,stdv)

    def forward(self,node_embs,mask):
        scores = node_embs.matmul(self.scorer) / self.scorer.norm()
        scores = scores + mask

        vals, topk_indices = scores.view(-1).topk(self.k)
        topk_indices = topk_indices[vals > -float("Inf")]

        if topk_indices.size(0) < self.k:
            topk_indices = u.pad_with_last_val(topk_indices,self.k)
            
        tanh = torch.nn.Tanh()

        if isinstance(node_embs, torch.sparse.FloatTensor) or \
           isinstance(node_embs, torch.cuda.sparse.FloatTensor):
            node_embs = node_embs.to_dense()

        out = node_embs[topk_indices] * tanh(scores[topk_indices].view(-1,1))

        #we need to transpose the output
        return out.t()
        
import time
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
import random
# from utils import encode_onehot
# from models import MyGNN,GCN,GAT,GraphSage,EGCN,LSTMGCN,RNNGCN,TRNNGCN

# import tensorflow
# from dynamicgem.embedding.dynAERNN import DynAERNN

import dgl

import scipy as sp
import scipy.linalg as linalg
import networkx as nx
import matplotlib.pyplot as plt
from scipy.cluster.vq import kmeans, vq
from scipy import stats

from sklearn.cluster import SpectralClustering
from sklearn import metrics

from itertools import permutations

try:
    import google.colab

    IN_COLAB = True
except:
    IN_COLAB = False


def getNormLaplacian(W):
    """input matrix W=(w_ij)
    "compute D=diag(d1,...dn)
    "and L=D-W
    "and Lbar=D^(-1/2)LD^(-1/2)
    "return Lbar
    """
    d = [np.sum(row) for row in W]
    D = np.diag(d)
    L = D - W
    Dn = np.power(np.linalg.matrix_power(D, -1), 0.5)
    Lbar = np.dot(np.dot(Dn, L), Dn)
    return Lbar


def getKlargestEigVec(Lbar, k):
    """input
    "matrix Lbar and k
    "return
    "k largest eigen values and their corresponding eigen vectors
    """
    eigval, eigvec = linalg.eig(Lbar)
    dim = len(eigval)

    # find top k largest eigval
    dictEigval = dict(zip(eigval, range(0, dim)))
    kEig = np.sort(eigval)[::-1][:k]  # [0:k]
    ix = [dictEigval[k] for k in kEig]
    return eigval[ix], eigvec[:, ix]


def getKlargestSigVec(Lbar, k):
    """input
    "matrix Lbar and k
    "return
    "k largest singular values and their corresponding eigen vectors
    """
    lsigvec, sigval, rsigvec = linalg.svd(Lbar)
    dim = len(sigval)

    # find top k largest left sigval
    dictSigval = dict(zip(sigval, range(0, dim)))
    kSig = np.sort(sigval)[::-1][:k]  # [0:k]
    ix = [dictSigval[k] for k in kSig]
    return sigval[ix], lsigvec[:, ix]


def checkResult(Lbar, eigvec, eigval, k):
    """
    "input
    "matrix Lbar and k eig values and k eig vectors
    "print norm(Lbar*eigvec[:,i]-lamda[i]*eigvec[:,i])
    """
    check = [np.dot(Lbar, eigvec[:, i]) - eigval[i] * eigvec[:, i] for i in range(0, k)]
    length = [np.linalg.norm(e) for e in check] / np.spacing(1)
    print("Lbar*v-lamda*v are %s*%s" % (length, np.spacing(1)))


"""# Model"""


def one_hot(l, classnum=1):  # classnum fix some special case
    one_hot_l = np.zeros((len(l), max(l.max() + 1, classnum)))
    for i in range(len(l)):
        one_hot_l[i][l[i]] = 1
    return one_hot_l


def train(epoch, model, optimizer, features, adj, labels, idx_train, idx_val, model_type,file_name):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    # print(features.shape)
    # print("\nNow in Training: \n")
    output = model(features, adj)
    # print(output.shape)

    loss_train = F.nll_loss(output[idx_train], labels[idx_train])

    pred_labels = torch.argmax(output, axis=1)
    acc_train = metrics.accuracy_score(pred_labels[idx_train].cpu().detach().numpy(),
                                       labels[idx_train].cpu().detach().numpy())

    # print(acc_train)
    # train_acc = (pred_labels[idx_train].cpu().detach().numpy() == labels[idx_train].cpu().detach().numpy()).float().mean()
    # print(train_acc)

    loss_train.backward(retain_graph=True)
    optimizer.step()
    # print(loss_train,acc_train)
    # print("\nAcc Train: "+str(acc_train)+", Loss Train:"+str(loss_train)+"\n")

    # validation
    model.eval()
    # print("\nNow in Validation: \n")
    output = model(features, adj)

    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = metrics.accuracy_score(pred_labels[idx_val].cpu().detach().numpy(),
                                     labels[idx_val].cpu().detach().numpy())
    
    performance_file = open(file_name+"_performance", "a+")
    print(loss_val,acc_val)

    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()),
          'time: {:.4f}s'.format(time.time() - t))
  
    performance_file.write('Epoch: {:04d}'.format(epoch+1)+';'+
          'loss_train: {:.4f}'.format(loss_train.item())+';'+
          'acc_train: {:.4f}'.format(acc_train.item())+';'+
          'loss_val: {:.4f}'.format(loss_val.item())+';'+
          'acc_val: {:.4f}'.format(acc_val.item())+';'+
          'time: {:.4f}s'.format(time.time() - t)+'\n')
    performance_file.close()

    return acc_val


def test(model, features, adj, labels, idx_test):
    model.eval()
    # print("\nNow in Testing: \n")
    output = model(features, adj)
    pred_labels = torch.argmax(output, axis=1)
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = metrics.accuracy_score(labels[idx_test].cpu().detach().numpy(),
                                      pred_labels[idx_test].cpu().detach().numpy())
    f1_test = metrics.f1_score(labels[idx_test].cpu().detach().numpy(), pred_labels[idx_test].cpu().detach().numpy(),
                               average='weighted')
    auc_test = metrics.roc_auc_score(one_hot(labels[idx_test].cpu().detach().numpy()),
                                     output[idx_test].cpu().detach().numpy(), multi_class='ovr', average='weighted')
    
    # print('loss_test: {:.4f}'.format(loss_test.item()),
    #       'acc_test: {:.4f}'.format(acc_test.item()),
    #       'auc_test: {:.4f}'.format(auc_test.item()),
    #       'f1_test: {:.4f}'.format(f1_test.item()),'\n')

    return loss_test.item(), acc_test, auc_test, f1_test


def single_train_and_test(lambda_matrix, Probability_matrix, features, adj, labels, idx_train, idx_val, idx_test,
                          model_type, normalize=False, file_name='pop'):
    # print("\nFeature Shape: ",features.shape)
    # print("\nadj Shape: ",adj.shape)


    if model_type=="DynAERNN":
        
        length=adj.shape[1]
        lookup=length-2

        dim_emb  = class_num
        if args_cuda:
          tensorflow.device('/gpu:0')
        embedding = DynAERNN(d   = dim_emb,
            beta           = 5,
            n_prev_graphs  = lookup,
            nu1            = 1e-6,
            nu2            = 1e-6,
            n_aeunits      = [50, 30],
            n_lstmunits    = [50,dim_emb],
            rho            = 0.3,
            n_iter         = args_epochs,
            xeta           = 1e-3,
            n_batch        = 10,
            modelfile      = ['./intermediate/enc_model_dynAERNN.json', 
                              './intermediate/dec_model_dynAERNN.json'],
            weightfile     = ['./intermediate/enc_weights_dynAERNN.hdf5', 
                              './intermediate/dec_weights_dynAERNN.hdf5'],
            savefilesuffix = "testing")
        embs = []
        
        graphs     = [nx.Graph(adj[:,l,:].numpy()) for l in range(length)]
        for temp_var in range(lookup, length):
                        emb, _ = embedding.learn_embeddings(graphs[:temp_var])
                        embs.append(emb)
        centroid=kmeans(embs[-1],class_num)[0] #change kSigvec from complex64 to float
        result=vq(embs[-1],centroid)[0]

        

        perm = permutations(range(class_num)) 
        one_hot_result=torch.tensor(one_hot(result,class_num))
        acc_test=0
        f1_test=0
        auc_test=0
        count=0
        for i in perm: 
              count+=1
              one_hot_i=one_hot(np.array(i))
              perm_result=torch.mm(one_hot_result,torch.tensor(one_hot_i))
              pred_labels=torch.argmax(perm_result,axis=1)
              acc_test = max(metrics.accuracy_score(labels,pred_labels),acc_test)
              f1_test=max(metrics.f1_score(labels, pred_labels,average='weighted'),f1_test)
              auc_test=max(metrics.roc_auc_score(one_hot(labels), perm_result,multi_class='ovr',average='weighted'),auc_test)
              if count%10000==0:
                print(count)
                print(acc_test,f1_test,auc_test)   
        print(str(acc_test)+'\t'+str(f1_test)+'\t'+str(auc_test))  
        try:
              spec_norm=getKlargestSigVec(adj-Probability_matrix,2)[0]
        except:
              spec_norm=[]
        # return 0,acc_test,spec_norm
        return loss, acc_test, auc_test, f1_test, spec_norm
        




    # choose adj matrix
    # GCN:n*n, Others: n*t*n
    elif model_type == 'GCN':
        if type(lambda_matrix) != type(None):
            decay_adj = torch.zeros(adj.shape[0], adj.shape[0])
            for j in range(adj.shape[0]):
                for k in range(adj.shape[2]):
                    decay_adj[j][k] = lambda_matrix[labels[j]][labels[k]]
            now_adj = adj[:, 0, :].clone()

            for i in range(1, adj.shape[1]):  # time_steps
                tmp_adj = adj[:, i, :].clone()

                now_adj = (1 - decay_adj) * now_adj + decay_adj * tmp_adj
            adj = now_adj
        else:
            now_adj = adj[:, 0, :].clone()
            for i in range(1, adj.shape[1]):  # time_steps
                now_adj += adj[:, i, :].clone()
            adj = now_adj

        # normalize in both cases
        if normalize == True:
            adj += torch.eye(adj.shape[0], adj.shape[1])
            d = torch.sum(adj, axis=1)
            D_minus_one_over_2 = torch.zeros(adj.shape[0], adj.shape[0])
            D_minus_one_over_2[range(len(D_minus_one_over_2)), range(len(D_minus_one_over_2))] = d ** (-0.5)
            adj = torch.mm(torch.mm(D_minus_one_over_2, adj), D_minus_one_over_2)

        features = features[:, -1, :]


    elif model_type == 'GAT' or model_type == 'GraphSage':
        now_adj = adj[:, 0, :].clone()
        for i in range(1, adj.shape[1]):  # time_steps
            now_adj += adj[:, i, :].clone()
        adj = now_adj

        # normalize in both cases
        if normalize == True:
            adj += torch.eye(adj.shape[0], adj.shape[1])
            d = torch.sum(adj, axis=1)
            D_minus_one_over_2 = torch.zeros(adj.shape[0], adj.shape[0])
            D_minus_one_over_2[range(len(D_minus_one_over_2)), range(len(D_minus_one_over_2))] = d ** (-0.5)
            adj = torch.mm(torch.mm(D_minus_one_over_2, adj), D_minus_one_over_2)

        features = features[:, -1, :]
    elif model_type == 'EGCN':
        adj = torch.transpose(adj, 0, 1)
        features = torch.transpose(features, 0, 1)


    elif model_type == 'GAT_BiLSTM':
        now_adj = adj[:, 0, :].clone()
        for i in range(1, adj.shape[1]):  # time_steps
            now_adj += adj[:, i, :].clone()
        adj = now_adj

        if normalize == True:
            adj += torch.eye(adj.shape[0], adj.shape[1])
            d = torch.sum(adj, axis=1)
            D_minus_one_over_2 = torch.zeros(adj.shape[0], adj.shape[0])
            D_minus_one_over_2[range(len(D_minus_one_over_2)), range(len(D_minus_one_over_2))] = d ** (-0.5)
            adj = torch.mm(torch.mm(D_minus_one_over_2, adj), D_minus_one_over_2)

        features = features[:, -1, :]

    elif model_type == 'GraphSage_BiLSTM_GAT':
        now_adj = adj[:, 0, :].clone()
        for i in range(0, adj.shape[1]):  # time_steps
          now_adj = adj[:, 0, :].clone()
        adj = now_adj
        

        if normalize == True:
            adj += torch.eye(adj.shape[0], adj.shape[1])
            d = torch.sum(adj, axis=1)
            D_minus_one_over_2 = torch.zeros(adj.shape[0], adj.shape[0])
            D_minus_one_over_2[range(len(D_minus_one_over_2)), range(len(D_minus_one_over_2))] = d ** (-0.5)
            adj = torch.mm(torch.mm(D_minus_one_over_2, adj), D_minus_one_over_2)

        features = features[:, -1, :]
    elif model_type == "GraphSageTransformer":
        now_adj = adj[:, 0, :].clone()
        for i in range(1, adj.shape[1]):  # time_steps
            now_adj += adj[:, i, :].clone()
        adj = now_adj

        # normalize in both cases
        if normalize == True:
            adj += torch.eye(adj.shape[0], adj.shape[1])
            d = torch.sum(adj, axis=1)
            D_minus_one_over_2 = torch.zeros(adj.shape[0], adj.shape[0])
            D_minus_one_over_2[range(len(D_minus_one_over_2)), range(len(D_minus_one_over_2))] = d ** (-0.5)
            adj = torch.mm(torch.mm(D_minus_one_over_2, adj), D_minus_one_over_2)

        features = features[:, -1, :]
    elif model_type=='EGCN':
        adj=torch.transpose(adj,0,1)
        features=torch.transpose(features,0,1)


    # define model
    if model_type == 'GCN':
        model = GCN(nfeat=features.shape[1],
                    nhid=args_hidden,
                    nclass=class_num,
                    dropout=args_dropout)
    elif model_type == 'RNNGCN':
        model = RNNGCN(nfeat=features.shape[2],
                       nhid=args_hidden,
                       nclass=class_num,
                       dropout=args_dropout)
    elif model_type == 'TRNNGCN':
        model = TRNNGCN(nfeat=features.shape[2],
                        nhid=args_hidden,
                        nclass=class_num,
                        dropout=args_dropout,
                        nnode=features.shape[0],
                        use_cuda=args_cuda)
    elif model_type=='GCNLSTM':
        model = GCNLSTM(nfeat=features.shape[2],
                nhid=args_hidden,
                nclass=class_num,
                dropout=args_dropout)
        
    elif model_type == 'MyGNN':
        # print("\nIn MyGNN Feature Shape: ",features.shape)
        # print("\nadj Shape: ",adj.shape)

        model = MyGNN(nfeat=features.shape[2],
                        nhid=args_hidden,
                        nclass=class_num,
                        dropout=args_dropout)

    elif model_type == "GAT":
        # print("\nIn GAT Feature Shape: ",features.shape)
        # print("\nadj Shape: ",adj.shape)
        adj = dgl.from_networkx(nx.Graph(adj.numpy()))  # fit in dgl
        # print("\nAfter DGL Conv adj Shape: ",adj.shape)
        model = GAT(nfeat=features.shape[1],
                    nhid=args_hidden,
                    nclass=class_num,
                    dropout=args_dropout)
    elif model_type == "GraphSage":
        adj = dgl.from_networkx(nx.Graph(adj.numpy()))  # fit in dgl
        model = GraphSage(nfeat=features.shape[1],
                          nhid=args_hidden,
                          nclass=class_num,
                          dropout=args_dropout)

    elif model_type == "GAT_BiLSTM":
        adj = dgl.from_networkx(nx.Graph(adj.numpy()))  # fit in dgl
        model = GAT(nfeat=features.shape[1],
                    nhid=args_hidden,
                    nclass=class_num,
                    dropout=args_dropout)
    elif model_type == "GraphSage_BiLSTM_GAT":
        adj = dgl.from_networkx(nx.Graph(adj.numpy()))  # fit in dgl
        model = GraphSage_BiLSTM_GAT(nfeat=features.shape[1],
                                     nhid=args_hidden,
                                     nclass=class_num,
                                     dropout=args_dropout)
    elif model_type == "GraphSageTransformer":
        adj = dgl.from_networkx(nx.Graph(adj.numpy()))  # fit in dgl
        model = GraphSageTransformer(nfeat=features.shape[1],
                                     nhid=args_hidden,
                                     nclass=class_num,
                                     dropout=args_dropout)
        
    elif model_type == 'GGGGG':
        model = GGGGG(nfeat=features.shape[2],
                        nhid=args_hidden,
                        nclass=class_num,
                        dropout=args_dropout)
    elif model_type=="EGCN":
        model = EGCN(nfeat=features.shape[2],
                    nhid=args_hidden,
                    nclass=class_num,
                    device=torch.device('cpu'))

    if model_type != "SPEC" and model_type != "SPEC_sklearn" and model_type != "DynAERNN":
        if args_cuda:
            if model_type != 'EGCN':
                model = model.to(torch.device('cuda:0'))  # .cuda()
                features = features.cuda()
                adj = adj.to(torch.device('cuda:0'))
                labels = labels.cuda()
                idx_train = idx_train.cuda()
                idx_val = idx_val.cuda()
                idx_test = idx_test.cuda()
        # optimizer and train
        optimizer = optim.Adam(model.parameters(),
                               lr=args_lr, weight_decay=args_weight_decay)
        # Train model
        train_time_1 = time.time()
        best_val = 0
        for epoch in range(args_epochs):
            acc_val = train(epoch, model, optimizer, features, adj, labels, idx_train, idx_val, model_type,file_name)
            # print("\nTrain Acc= " + str(acc_val) + "\n")
            # print(model.Lambda)
            if acc_val > best_val:
                best_val = acc_val
                loss, acc, auc, f1 = test(model, features, adj, labels, idx_test)
                test_best_val = [loss, acc, auc, f1]
        train_time_2 = time.time()
        # print("\n^^^^^^TIME^^^^^^^\n")
        print(train_time_2-train_time_1)
        # Testing
        # loss, acc, auc, f1 = test(model, features, adj, labels, idx_test)
        if model_type == 'RNNGCN' or model_type == 'TRNNGCN':
            print(model.Lambda, end='\t')
        # print("\n"+str(test_best_val)+"\n")
        print("acc= " + str(acc) + ", auc= " + str(auc) + ", f1= " + str(f1) + "\n")
        full_list.append((acc,auc,f1))
        # print(str(test_best_val[1]) + '\t' + str(test_best_val[2]) + '\t' + str(test_best_val[3]))  # ,end='\t')
        try:
            spec_norm = getKlargestSigVec(now_adj - Probability_matrix, 2)[0]
        except:
            spec_norm = 0  # temperal adj
        return loss, acc, auc, f1, spec_norm


"""# Run Exp for Spectral Clustering and GCN with Decay Rates

# Run Exp on Simulated and Real Datasets
"""


def load_real_data(dataset_name):
    print(dataset_name)
    dataset_dict = dict()
    # dataset_dict["DBLP3"] = "DBLP3.npz"
    # dataset_dict["DBLP5"] = "DBLP5.npz"
    # dataset_dict["Brain"] = "Brain.npz"
    # dataset_dict["Reddit"] = "reddit.npz"
    # dataset_dict["DBLPE"] = "DBLPE.npz"
    dataset_dict["DBLP3"]="/content/drive/MyDrive/InterpretableClustering-master/DBLP3.npz"
    dataset_dict["DBLP5"]="/content/drive/MyDrive/InterpretableClustering-master/DBLP5.npz"
    dataset_dict["Brain"]="/content/drive/MyDrive/InterpretableClustering-master/Brain.npz"
    # dataset_dict["Reddit"]="/content/drive/MyDrive/InterpretableClustering-master/reddit.npz"
    dataset_dict["Reddit"]="/content/drive/MyDrive/GCN-SE-main/reddit.npz"
    dataset_dict["DBLPE"]="/content/drive/MyDrive/InterpretableClustering-master/DBLPE.npz"

    print(dataset_dict[dataset_name])

    dataset = np.load(dataset_dict[dataset_name])
    # print(dict(dataset))
    # print(dataset['adjs'].shape)
    # print(dataset['labels'].shape)

    Graphs = torch.LongTensor(dataset['adjs'])  # (n_time, n_node, n_node)
    Graphs = torch.transpose(Graphs, 0, 1)  # (n_node, n_time, n_node)

    now_adj = Graphs[:, 0, :].clone()
    # print(torch.sum(now_adj))
    # print(now_adj,now_adj.shape)
    for i in range(1, Graphs.shape[1]):  # time_steps
        # print("\nTime:",str(i),"\n")
        # print(now_adj,now_adj.shape)
        now_adj += Graphs[:,i,:].clone()

    #     # print(now_adj,now_adj.shape)
    #     # print(torch.sum(now_adj))

    d = torch.sum(now_adj, axis=1)
    # print(d,d.shape)
    non_zero_index = torch.nonzero(d, as_tuple=True)[0]
    # print(non_zero_index,non_zero_index.shape)
    Graphs = Graphs[non_zero_index, :, :]
    Graphs = Graphs[:, :, non_zero_index]
    # print(Graphs.shape)


    if dataset_name == "DBLPE":
        Labels = torch.LongTensor(np.argmax(dataset['labels'], axis=2))  # (n_node, n_time, num_classes) argmax
        Features = torch.zeros(Graphs.shape)
        for i in range(Features.shape[1]):
            Features[:, i, :] = torch.eye(Features.shape[0], Features.shape[2])
        Labels = Labels[non_zero_index]

    else:
        Labels = torch.LongTensor(np.argmax(dataset['labels'], axis=1))  # (n_node, num_classes) argmax
        Features = torch.LongTensor(dataset['attmats'])  # (n_node, n_time, att_dim)

        Features = Features[non_zero_index]
        Labels = Labels[non_zero_index]
    # print(Graphs,Graphs.shape)
    # print(Features.float().shape, Graphs.float().shape, Labels.long().shape,sep="\n-----------\n")

    # shuffle datasets
    number_of_nodes = Graphs.shape[0]
    # print("\nNumber of Total Nodes:\n")
    # print(number_of_nodes)
    nodes_id = list(range(number_of_nodes))
    # print(nodes_id)
    random.shuffle(nodes_id)
    # print(len(nodes_id))
    # idx_train = torch.LongTensor(nodes_id[:(5 * number_of_nodes) // 10])
    # idx_val = torch.LongTensor(nodes_id[(5 * number_of_nodes) // 10: (7 * number_of_nodes) // 10])
    # idx_test = torch.LongTensor(nodes_id[(7 * number_of_nodes) // 10: number_of_nodes])
    
    idx_train = torch.LongTensor(nodes_id[:(7*number_of_nodes)//10])
    idx_val = torch.LongTensor(nodes_id[(7*number_of_nodes)//10: (9*number_of_nodes)//10])
    idx_test = torch.LongTensor(nodes_id[(9*number_of_nodes)//10: number_of_nodes])    
    
    
    
    
    
    
    
    
    
    
    
    
    # print("\nOriginal Features:\n")
    # print(Features.float(),Features.float().shape)
    # Features=torch.rand(Features.float().shape)
    # print("\Random Features:\n")
    # print(Features.float(),Features.float().shape)
    # print(Features.float(), Graphs.float(), Labels.long(), idx_train, idx_val, idx_test,sep="\n-----------\n")
    # print(features.shape,adj.shape,labels.shape, idx_train.shape, idx_val.shape, idx_test.shape)
    return Features.float(), Graphs.float(), Labels.long(), idx_train, idx_val, idx_test, []


"""/content/drive/MyDrive/Dynamic_New"""


def test_real_dataset(file_name):
    if IN_COLAB == True:
        # summary_file = open("/content/drive/MyDrive/Finetune/"+file_name, "a+")
        summary_file = open(file_name, "a+")
    else:
        summary_file = open(file_name, "a+")
        # summary_file = open("/content/drive/MyDrive/Finetune"+file_name, "a+")
    t = time.time()
    lambda_matrix = None
    total_loss = 0
    total_acc = 0
    total_norm = []
    loss, acc, auc, f1, specnorm = single_train_and_test(lambda_matrix, Probability_matrix, features, adj, labels,
                                                         idx_train,
                                                         idx_val, idx_test, model_type, normalize=args_normalize,file_name=file_name)
    if type(lambda_matrix) != type(None):
        summary_file.write("accuracy= {:.6f}".format(acc) +
                           "\tauc= {:.6f}".format(auc) +
                           "\tf1= {:.6f}".format(f1) +
                           "\n")
    else:
        summary_file.write("accuracy= {:.6f}".format(acc) +
                           "\tauc= {:.6f}".format(auc) +
                           "\tf1= {:.6f}".format(f1) +
                           "\n")

    summary_file.close()
    return file_name



total_time_3 = time.time()


full_list=[]
# dataset_name="DBLPE"
# dataset_name = "DBLP5"
# dataset_name = "DBLP3"
# dataset_name="Reddit"
dataset_name="Brain"
features, adj, labels, idx_train, idx_val, idx_test, Probability_matrix = load_real_data(dataset_name)
print(features.shape,adj.shape,labels.shape, idx_train.shape, idx_val.shape, idx_test.shape)
class_num = int(labels.max()) + 1
print(class_num)
total_adj = adj
total_labels = labels

# print(features, adj, labels, idx_train, idx_val, idx_test, Probability_matrix)

# len(total_labels)
model_type = 'MyGNN'  # GCN, GAT, GraphSage, GCNLSTM, EGCN, RNNGCN, TRNNGCN, MyGNN
args_hidden = class_num
args_dropout = 0.5
args_lr = 0.0025
args_weight_decay = 5e-4
args_epochs = 30
args_no_cuda = True
args_cuda = not args_no_cuda and torch.cuda.is_available()
args_normalize = True
file_name = dataset_name + '_' + model_type + ".txt"

# file_name = "/content/drive/MyDrive/SP_CIKM_all_30/"+dataset_name + '_' + model_type + ".txt"
print("\n"+model_type+"\n")
for i in range(30):
    print("\nIteration: "+str(i)+"\n")
    test_real_dataset(file_name)
print(full_list)

def Average(lst):
    return sum(lst) / len(lst)
acc_list= [x for (x,y,z) in full_list]
auc_list= [y for (x,y,z) in full_list]
f1_list= [z for (x,y,z) in full_list]
print(len(acc_list))
print(Average(acc_list),Average(auc_list),Average(f1_list),sep=',')
total_time_4 = time.time()
print(total_time_4-total_time_3)
# my_list = []
# with open(file_name) as fp:
#     Lines = fp.readlines()
#     for line in Lines:
#         a = line.strip()
#         my_list.append([float(a.split('\t')[i].split('= ')[1]) for i in [0, 1, 2]])


# df = pd.DataFrame(my_list, columns=['acc', 'auc', 'f1'])
# my_dict=dict(df.mean())
# summary_file = open(file_name, "a+")
# summary_file.write("break\n"+str(my_dict['acc'])+", "+str(my_dict['auc'])+", "+str(my_dict['f1'])+"\n")



Brain
/content/drive/MyDrive/InterpretableClustering-master/Brain.npz
torch.Size([5000, 12, 20]) torch.Size([5000, 12, 5000]) torch.Size([5000]) torch.Size([3500]) torch.Size([1000]) torch.Size([500])
10

MyGNN


Iteration: 0

tensor(2.3183, grad_fn=<NllLossBackward0>) 0.086
Epoch: 0001 loss_train: 2.3251 acc_train: 0.0909 loss_val: 2.3183 acc_val: 0.0860 time: 53.7982s
tensor(2.3143, grad_fn=<NllLossBackward0>) 0.091
Epoch: 0002 loss_train: 2.3194 acc_train: 0.0960 loss_val: 2.3143 acc_val: 0.0910 time: 57.1359s
tensor(2.3102, grad_fn=<NllLossBackward0>) 0.098
Epoch: 0003 loss_train: 2.3144 acc_train: 0.0997 loss_val: 2.3102 acc_val: 0.0980 time: 54.1844s
tensor(2.3060, grad_fn=<NllLossBackward0>) 0.097
Epoch: 0004 loss_train: 2.3095 acc_train: 0.0977 loss_val: 2.3060 acc_val: 0.0970 time: 54.3532s
tensor(2.3018, grad_fn=<NllLossBackward0>) 0.099
Epoch: 0005 loss_train: 2.3046 acc_train: 0.1077 loss_val: 2.3018 acc_val: 0.0990 time: 55.0856s
tensor(2.2975, grad_fn=<NllLossBackward0>) 

# New section