In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Dropout, Concatenate, Add, BatchNormalization
from tensorflow.keras.models import Model
from spektral.layers import GlobalAvgPool, GlobalMaxPool, GCNConv

# New callback to log gradient norms
class GradientLogger(tf.keras.callbacks.Callback):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def on_epoch_end(self, epoch, logs=None):
        # Get one batch from the dataset
        for batch in self.dataset.take(1):
            inputs, targets = batch
            with tf.GradientTape() as tape:
                predictions = self.model(inputs, training=True)
                # Use the model's compiled loss function
                loss = self.model.compiled_loss(targets, predictions)
            gradients = tape.gradient(loss, self.model.trainable_variables)
            norm_list = []
            for grad in gradients:
                if grad is not None:
                    norm_list.append(tf.norm(grad).numpy())
            if norm_list:
                avg_norm = np.mean(norm_list)
                max_norm = np.max(norm_list)
                min_norm = np.min(norm_list)
                print(f"Epoch {epoch+1} - Gradients: Mean L2 norm: {avg_norm:.6f}, Max: {max_norm:.6f}, Min: {min_norm:.6f}")
            break

class NoMaskGCNConv(GCNConv):
    def call(self, inputs, mask=None):
        return super().call(inputs, mask=None)
    def compute_mask(self, inputs, mask=None):
        return None

def build_revised_gcn_model(n_node_features, n_classes, num_nodes):
    # Inputs: node features and adjacency matrix
    x_in = tf.keras.Input(shape=(num_nodes, n_node_features), name="node_features")
    a_in = tf.keras.Input(shape=(num_nodes, num_nodes), name="adjacency_matrix")
    
    # --- GCN Layer 1 ---
    x1 = NoMaskGCNConv(32, activation='relu')([x_in, a_in])
    x1 = BatchNormalization()(x1)
    x1 = Dropout(0.2)(x1)

    # --- GCN Layer 2 ---
    x2 = NoMaskGCNConv(64, activation='relu')([x1, a_in])
    x2 = BatchNormalization()(x2)
    x2 = Dropout(0.2)(x2)

    # --- GCN Layer 3 ---
    x3 = NoMaskGCNConv(128, activation='relu')([x2, a_in])
    x3 = BatchNormalization()(x3)
    x3 = Dropout(0.2)(x3)

    # --- Residual Connection ---
    x2_proj = Dense(128, activation='linear')(x2)
    x3 = Add()([x3, x2_proj])

    # --- Global Pooling ---
    x_max = GlobalMaxPool()(x3)
    x_avg = GlobalAvgPool(name='global_avg_pool')(x3)
    x = Concatenate()([x_max, x_avg])
    
    # --- Fully Connected Layers ---
    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    # Final classification layer
    output = Dense(n_classes, activation='softmax')(x)
    
    model = Model(inputs=[x_in, a_in], outputs=output)
    return model

img_size = 48
patch_size = 6
num_nodes = (img_size // patch_size) ** 2  # 48/6 = 8 -> 8x8 grid = 64 nodes
n_node_features = patch_size * patch_size + 2  # 36 + 2 = 38
n_classes = 7 

model = build_revised_gcn_model(n_node_features=n_node_features, n_classes=n_classes, num_nodes=num_nodes)
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.summary()

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
gradient_logger = GradientLogger(train_dataset)

history = model.fit(
    train_dataset,
    steps_per_epoch=len(train_generator),
    validation_data=validation_dataset,
    validation_steps=len(validation_generator),
    epochs=20,
    callbacks=[early_stop, gradient_logger]
)
