## Import packages necessary to structure the data and compile the EfficientNet model for 6 channel images

In [4]:
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense, Dropout, Conv2D
from tensorflow.keras.models import Model
from tensorflow.keras.applications.efficientnet import EfficientNetB0
import tensorflow as tf
from tensorflow.keras.utils import Sequence
import numpy as np
import rasterio
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.utils.class_weight import compute_class_weight
import os

## Create, train and test the six channel VGG
### Define a list with all the 6-channel image paths to save computational ressources in the model

In [None]:
# Identify and save six-channel image paths
def identify_six_channel_images(directory, output_file):
    six_channel_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith('.tif'):
                file_path = os.path.join(root, file)
                try:
                    with rasterio.open(file_path) as src:
                        if src.count == 6:
                            six_channel_files.append(file_path)
                except Exception as e:
                    print(f"Error processing file {file_path}: {e}")
    with open(output_file, 'w') as f:
        for path in six_channel_files:
            f.write("%s\n" % path)
    print(f"Identified {len(six_channel_files)} six-channel images.")

# Pre-identify six-channel images 
identify_six_channel_images('/Volumes/HD710PRO/Fire_and_Hurricane_Images/Fire/6channel/Training', 'training_six_channel_images.txt')
identify_six_channel_images('/Volumes/HD710PRO/Fire_and_Hurricane_Images/Fire/6channel/Validation', 'validation_six_channel_images.txt')
identify_six_channel_images('/Volumes/HD710PRO/Fire_and_Hurricane_Images/Fire/6channel/Test', 'testing_six_channel_images.txt')


### Create a custom data generator for 6-Channel GeoTIFF Imagery

In [5]:
# Custom Data Generator for Six-Channel Images
class SixChannelGenerator(Sequence):
    def __init__(self, file_paths, labels, batch_size=32, dim=(256, 256), n_channels=6, shuffle=True):
        self.file_paths = file_paths
        self.labels = labels
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.file_paths) / self.batch_size))

    def __getitem__(self, index):
        batch_paths = self.file_paths[index * self.batch_size:(index + 1) * self.batch_size]
        batch_labels = self.labels[index * self.batch_size:(index + 1) * self.batch_size]
        batch_x = np.empty((len(batch_paths), *self.dim, self.n_channels), dtype=np.float32)
        batch_y = np.array(batch_labels, dtype=np.float32)
        
        for i, path in enumerate(batch_paths):
            with rasterio.open(path) as src:
                img = src.read()[:self.n_channels, :self.dim[0], :self.dim[1]]
                img = np.moveaxis(img, 0, -1)  # Convert from channels_first to channels_last format
                batch_x[i,] = img / 255.0  # Normalize images

        return batch_x, batch_y

    def on_epoch_end(self):
        if self.shuffle:
            temp = list(zip(self.file_paths, self.labels))
            np.random.shuffle(temp)
            self.file_paths, self.labels = zip(*temp)


### Define the adapted EfficientNet model for Six-Channel Input and define the function to load 6 channel data

In [None]:
# Model Adaptation for Six-Channel Input
def create_efficientnet_six_channel(input_shape=(256, 256, 6), dropout_rate=0.3):
    input_tensor = Input(shape=input_shape)
    
    # Consider initializing this layer with custom weights or explore ways to adapt pre-trained weights
    x = Conv2D(3, (3, 3), padding='same', activation='relu')(input_tensor)
    
    base_model = EfficientNetB0(include_top=False, input_tensor=x, weights=None)  # No pre-trained weights
    
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dropout(dropout_rate)(x)
    x = Dense(1024, activation='relu')(x)
    output = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=input_tensor, outputs=output)
    return model

def load_preprocessed_dataset(file_list_path, percentage=0.25):
    with open(file_list_path, 'r') as file:
        all_files = [line.strip() for line in file.readlines()]

    # Ensure proper shuffling to mix damaged and undamaged images
    np.random.seed(42)
    np.random.shuffle(all_files)

    # Select a subset after shuffling
    subset_size = int(len(all_files) * percentage)
    selected_files = all_files[:subset_size]

    # Correct the logic to match the actual directory names
    labels = [1 if '/Damage/' in file else 0 for file in selected_files]

    damaged_count = labels.count(1)
    undamaged_count = labels.count(0)
    print(f"Loaded {len(selected_files)} images: {damaged_count} damaged, {undamaged_count} undamaged.")
    return selected_files, labels



### Load 25% of the data, define the classes and do weight class computing for class imbalnce, initialize the data generators

In [6]:
# data
print("Loading datasets...")
train_files, train_labels = load_preprocessed_dataset('training_six_channel_images.txt', 0.25)
val_files, val_labels =  load_preprocessed_dataset('validation_six_channel_images.txt', 0.25)
test_files, test_labels = load_preprocessed_dataset('testing_six_channel_images.txt', 0.25)
#Check distribution 
print(f"Distribution in training data: {np.bincount(train_labels)}")
print(f"Distribution in validation data: {np.bincount(val_labels)}")
print(f"Distribution in test data: {np.bincount(test_labels)}")

# with 1 representing 'damaged' and 0 'undamaged' in the filepath
classes = np.unique(train_labels)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=train_labels)

# Create a dictionary mapping class indices to their respective weights
class_weights_dict = {classes[i]: weight for i, weight in enumerate(class_weights)}
print("Class weights:", class_weights_dict)


#Initialize data generators
train_generator = SixChannelGenerator(train_files, train_labels, batch_size=32)
val_generator = SixChannelGenerator(val_files, val_labels, batch_size=32)
test_generator = SixChannelGenerator(test_files, test_labels, batch_size=32)

Loading datasets...
Loaded 21231 images: 5727 damaged, 15504 undamaged.
Loaded 2652 images: 691 damaged, 1961 undamaged.
Loaded 2716 images: 808 damaged, 1908 undamaged.
Distribution in training data: [15504  5727]
Distribution in validation data: [1961  691]
Distribution in test data: [1908  808]
Class weights: {0: 0.6846942724458205, 1: 1.85358826610791}


### Initialize, train, compile and fit the model

In [9]:
# Initialize the model
model = create_efficientnet_six_channel()

# Model Training with Callbacks for Optimal Training Control
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
model_checkpoint = ModelCheckpoint('efficientnet_6channel_best.keras', save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, min_lr=1e-6, verbose=1)

# Compile the model
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
# Fit the model
model.fit(train_generator, epochs=20, validation_data=val_generator,
          callbacks=[early_stopping, model_checkpoint, reduce_lr], verbose=1)

model.save('6channelEN_25.keras')


Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 4: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 10: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 11/20
Epoch 12/20
Epoch 12: ReduceLROnPlateau reducing learning rate to 8.000000525498762e-06.
Epoch 13/20


### Test the model on the unseen test 

In [10]:
#Test the model
test_loss, test_accuracy, test_precision, test_recall = model.evaluate(test_generator, verbose = 1)
print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}, Test Precision: {test_precision}, Test Recall: {test_recall}")

Test Loss: 0.1784033179283142, Test Accuracy: 0.9267305135726929, Test Precision: 0.9258741140365601, Test Recall: 0.8193069100379944
