In [None]:
# Construct Training/Evaluation Sets

import numpy as np
import os
import tensorflow as tf
import spektral
import numpy as np
import os

def load_data_by_rat(processed_path, selected_rats, training):
    X, Y = [], []
    for file in os.listdir(processed_path):
      if training == True:
        if any(rat in file for rat in selected_rats) and file.endswith('fold0.npz'):
            print(f"Loading processed file: {file}")
            data = np.load(os.path.join(processed_path, file))
            X.append(data['X'])
            Y.append(data['Y'])
      else:
        if any(rat in file for rat in selected_rats) and file.endswith('fold0.npz'):
            print(f"Loading processed file: {file}")
            data = np.load(os.path.join(processed_path, file))
            X.append(data['X'])
            Y.append(data['Y'])

    X = np.concatenate(X, axis=0)
    Y = np.concatenate(Y, axis=0)
    return X, Y

# Split rats into train/test (generalizability evaluation)
test_rat = ["Rat5"] 
train_rat = ["Rat2", "Rat4", "Rat5", "Rat7", "Rat8", "Rat9", "Rat10"]


processed_path = '/content/drive/MyDrive/Graphs/ProcessedGraphs'

# Load data
X_train, Y_train = load_data_by_rat(processed_path,train_rat, training=True)
X_test, Y_test = load_data_by_rat(processed_path,test_rat, training=False)


# Print shapes for verification
print(f"Train: {X_train.shape}, {Y_train.shape}")
print(f"Test: {X_test.shape}, {Y_test.shape}")


In [None]:
# Geodesic Distance Graph Construction
from spektral.data import Dataset, Graph
import spektral
import tensorflow as tf



def geodesic_distance(channel1, channel2):
    r1, c1 = divmod(channel1, 8)  # 8 channels per ring
    r2, c2 = divmod(channel2, 8)

    h_dist = min(abs(c1 - c2), 8 - abs(c1 - c2))
    v_dist = abs(r1 - r2)

    return h_dist + v_dist


class MyDataset(Dataset):
    def __init__(self, ele, nei, distance_threshold, X, Y, **kwargs):

        self.nei = nei
        self.ele = ele
        self.a = None
        self.distance_threshold = distance_threshold
        self.X = X
        self.Y = Y
        super().__init__(**kwargs)

    def read(self):

        x = self.X
        self.a = self.geodesic_adj(self.nei, self.ele)
        return [Graph(x=x_, y=y_) for x_, y_ in zip(self.X, self.Y)]

    def geodesic_adj(self, nei, num_elec):
        sigma = 2
        geodesic_matrix = np.zeros((num_elec, num_elec))

        for i in range(num_elec):
            for j in range(num_elec):
                geodesic_matrix[i, j] = geodesic_distance(i, j)
 

        sorted_indices = np.argsort(geodesic_matrix, axis=1)  
        knn_indices = sorted_indices[:, 1:nei+1]  

        adjacency_matrix = np.zeros_like(geodesic_matrix)
        for i in range(num_elec):
            adjacency_matrix[i, knn_indices[i]] = np.exp(- (geodesic_matrix[i, knn_indices[i]] ** 2) / (2 * sigma ** 2))

        return adjacency_matrix.astype(float)

data_tr = MyDataset(56, 5, 5, X_train, Y_train)
data_te = MyDataset(56, 5, 5, X_test, Y_test)

data_tr.a = spektral.utils.sparse.sp_matrix_to_sp_tensor(
    spektral.layers.EdgeConv.preprocess(data_tr.a)
)

data_te.a = spektral.utils.sparse.sp_matrix_to_sp_tensor(
    spektral.layers.EdgeConv.preprocess(data_te.a)
)



In [None]:
# Random Graph Construction

from spektral.data import Dataset, Graph
import spektral
import tensorflow as tf
import numpy as np

def random_adj(num_nodes, max_degree=2, seed=None):
    if seed is not None:
        np.random.seed(seed)

    adj = np.zeros((num_nodes, num_nodes))
    degrees = np.zeros(num_nodes, dtype=int)

    possible_edges = [(i, j) for i in range(num_nodes) for j in range(i+1, num_nodes)]
    np.random.shuffle(possible_edges)

    for i, j in possible_edges:
        if degrees[i] < max_degree and degrees[j] < max_degree:
            adj[i, j] = adj[j, i] = 1
            degrees[i] += 1
            degrees[j] += 1

    return adj



class MyDataset(Dataset):
    def __init__(self, ele, max_degree, X, Y, **kwargs):
        self.ele = ele
        self.max_degree = max_degree
        self.X = X
        self.Y = Y
        self.a = None
        super().__init__(**kwargs)

    def read(self):
        self.a = self.random_adj(self.ele, self.max_degree)
        return [Graph(x=x_, y=y_) for x_, y_ in zip(self.X, self.Y)]

    def random_adj(self, num_nodes, max_degree):
        return random_adj(num_nodes, max_degree)



