In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.saving import register_keras_serializable
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

data = pd.read_csv("sampled_flood_data.csv")

print("✅ Enhanced dataset size:", len(data), 
      "| Flood:", data['flood_binary'].sum(), 
      "| No Flood:", (data['flood_binary'] == 0).sum())

X = data.drop("flood_binary", axis=1).astype("float32")
y = data["flood_binary"].astype("float32")
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.3, random_state=42)

# --- Custom Layers ---
@register_keras_serializable()
def rainfall_proximity_penalty(inputs):
    rainfall = inputs[:, 0]
    distance = inputs[:, 4]
    proximity_score = tf.sigmoid((150 - distance) * 0.04)
    rainfall_score = tf.sigmoid((rainfall - 90) * 0.3)
    return (rainfall_score * proximity_score)[:, None]

@register_keras_serializable()
def flood_risk_booster(inputs):
    slope = inputs[:, 3]
    rainfall = inputs[:, 0]
    slope_boost = tf.sigmoid((slope - 2.0) * 1.5)
    rain_boost = tf.sigmoid((rainfall - 60) * 0.25)
    return (1.0 + 0.25 * slope_boost * rain_boost)[:, None]

@register_keras_serializable()
def flood_suppression_mask(inputs):
    elevation = inputs[:, 2]
    rainfall = inputs[:, 0]
    flatness = tf.sigmoid((elevation - 9.0) * 0.6)
    dryness = tf.sigmoid((20.0 - rainfall) * 0.2)
    return (1.0 - 0.3 * flatness * dryness)[:, None]

class PrintValidationStats(callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        preds = self.model.predict(self.validation_data[0])
        print(f"Epoch {epoch+1} raw preds (first 5):", preds[:5].flatten())

# Then pass it into `callbacks=[early_stop, PrintValidationStats()]`

# --- Model Architecture (FireNet Style) ---
input_layer = layers.Input(shape=(5,))
x = layers.BatchNormalization()(input_layer)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dense(32, activation='relu')(x)
logits = layers.Dense(1)(x)

penalty = layers.Lambda(rainfall_proximity_penalty, output_shape=(1,))(input_layer)
booster = layers.Lambda(flood_risk_booster, output_shape=(1,))(input_layer)
suppressor = layers.Lambda(flood_suppression_mask, output_shape=(1,))(input_layer)

modulated_logits = layers.Add()([
    logits,
    layers.Multiply()([penalty, booster, suppressor])
])
adjusted_output = layers.Activation("sigmoid")(modulated_logits)


model = models.Model(inputs=input_layer, outputs=adjusted_output)
model.compile(optimizer="adam", loss=tf.keras.losses.BinaryCrossentropy(from_logits=False, label_smoothing=0.05)
, metrics=["accuracy"])

# --- Train & Evaluate ---
early_stop = callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True)
model.fit(X_train, y_train, validation_split=0.2, epochs=10, batch_size=8, callbacks=[early_stop])

loss, acc = model.evaluate(X_test, y_test)
print(f"🌊 FloodNet Accuracy: {acc:.4f}")
model.save("models/FV-FloodNet.h5")


✅ Enhanced dataset size: 2000 | Flood: 1000 | No Flood: 1000
Epoch 1/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 6ms/step - accuracy: 0.9954 - loss: 0.2282 - val_accuracy: 1.0000 - val_loss: 0.1240
Epoch 2/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.9998 - loss: 0.1300 - val_accuracy: 1.0000 - val_loss: 0.1219
Epoch 3/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.9800 - loss: 0.1594 - val_accuracy: 1.0000 - val_loss: 0.1198
Epoch 4/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - accuracy: 1.0000 - loss: 0.1282 - val_accuracy: 1.0000 - val_loss: 0.1196
Epoch 5/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.9997 - loss: 0.1252 - val_accuracy: 1.0000 - val_loss: 0.1255
Epoch 6/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.9875 - loss: 0.1467 - val_acc



🌊 FloodNet Accuracy: 1.0000
