In [None]:
import os
import warnings
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Define SelfAttention layer
class SelfAttention(tf.keras.layers.Layer):
    def __init__(self):
        super(SelfAttention, self).__init__()

    def build(self, input_shape):
        self.channels = input_shape[-1]
        self.query_conv = tf.keras.layers.Conv2D(filters=self.channels // 8, kernel_size=1)
        self.key_conv = tf.keras.layers.Conv2D(filters=self.channels // 8, kernel_size=1)
        self.value_conv = tf.keras.layers.Conv2D(filters=self.channels, kernel_size=1)
        self.gamma = self.add_weight("gamma", shape=[1], initializer="zeros", trainable=True)

    def call(self, inputs):
        batch_size, height, width, channels = tf.shape(inputs)

        # Calculate query, key, and value
        query = self.query_conv(inputs)  
        key = self.key_conv(inputs)  
        value = self.value_conv(inputs)  

        # Reshape query and key for dot product
        query = tf.reshape(query, [batch_size, -1, height * width])
        key = tf.reshape(key, [batch_size, -1, height * width])

        # Calculate attention weights
        attention = tf.matmul(query, key, transpose_b=True)
        attention = tf.nn.softmax(attention)

        # Apply attention to value
        attention = tf.reshape(attention, [batch_size, height * width, height, width])
        attention = tf.transpose(attention, perm=[0, 2, 3, 1])
        attention = tf.expand_dims(attention, axis=-1)
        attended_value = value * attention
        attended_value = tf.reduce_sum(attended_value, axis=[2, 3])
        attended_value = tf.reshape(attended_value, [batch_size, height, width, channels])

        # Apply gamma factor and combine with input
        outputs = self.gamma * attended_value + inputs

        return outputs

# Location of the data
base_dir = 'C:/Users/cshen/Desktop/ML_test1/240410_nocross_test1/240430_test3_flip'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

# Train group
train_aster_dir = os.path.join(train_dir, 'aster')
train_noaster_dir = os.path.join(train_dir, 'noaster')

# Validation group
validation_aster_dir = os.path.join(validation_dir, 'aster')
validation_noaster_dir = os.path.join(validation_dir, 'noaster')

# Construct the neural convolution network
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 1)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    
    # Apply Self-Attention layer
    SelfAttention(),
    
    # Prepare for the full connection layer
    tf.keras.layers.Flatten(),
    
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.5),  # Add dropout with a dropout rate of 0.5
    
    # Sigmoid for binary classification
    tf.keras.layers.Dense(1, activation='sigmoid')
])

# Display model summary
model.summary()

#setting the training 

model.compile(loss='binary_crossentropy',
              optimizer=Adam(learning_rate=1e-4),
              metrics=['acc'])

#pre-analyze the data
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(64,64),
        batch_size=128,
        class_mode='binary',
        color_mode='grayscale')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(64,64),
        batch_size=128,
        class_mode='binary',
        color_mode='grayscale')

#training the model

# Set batch size
#batch_size = 128

history = model.fit_generator(
      train_generator,
      steps_per_epoch=248,
      epochs=100,
      validation_data=validation_generator,
      validation_steps=80,
      verbose=1)

#save the model
model.save("C:/Users/cshen/Desktop/ML_test1/240410_nocross_test1/240430_test3_flip/test2_SA_100epoch.h5")