data_tr = MyDataset(56, max_degree=2, X=X_train, Y=Y_train)
data_te = MyDataset(56, max_degree=2, X=X_test, Y=Y_test)

data_tr.a = spektral.utils.sparse.sp_matrix_to_sp_tensor(
    spektral.layers.EdgeConv.preprocess(data_tr.a)
)

data_te.a = spektral.utils.sparse.sp_matrix_to_sp_tensor(
    spektral.layers.EdgeConv.preprocess(data_te.a)
)



In [None]:
# Euclidean Graph Construction

from spektral.data import Dataset, Graph
import spektral
import tensorflow as tf

def electrode_coordinates(num_elec=56, n_per_ring=8,
                           ring_spacing_mm=3.33,
                           radius_mm=0.5):
    
    coords = np.zeros((num_elec, 2))

    for i in range(num_elec):
        r, c = divmod(i, n_per_ring)

        theta = 2 * np.pi * c / n_per_ring
        x = radius_mm * np.cos(theta)
        y = r * ring_spacing_mm 

        coords[i] = [x, y]

    return coords

def euclidean_distance_matrix(coords):
    diff = coords[:, None, :] - coords[None, :, :]
    return np.linalg.norm(diff, axis=-1)

class MyDataset(Dataset):
    def __init__(self, ele, nei, distance_threshold, X, Y, **kwargs):
        self.nei = nei
        self.ele = ele
        self.a = None
        self.distance_threshold = distance_threshold
        self.X = X
        self.Y = Y

        self.coords = electrode_coordinates(num_elec=ele)

        super().__init__(**kwargs)

    def read(self):
        self.a = self.euclidean_adj(self.nei, self.ele)
        return [Graph(x=x_, y=y_) for x_, y_ in zip(self.X, self.Y)]

    def euclidean_adj(self, nei, num_elec):
        sigma = 2.0

        dist_matrix = euclidean_distance_matrix(self.coords)

        sorted_indices = np.argsort(dist_matrix, axis=1)
        knn_indices = sorted_indices[:, 1:nei + 1]

        adjacency_matrix = np.zeros_like(dist_matrix)
        for i in range(num_elec):
            d = dist_matrix[i, knn_indices[i]]
            adjacency_matrix[i, knn_indices[i]] = np.exp(
                - (d ** 2) / (2 * sigma ** 2)
            )

        return adjacency_matrix.astype(float)

data_tr = MyDataset(56, 5, 5, X_train, Y_train)
data_te = MyDataset(56, 5, 5, X_test, Y_test)

data_tr.a = spektral.utils.sparse.sp_matrix_to_sp_tensor(
    spektral.layers.EdgeConv.preprocess(data_tr.a)
)

data_te.a = spektral.utils.sparse.sp_matrix_to_sp_tensor(
    spektral.layers.EdgeConv.preprocess(data_te.a)
)


In [None]:
# Model Training
import tensorflow as tf
import spektral
import numpy as np
from sklearn.metrics import f1_score


# Training parameters
batch_size = 256 
epochs = 2000  
patience = 100  
l2_reg = 5e-3 
learning_rate = 1e-3

loader_te = spektral.data.MixedLoader(data_te, batch_size=batch_size)

