In [6]:
import tensorflow as tf
import numpy as np

In [34]:
class GCNModel(tf.keras.Model):
    def __init__(self, no_classes, hidden_units):
        super(GCNModel, self).__init__()
        self.node_dense = tf.keras.layers.Dense(hidden_units, activation='relu')
        self.edge_dense = tf.keras.layers.Dense(hidden_units, activation='relu')
        self.output_layer = tf.keras.layers.Dense(no_classes, activation='softmax')

    def call(self, inputs, training=False, **kwargs):
        # Unpack inputs
        node_features, edge_features, adjacency_matrix = inputs

        # Process node features
        node_output = self.node_dense(node_features)
        print(node_output.shape)

        edge_output = self.edge_dense(edge_features)
        print(edge_output.shape)

        # Determine the maximum size along the feature dimension
        max_output_size = max(node_output.shape[-2], edge_output.shape[-2])
        print(max_output_size)
        print(node_output.shape[-2])
        print(edge_output.shape[-2])
        
        # Pad node features to match the size of edge features
        node_output_padded = tf.pad(node_output, [[0, 0], [0, 50], [0, 0]])
        
        print(node_output_padded.shape)
        print(edge_output.shape)
        
        # Now both tensors have the same size along the feature dimension and can be concatenated
        combined_output = tf.concat([node_output_padded, edge_output], axis=-1)
        

        adjacency_matrix_padded = tf.pad(adjacency_matrix, [[0, 0], [0, 50], [0, 50]])
        print(adjacency_matrix_padded.shape)
        print(combined_output.shape)

        # Process adjacency matrix
        adjacency_output = tf.matmul(adjacency_matrix_padded, combined_output)

        # Perform graph convolution
        graph_output = self.output_layer(adjacency_output)

        return graph_output

In [35]:
num_hidden_units = 64
batch_size = 32
max_num_nodes = 50
max_num_edges = 100
num_node_features = 10
num_edge_features = 8
num_classes = 5

# # Define placeholders for inputs with dynamic shapes
# node_features = tf.keras.Input(shape=(None, num_node_features))  # Variable number of nodes, each with num_node_features features
# edge_features = tf.keras.Input(shape=(None, num_edge_features))  # Variable number of edges, each with num_edge_features features
# adjacency_matrix = tf.keras.Input(shape=(None, None))  # Variable size adjacency matrix

node_features_ = np.random.rand(batch_size, max_num_nodes, num_node_features).astype(np.float32)
edge_features_ = np.random.rand(batch_size, max_num_edges, num_edge_features).astype(np.float32)
adjacency_matrix_ = np.random.randint(0, 2, size=(batch_size, max_num_nodes, max_num_nodes)).astype(np.float32)

# print(node_features_)
# print(edge_features_)
# print(adjacency_matrix_)

# Instantiate the GCN model
gcn_model = GCNModel(no_classes=num_classes, hidden_units=64)

# Perform forward pass
output = gcn_model([node_features_, edge_features_, adjacency_matrix_])

(32, 50, 64)
(32, 100, 64)
100
50
100
(32, 100, 64)
(32, 100, 64)
(32, 100, 100)
(32, 100, 128)


In [36]:
output

<tf.Tensor: shape=(32, 100, 5), dtype=float32, numpy=
array([[[7.0338398e-01, 1.9089876e-02, 2.7737394e-01, 1.9508117e-07,
         1.5189657e-04],
        [6.3200849e-01, 6.7678735e-02, 3.0005270e-01, 1.3451352e-06,
         2.5880727e-04],
        [8.1729960e-01, 9.1677215e-03, 1.7344971e-01, 1.0758510e-07,
         8.2791281e-05],
        ...,
        [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
         2.0000000e-01],
        [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
         2.0000000e-01],
        [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
         2.0000000e-01]],

       [[9.0433198e-01, 2.2352502e-02, 7.2392009e-02, 3.4009558e-06,
         9.2014100e-04],
        [9.2741859e-01, 2.6691044e-02, 4.5625307e-02, 1.8744667e-07,
         2.6492609e-04],
        [9.3447089e-01, 6.2962049e-03, 5.9214264e-02, 1.8607353e-09,
         1.8655723e-05],
        ...,
        [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
   