In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, backend as K
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model
import numpy as np

class ChannelSelfAttention(Layer):
    def __init__(self):
        super(ChannelSelfAttention, self).__init__()

    def build(self, input_shape):
        self.gamma = self.add_weight(name='gamma', shape=[1], initializer='zeros', trainable=True)

    def call(self, x):
        avg_pool = tf.reduce_mean(x, axis=1, keepdims=True)  
        max_pool = tf.reduce_max(x, axis=1, keepdims=True)  
        concat = tf.concat([avg_pool, max_pool], axis=1)
        attn = tf.nn.sigmoid(concat)
        return x * attn * self.gamma

def spherical_conv_layer(inputs, filters=16, kernel_size=(3, 3)):
     return layers.Conv2D(filters, kernel_size, padding='same', activation='selu')(inputs)

def build_s2fscnn(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = spherical_conv_layer(inputs, filters=16, kernel_size=(3, 3))
    x = layers.BatchNormalization()(x)
    x = ChannelSelfAttention()(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='selu')(x)
    x = layers.Dense(64, activation='selu')(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)
    model = Model(inputs=inputs, outputs=outputs, name='S2FSCNN')
    return model
def compile_model(model, learning_rate=0.001):
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # PKOA-optimized LR can be set here
    model.compile(
        optimizer=optimizer,
        loss='binary_crossentropy',
        metrics=[
            'accuracy',
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall')
        ]
    )
    return model
class CustomMetrics(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        y_pred = np.round(self.model.predict(self.validation_data[0]))
        y_true = self.validation_data[1]

        TP = np.sum((y_pred == 1) & (y_true == 1))
        TN = np.sum((y_pred == 0) & (y_true == 0))
        FP = np.sum((y_pred == 1) & (y_true == 0))
        FN = np.sum((y_pred == 0) & (y_true == 1))

        specificity = TN / (TN + FP + 1e-8)
        precision = TP / (TP + FP + 1e-8)
        recall = TP / (TP + FN + 1e-8)
        f1_score = 2 * (precision * recall) / (precision + recall + 1e-8)
        error_rate = (FP + FN) / (TP + TN + FP + FN + 1e-8)

        print(f" â€” Specificity: {specificity:.4f}, F1: {f1_score:.4f}, Error Rate: {error_rate:.4f}")