def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
  radius = tf.cast(kernel_size / 2, dtype = tf.int32)
  kernel_size = radius * 2 + 1
  x = tf.cast(tf.range(-radius, radius + 1), dtype = tf.float32)
  blur_filter = tf.exp(
      -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, dtype = tf.float32), 2.0)))
  blur_filter /= tf.reduce_sum(blur_filter)
  # One vertical and one horizontal filter.
  blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
  blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
  num_channels = tf.shape(image)[-1]
  blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
  blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
  expand_batch_dim = image.shape.ndims == 3
  if expand_batch_dim:
    image = tf.expand_dims(image, axis=0)
  blurred = tf.nn.depthwise_conv2d(
      image, blur_h, strides=[1, 1, 1, 1], padding=padding)
  blurred = tf.nn.depthwise_conv2d(
      blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
  if expand_batch_dim:
    blurred = tf.squeeze(blurred, axis=0)
  return blurred


class Augmentation_layer1(tf.keras.layers.Layer):
    def __init__(self, N, M):
        super(Augmentation_layer1, self).__init__()
        self.M = M
        self.N = N

    def Trans(self, image, count):
        image = tf.reshape(image, (image.shape[0], image.shape[1], -1))
        image = tf.stack([image, image, image], axis=-1)
        image = tf.convert_to_tensor(image)

        G = tf.random.uniform(shape=[], minval=0, maxval=0.1, dtype=tf.float32) #Gaussian noise
        image = tf.keras.layers.GaussianNoise(G)(image)

        image = image[:, :, :, 0]
        return image

    def call(self, image, training):
        image = tf.cast(image, dtype=tf.float32)

        if training:
            for _ in range(self.M):
                image = self.Trans(image, None)

        return tf.cast(image, dtype=tf.float32)




class Net(tf.keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.A = Augmentation_layer1(3, 3) 
        self.C1 = tf.keras.layers.Conv1D(64, kernel_size=batch_size, activation = 'relu', padding='same')
        self.L1 =tf.keras.layers.LSTM(256, activation='relu',return_sequences=True) 

        self.conv1 = spektral.layers.GeneralConv(channels=128) 
        self.conv2 = spektral.layers.EdgeConv(32, kernel_regularizer=tf.keras.regularizers.l2(l2_reg)) #32

        self.b1 = tf.keras.layers.BatchNormalization()
        self.b2 = tf.keras.layers.BatchNormalization()
        self.d3 = tf.keras.layers.Dropout(0.2) #0.2
        self.d4= tf.keras.layers.Dropout(0.2)
        self.flatten = spektral.layers.GlobalAvgPool()
        self.fc1 = tf.keras.layers.Dense(512, activation="relu") #512
        self.fc2 = tf.keras.layers.Dense(128, activation="relu")
        self.fc3 = tf.keras.layers.Dense(3, activation="softmax")  # MNIST has 10 classes
        self.a1 = tf.keras.layers.Activation('relu')
    def call(self, inputs, training, mask = None):
        x, a = inputs

        x11 = self.L1(x)
        x1  =self.conv2([x11, a])

        x1 = self.b1(x1)
        x1 = self.a1(x1)
        x1 = self.d3(x1)
        x1 = self.conv1([x1, a])
        x1 = self.b2(x1)
        x1 = self.d3(x1)

        x1 = self.a1(x1)
        x1 = self.flatten(x1)
        output = self.fc1(x1)
        if training:
            output = self.d3(output)
        output = self.fc2(output)
        if training:
                output = self.d4(output)
        output = self.fc3(output)
        return output


model = Net()
optimizer = tf.keras.optimizers.Adam(learning_rate, weight_decay=1e-3) 
loss_fn = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)

# Training function
@tf.function
def train_on_batch(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True, mask = mask)
        loss = loss_fn(target, predictions) + sum(model.losses)
        acc = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(target, predictions))
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss, acc


# Evaluation function
def evaluate(loader):
    step = 0
    results = []
    for batch in loader:
        step += 1
        inputs, target = batch
        predictions = model(inputs, training=False)
        target = tf.one_hot(tf.squeeze(target), depth=3)
        loss = loss_fn(target, predictions)
        acc = tf.reduce_mean(tf.keras.metrics.categorical_accuracy(target, predictions))
        predictions_np = np.array(predictions)
        predictions_np = np.argmax(predictions_np, axis=1)
        target = np.argmax(target.numpy(), axis=1)
        f1 = f1_score(target, predictions_np, average='macro')
        results.append((loss, acc, f1, len(target)))  # Keep track of batch size
        if step == loader.steps_per_epoch:
            results = np.array(results)
            return np.average(results[:, :-1], 0, weights=results[:, -1])


# Setup training
best_tr_loss = 999999
current_patience = patience
step = 0


loader_tr = spektral.data.MixedLoader(data_tr, batch_size=batch_size, epochs=epochs)
# Training loop
results_tr = []
epoch = 0
for batch in loader_tr:
    epoch +=1
    step += 1
    
    # Training step
    inputs, target = batch
    inputs, target = batch
    with tf.GradientTape() as tape:
        mask = tf.ones((batch_size, 56), dtype=tf.float32)
        predictions = model(inputs, training=True)
        target = tf.one_hot(tf.squeeze(target), depth=3)
        loss = loss_fn(target, predictions) + sum(model.losses)
        acc = tf.reduce_mean(tf.keras.metrics.categorical_accuracy(target, predictions))
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    results_tr.append((loss, acc, len(target)))

    if step%loader_tr.steps_per_epoch==0:
        np.random.shuffle(data_tr)
        loader_tr = spektral.data.MixedLoader(data_tr, batch_size=batch_size, epochs=epochs)
        if results_tr[0] < best_tr_loss:
            best_tr_loss = results_tr[0]
            current_patience = patience
            results_te = evaluate(loader_te)
        else:
            current_patience -= 1
            if current_patience == 0:
                print("Early stopping")
                break

        
        results_tr = np.array(results_tr)
        results_tr = np.average(results_tr[:, :-1], 0, weights=results_tr[:, -1])

        # Print results
        print(
            "Train loss: {:.4f}, acc: {:.4f} | "
            "Test loss: {:.4f}, acc: {:.4f}, f1: {:.4f}".format(
                *results_tr, *results_te
            )
        )

        results_tr = []