<a href="https://colab.research.google.com/github/Sripathm2/UCLA_CS_245_Project5/blob/GNN/GNN/MPNN%2BLSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install spektral

In [None]:
import numpy as np
from scipy import sparse
import tensorflow as tf
import spektral
from spektral.layers.ops import sp_matrix_to_sp_tensor
from spektral.datasets.mnist import MNIST

data = MNIST()


class Net(tf.keras.Model):
    def __init__(self, window=6, dropout=.5, **kwargs):
        """
        Window: int. Window of days
        #LSTM hidden states: 64
        Training: 500 epocs, batchsize 8, Adam optimizer, LR 10-3
        """
        super().__init__(**kwargs)
        self._nets = [self.build_MPNN_unit(dropout) for i in range(286)]
        self.permute = tf.keras.layers.Permute((2,1,3))
        self.flatten = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())
        self.LSTM1 = tf.keras.layers.LSTM(52, return_sequences=True)
        self.LSTM2 = tf.keras.layers.LSTM(52, return_sequences=False,
                                          return_state=True)

    def build_MPNN_unit(self, dropout):
        L1 = []
        L1.append(
            spektral.layers.MessagePassing(aggregate='sum',
                                           activation='relu')
            )
        L1.append(
            tf.keras.layers.BatchNormalization()
            )
        L1.append(
            tf.keras.layers.Dropout(dropout)
            )
        L2 = []
        L2.append(
            spektral.layers.MessagePassing(aggregate='sum',
                                           activation='relu')
            )
        L2.append(
            tf.keras.layers.BatchNormalization()
            )
        L2.append(
            tf.keras.layers.Dropout(dropout)
            )
        return (L1,L2)


    def run_MPNN_unit(self, Adj, X, net_id):
        L1, L2 = self._nets[net_id]
        y = None
        for i in range(0,len(L1)):
            if i == 0: # MessagePassing layer
                y = L1[i].propagate(X, Adj)
                continue
            # print(i,L1[i])#, y)
            y = L1[i](y)
        H1 = y
        for i in range(0, len(L2)):
            if i == 0: # MessagePassing Layer
                y = L2[i].propagate(y, Adj)
                continue
            y = L2[i](y)
        H2 = y
        return tf.concat((H1,H2), axis=1)
    
    def call(self, input):
        X, Adj = input
        H_list = []
        X = X[0]
        Adj = Adj[0]
        for i in range(Adj.shape[0]):
          a = sp_matrix_to_sp_tensor(Adj[i])
          H = self.run_MPNN_unit(a, X[i], i)
          H_list.append(H)
        H_out = tf.expand_dims(H_list, axis=0)
        print('H_out: ', H_out.shape)
        #LSTM_input = self.permute(H_out)[0]
        LSTM_input = self.flatten(H_out)
        print('LSTM Input: ',LSTM_input.shape)
        x = self.LSTM1(inputs=LSTM_input)
        print('After First LSTM: ',x.shape)
        x, final_memory_state, final_carry_state = self.LSTM2(inputs=x)
        print('After Second LSTM: ',x.shape)
        print('Feature matrix: ', X.shape)
        x = tf.transpose(x)
        print('Output of LSTM: ',x)
        #x = X+x
        #Lin?
        #x = tf.keras.activations.relu(x)
        return x

# Create random Adj Matrices
A_list = []
for i in range(286):
  temp = np.zeros([52,52])
  for j in range(np.random.randint(1,25)):
    r = np.random.randint(0,52)
    s = np.random.randint(0,52)
    n = np.random.randn()
    temp[r,s] = n
    temp[s,r] = n
  A_list.append(temp)
Adj = tf.expand_dims(A_list, axis=0)[0]

# Create random node feature matrices
X = [np.random.rand(52, 5) for i in range(0,286)]
X = tf.expand_dims(X, axis=0)[0]

# Create random labels
y = np.random.rand(52, 1)

X2 = tf.expand_dims(X, 0)
Adj2 = tf.expand_dims(Adj,0)
y2 = tf.expand_dims(y,0)
input = (X2, Adj2)
model = Net(window=6)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), loss=tf.keras.losses.MeanSquaredError(), metrics= ['mse'])
# model.fit(x=Adj2, y=y2, epochs=5)


In [25]:
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Input, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2

from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GCNConv
from spektral.transforms import LayerPreprocess, AdjToSpTensor

# Load data
dataset = Citation('cora',
                   transforms=[LayerPreprocess(GCNConv), AdjToSpTensor()])
mask_tr, mask_va, mask_te = dataset.mask_tr, dataset.mask_va, dataset.mask_te

# Parameters
channels = 16          # Number of channels in the first layer
dropout = 0.5          # Dropout rate for the features
l2_reg = 5e-4 / 2      # L2 regularization rate
learning_rate = 1e-2   # Learning rate
epochs = 200           # Number of training epochs
patience = 10          # Patience for early stopping
a_dtype = dataset[0].a.dtype  # Only needed for TF 2.1

N = dataset.n_nodes          # Number of nodes in the graph
F = dataset.n_node_features  # Original size of node features
n_out = dataset.n_labels     # Number of classes

# Model definition
x_in = Input(shape=(F,))
a_in = Input((N,), sparse=True, dtype=a_dtype)

do_1 = Dropout(dropout)(x_in)
gc_1 = GCNConv(channels,
               activation='relu',
               kernel_regularizer=l2(l2_reg),
               use_bias=False)([do_1, a_in])
do_2 = Dropout(dropout)(gc_1)
gc_2 = GCNConv(n_out,
               activation='softmax',
               use_bias=False)([do_2, a_in])

# Build model
model = Model(inputs=[x_in, a_in], outputs=gc_2)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              weighted_metrics=['acc'])
model.summary()

# Train model
loader_tr = SingleLoader(dataset, sample_weights=mask_tr)
loader_va = SingleLoader(dataset, sample_weights=mask_va)
model.fit(loader_tr.load(),
          steps_per_epoch=loader_tr.steps_per_epoch,
          epochs=1)




To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



  self._set_arrayXarray(i, j, x)


Model: "functional_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            [(None, 1433)]       0                                            
__________________________________________________________________________________________________
dropout_2312 (Dropout)          (None, 1433)         0           input_7[0][0]                    
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 2708)]       0                                            
__________________________________________________________________________________________________
gcn_conv_6 (GCNConv)            (None, 16)           22928       dropout_2312[0][0]               
                                                                 input_8[0][0]         

<tensorflow.python.keras.callbacks.History at 0x7f90833c5160>

<RepeatDataset shapes: (((2708, 1433), (2708, 2708)), (2708, 7), (2708,)), types: ((tf.float32, tf.float64), tf.int32, tf.bool)>