In [2]:
import os
import numpy as np
import pandas as pd
import rasterio
import tensorflow as tf
from concurrent.futures import ThreadPoolExecutor

SEQ_LEN = 6                 
HORIZONS = 3               
PATCH_SIZE = 13             
HALF = PATCH_SIZE // 2
FILL_NAN_VALUE = 0.0

REQUIRED_COLS = [
    "era5_t2m_file", "era5_d2m_file", "era5_tp_file",
    "era5_u10_file", "era5_v10_file",
    "viirs_file", "dem_file", "lulc_file"
]

In [3]:
import tensorflow as tf
print(tf.__version__)

2.20.0


In [4]:
def _load_single_raster(path):
    with rasterio.open(path) as src:
        arr = src.read() 

    if arr.shape[0] == 1:
        
        return arr[0]
    else:
     
        return arr


In [5]:
def load_rasters(df, raster_cols, max_workers=8):
    
    all_paths = set()

    for col in raster_cols:
        if col in df.columns:
            all_paths.update(df[col].dropna().unique())
    all_paths = list(all_paths)

    cache = {}
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        results = list(ex.map(_load_single_raster, all_paths))

    for path, arr in zip(all_paths, results):
        if arr is not None:
            cache[path] = arr
    return cache


In [6]:
def _safe_center(h, w, patch_size=PATCH_SIZE):
    half = patch_size // 2
    r = np.clip(h // 2, half, h - half - 1)
    c = np.clip(w // 2, half, w - half - 1)
    return r, c

In [7]:
def _extract_patch(arr, row, col, patch_size=PATCH_SIZE):
    half = patch_size // 2
    h, w = arr.shape

    r0 = row - half
    r1 = row + half + 1
    c0 = col - half
    c1 = col + half + 1

    patch = np.zeros((patch_size, patch_size), dtype=arr.dtype)

    r0_clip = max(r0, 0)
    r1_clip = min(r1, h)
    c0_clip = max(c0, 0)
    c1_clip = min(c1, w)

    pr0 = r0_clip - r0
    pr1 = pr0 + (r1_clip - r0_clip)
    pc0 = c0_clip - c0
    pc1 = pc0 + (c1_clip - c0_clip)

    patch[pr0:pr1, pc0:pc1] = arr[r0_clip:r1_clip, c0_clip:c1_clip]

    return patch

In [8]:
def build_sample(seq_rows, horizon_rows, cache, force_fire=False):
    seq_patches = []

   
    for _, row in seq_rows.iterrows():
        bands = []
        for var in ["era5_t2m_file", "era5_d2m_file", "era5_tp_file",
                    "era5_u10_file", "era5_v10_file"]:
            arr = cache[row[var]]

            if len(arr.shape) == 3:
                arr = arr[0]

            h, w = arr.shape
            r, c = _safe_center(h, w)
            bands.append(_extract_patch(arr, r, c))

        dem = cache[row["dem_file"]]
        lulc = cache[row["lulc_file"]]

        if len(dem.shape) == 3:
            dem = dem[0]
        if len(lulc.shape) == 3:
            lulc = lulc[0]

        h, w = dem.shape
        r, c = _safe_center(h, w)
        bands.append(_extract_patch(dem, r, c))
        bands.append(_extract_patch(lulc, r, c))

        seq_patches.append(np.stack(bands, axis=-1))

    X = np.stack(seq_patches, axis=0)


    horizon_patches = []
   
    for _, row in horizon_rows.iterrows():
        viirs_stack = cache[row["viirs_file"]]
        
    
        target_band_idx_list = eval(row["target_band_idxs"])
 
        idx = target_band_idx_list[0]
        
        band = viirs_stack[idx - 1]
        h, w = band.shape
        r, c = _safe_center(h, w)

        if force_fire and np.any(band > 0):
            fire_pos = np.argwhere(band > 0)
            r, c = fire_pos[np.random.randint(len(fire_pos))]

        horizon_patches.append(_extract_patch(band, r, c))

    y = np.stack(horizon_patches, axis=0)

    return X.astype("float32"), y.astype("float32")

In [11]:
def make_generator(df, cache, fire_ratio=0.5):
    valid_start_indices = list(range(len(df) - SEQ_LEN - HORIZONS + 1))
    fire_start_indices = []
    non_fire_start_indices = []
    
    print("Scanning data for fire and non-fire events...")
    for i in valid_start_indices:
        horizon_rows = df.iloc[i + SEQ_LEN : i + SEQ_LEN + HORIZONS]
        has_fire = any(np.any(cache[row["viirs_file"]] > 0) for _, row in horizon_rows.iterrows())
        
        if has_fire:
            fire_start_indices.append(i)
        else:
            non_fire_start_indices.append(i)

    num_fire_samples = len(fire_start_indices)
    
    if num_fire_samples == 0:
        print("Warning: No fire events found in the dataset.")
        num_non_fire_samples_to_use = min(len(non_fire_start_indices), 1000) 
    else:
        num_non_fire_samples_to_use = int((num_fire_samples / fire_ratio) - num_fire_samples)
        num_non_fire_samples_to_use = min(num_non_fire_samples_to_use, len(non_fire_start_indices))

    fire_indices_to_use = fire_start_indices

    if len(non_fire_start_indices) > 0 and num_non_fire_samples_to_use > 0:
      non_fire_indices_to_use = np.random.choice(
          non_fire_start_indices,
          size=num_non_fire_samples_to_use,
          replace=False 
      )
      indices_to_use = np.concatenate([fire_indices_to_use, non_fire_indices_to_use])
    else:
      indices_to_use = np.array(fire_indices_to_use)

    np.random.shuffle(indices_to_use)
    indices_to_use = indices_to_use.astype(int)
    
    print(f"Generator initialized. Found {len(fire_indices_to_use)} fire samples and using {len(indices_to_use) - len(fire_indices_to_use)} non-fire samples.")

    for i in indices_to_use:
        seq_rows = df.iloc[i : i + SEQ_LEN]
        horizon_rows = df.iloc[i + SEQ_LEN : i + SEQ_LEN + HORIZONS]
        X, y = build_sample(seq_rows, horizon_rows, cache)
        yield X, y

In [12]:
# def create_dataset(df, cache, shuffle_buf=256):
#     output_signature = (
#         tf.TensorSpec(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 7), dtype=tf.float32),
#         tf.TensorSpec(shape=(HORIZONS, PATCH_SIZE, PATCH_SIZE), dtype=tf.float32),
#     )
    
#     ds = tf.data.Dataset.from_generator(
#         lambda: make_generator(df, cache),
#         output_signature=output_signature
#     )
    
#     ds = ds.shuffle(shuffle_buf, reshuffle_each_iteration=True)
#     ds = ds.prefetch(tf.data.AUTOTUNE)
    
#     return ds

In [13]:
def create_dataset(df, cache, shuffle=True, ensure_fire=True, shuffle_buf=256):
    output_signature = (
        tf.TensorSpec(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 7), dtype=tf.float32),
        tf.TensorSpec(shape=(HORIZONS, PATCH_SIZE, PATCH_SIZE), dtype=tf.float32),
    )
    
    # CORRECTED LINE:
    # Changed the keyword argument from 'ensure_fire=' to 'fire_ratio='
    # This now correctly passes the value to the make_generator function.
    ds = tf.data.Dataset.from_generator(
        lambda: make_generator(df, cache, fire_ratio=ensure_fire),
        output_signature=output_signature
    )
    
    if shuffle:
        ds = ds.shuffle(shuffle_buf, reshuffle_each_iteration=True)
    
    ds = ds.prefetch(tf.data.AUTOTUNE)
    
    return ds

In [14]:
if __name__ == "__main__":
    csv_path = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly_binary.csv"
    df = pd.read_csv(csv_path)
    
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)

    TOTAL = len(df)
    VAL_SPLIT = 0.2
    val_size = int(TOTAL * VAL_SPLIT)

    val_df = df.iloc[:val_size].copy()
    train_df = df.iloc[val_size:].copy()

    print(f"Total samples: {TOTAL}")
    print(f"Train samples: {len(train_df)}")
    print(f"Validation samples: {len(val_df)}")

    raster_cols = REQUIRED_COLS
    print("Loading rasters into memory...")
    cache = load_rasters(df, raster_cols, max_workers=8)
    print(f"Loaded {len(cache)} rasters into memory ✅")

    # Use the new, balanced dataset functions
    train_dataset = create_dataset(train_df, cache)
    val_dataset = create_dataset(val_df, cache)

Total samples: 17535
Train samples: 14028
Validation samples: 3507
Loading rasters into memory...
Loaded 9 rasters into memory ✅


In [15]:
BATCH_SIZE = 16
train_dataset = train_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [16]:
from tensorflow.keras import layers, models

In [17]:
SEQ_LEN = 6           
PATCH_H = 13            
PATCH_W = 13          
CHANNELS = 7       
HORIZONS = 3            
LSTM_UNITS = 64    
CNN_FEATURES = 128

In [18]:
# import tensorflow as tf
# from tensorflow.keras import layers, models, callbacks

# def build_conv_lstm_unet_model(
#     seq_len=SEQ_LEN,
#     patch_h=PATCH_H,
#     patch_w=PATCH_W,
#     channels=CHANNELS,
#     horizons=HORIZONS
# ):
#     inp = layers.Input(shape=(seq_len, patch_h, patch_w, channels))

#     enc1 = layers.ConvLSTM2D(
#         filters=32, kernel_size=(3, 3), padding='same', return_sequences=True, activation='relu'
#     )(inp)
#     enc1_pool = layers.MaxPooling3D(pool_size=(1, 2, 2), padding='same')(enc1)

#     enc2 = layers.ConvLSTM2D(
#         filters=64, kernel_size=(3, 3), padding='same', return_sequences=True, activation='relu'
#     )(enc1_pool)
#     enc2_pool = layers.MaxPooling3D(pool_size=(1, 2, 2), padding='same')(enc2)

#     bottleneck = layers.ConvLSTM2D(
#         filters=128, kernel_size=(3, 3), padding='same', return_sequences=True, activation='relu'
#     )(enc2_pool)

#     dec1_up = layers.UpSampling3D(size=(1, 2, 2))(bottleneck)
#     dec1_up = layers.Conv3D(filters=64, kernel_size=(3,3,3), padding='same', activation='relu')(dec1_up)
#     dec1_up_cropped = layers.Cropping3D(cropping=((0, 0), (0, 1), (0, 1)))(dec1_up)
#     dec1_concat = layers.Concatenate(axis=-1)([dec1_up_cropped, enc2])

#     dec2_up = layers.UpSampling3D(size=(1, 2, 2))(dec1_concat)
#     dec2_up = layers.Conv3D(filters=32, kernel_size=(3,3,3), padding='same', activation='relu')(dec2_up)
#     dec2_up_cropped = layers.Cropping3D(cropping=((0, 0), (0, 1), (0, 1)))(dec2_up)
#     dec2_concat = layers.Concatenate(axis=-1)([dec2_up_cropped, enc1])

#     output_convlstm = layers.ConvLSTM2D(
#         filters=1, kernel_size=(3, 3), padding='same', return_sequences=True, activation='sigmoid'
#     )(dec2_concat[:, :horizons])

#     final_output = tf.keras.ops.squeeze(output_convlstm, axis=-1)

#     model = models.Model(inputs=inp, outputs=final_output)
#     return model

In [19]:
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks

def slice_output_func(x):
    return x[:, :HORIZONS, :, :, :]

def slice_output_shape(input_shape):
    return (input_shape[0], HORIZONS, input_shape[2], input_shape[3], input_shape[4])

def squeeze_output_func(x):
    return tf.squeeze(x, axis=-1)

def squeeze_output_shape(input_shape):
    return input_shape[:-1] 

def build_conv_lstm_unet_model(
    seq_len=SEQ_LEN,
    patch_h=PATCH_H,
    patch_w=PATCH_W,
    channels=CHANNELS,
    horizons=HORIZONS
):
    inp = layers.Input(shape=(seq_len, patch_h, patch_w, channels))

    enc1 = layers.ConvLSTM2D(filters=32, kernel_size=(3, 3), padding='same', return_sequences=True, activation='relu')(inp)
    enc1_pool = layers.MaxPooling3D(pool_size=(1, 2, 2), padding='same')(enc1)

    enc2 = layers.ConvLSTM2D(filters=64, kernel_size=(3, 3), padding='same', return_sequences=True, activation='relu')(enc1_pool)
    enc2_pool = layers.MaxPooling3D(pool_size=(1, 2, 2), padding='same')(enc2)

    bottleneck = layers.ConvLSTM2D(filters=128, kernel_size=(3, 3), padding='same', return_sequences=True, activation='relu')(enc2_pool)

    dec1_up = layers.UpSampling3D(size=(1, 2, 2))(bottleneck)
    dec1_up = layers.Conv3D(filters=64, kernel_size=(3,3,3), padding='same', activation='relu')(dec1_up)
    dec1_up_cropped = layers.Cropping3D(cropping=((0, 0), (0, 1), (0, 1)))(dec1_up)
    dec1_concat = layers.Concatenate(axis=-1)([dec1_up_cropped, enc2])

    dec2_up = layers.UpSampling3D(size=(1, 2, 2))(dec1_concat)
    dec2_up = layers.Conv3D(filters=32, kernel_size=(3,3,3), padding='same', activation='relu')(dec2_up)
    dec2_up_cropped = layers.Cropping3D(cropping=((0, 0), (0, 1), (0, 1)))(dec2_up)
    dec2_concat = layers.Concatenate(axis=-1)([dec2_up_cropped, enc1])

    output_convlstm = layers.ConvLSTM2D(
        filters=1, kernel_size=(3, 3), padding='same', return_sequences=True, activation='sigmoid'
    )(dec2_concat)

    output_sliced = layers.Lambda(
        slice_output_func, 
        output_shape=slice_output_shape,
        name='output_slicer'
    )(output_convlstm)

    final_output = layers.Lambda(
        squeeze_output_func,
        output_shape=squeeze_output_shape,
        name='final_squeeze'
    )(output_sliced)
    
    model = models.Model(inputs=inp, outputs=final_output)
    return model

In [20]:
from tensorflow.keras import callbacks

In [21]:
model = build_conv_lstm_unet_model()
model.summary()

In [22]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.AUC()],
)

In [23]:
early_stop = callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)
checkpoint = callbacks.ModelCheckpoint(
    "best_unet_model.keras", # Changed filename
    monitor='val_loss',
    save_best_only=True,
    verbose=1
)

In [24]:
BATCH_SIZE = 16
steps_per_epoch = len(train_df) // BATCH_SIZE
validation_steps = len(val_df) // BATCH_SIZE

print(f"Batch size: {BATCH_SIZE}")
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")


Batch size: 16
Steps per epoch: 876
Validation steps: 219


In [27]:
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=50,
    callbacks=[early_stop, checkpoint],
    verbose=1,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps
)

Epoch 1/50
Scanning data for fire and non-fire events...
Generator initialized. Found 14020 fire samples and using 0 non-fire samples.


I0000 00:00:1758646730.610202  409982 service.cc:152] XLA service 0x7f98440068e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758646730.610219  409982 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 5060 Ti, Compute Capability 12.0
2025-09-23 22:28:50.874322: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1758646731.834813  409982 cuda_dnn.cc:529] Loaded cuDNN version 91100
I0000 00:00:1758646739.310925  409982 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m875/876[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - auc: 0.5325 - binary_accuracy: 0.9954 - loss: 0.0911Scanning data for fire and non-fire events...
Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 1: val_loss improved from None to 0.03133, saving model to best_unet_model.h5




[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 35ms/step - auc: 0.5720 - binary_accuracy: 0.9952 - loss: 0.0463 - val_auc: 0.6382 - val_binary_accuracy: 0.9948 - val_loss: 0.0313
Epoch 2/50
[1m  1/876[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:12:45[0m 5s/step - auc: 0.6670 - binary_accuracy: 0.9778 - loss: 0.1131Scanning data for fire and non-fire events...


2025-09-23 22:29:35.866010: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-09-23 22:29:35.866024: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]


Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 2: val_loss did not improve from 0.03133
[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 7ms/step - auc: 0.6670 - binary_accuracy: 0.9778 - loss: 0.1131 - val_auc: 0.6441 - val_binary_accuracy: 0.9948 - val_loss: 0.0315
Epoch 3/50
Scanning data for fire and non-fire events...
Generator initialized. Found 14020 fire samples and using 0 non-fire samples.
[1m874/876[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - auc: 0.6292 - binary_accuracy: 0.9950 - loss: 0.0304Scanning data for fire and non-fire events...
Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 3: val_loss improved from 0.03133 to 0.03108, saving model to best_unet_model.h5




[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 29ms/step - auc: 0.6344 - binary_accuracy: 0.9952 - loss: 0.0293 - val_auc: 0.6588 - val_binary_accuracy: 0.9948 - val_loss: 0.0311
Epoch 4/50
[1m  1/876[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m12s[0m 14ms/step - auc: 0.6925 - binary_accuracy: 0.9951 - loss: 0.0292Scanning data for fire and non-fire events...


2025-09-23 22:30:11.368007: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]


Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 4: val_loss improved from 0.03108 to 0.03108, saving model to best_unet_model.h5




[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 7ms/step - auc: 0.6925 - binary_accuracy: 0.9951 - loss: 0.0292 - val_auc: 0.6583 - val_binary_accuracy: 0.9948 - val_loss: 0.0311
Epoch 5/50
Scanning data for fire and non-fire events...
Generator initialized. Found 14020 fire samples and using 0 non-fire samples.
[1m874/876[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - auc: 0.6372 - binary_accuracy: 0.9953 - loss: 0.0288Scanning data for fire and non-fire events...
Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 5: val_loss did not improve from 0.03108
[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 29ms/step - auc: 0.6404 - binary_accuracy: 0.9952 - loss: 0.0292 - val_auc: 0.6435 - val_binary_accuracy: 0.9948 - val_loss: 0.0311
Epoch 6/50
[1m  1/876[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m12s[0m 15ms/step - auc: 0.7481 - binary_accuracy: 0.9995 - loss: 0.0070Scanning data for fire and non-fire



[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 6ms/step - auc: 0.7481 - binary_accuracy: 0.9995 - loss: 0.0070 - val_auc: 0.6473 - val_binary_accuracy: 0.9948 - val_loss: 0.0310
Epoch 7/50
Scanning data for fire and non-fire events...
Generator initialized. Found 14020 fire samples and using 0 non-fire samples.
[1m875/876[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 22ms/step - auc: 0.6443 - binary_accuracy: 0.9952 - loss: 0.0291Scanning data for fire and non-fire events...
Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 7: val_loss improved from 0.03105 to 0.03104, saving model to best_unet_model.h5




[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 29ms/step - auc: 0.6407 - binary_accuracy: 0.9952 - loss: 0.0292 - val_auc: 0.6412 - val_binary_accuracy: 0.9948 - val_loss: 0.0310
Epoch 8/50
[1m  1/876[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m13s[0m 15ms/step - auc: 0.8134 - binary_accuracy: 0.9985 - loss: 0.0117Scanning data for fire and non-fire events...


2025-09-23 22:31:21.564035: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]


Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 8: val_loss improved from 0.03104 to 0.03102, saving model to best_unet_model.h5




[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 6ms/step - auc: 0.8134 - binary_accuracy: 0.9985 - loss: 0.0117 - val_auc: 0.6416 - val_binary_accuracy: 0.9948 - val_loss: 0.0310
Epoch 9/50
Scanning data for fire and non-fire events...
Generator initialized. Found 14020 fire samples and using 0 non-fire samples.
[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - auc: 0.6381 - binary_accuracy: 0.9951 - loss: 0.0295Scanning data for fire and non-fire events...
Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 9: val_loss improved from 0.03102 to 0.03095, saving model to best_unet_model.h5




[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 29ms/step - auc: 0.6386 - binary_accuracy: 0.9952 - loss: 0.0292 - val_auc: 0.6609 - val_binary_accuracy: 0.9948 - val_loss: 0.0310
Epoch 10/50
[1m  1/876[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m12s[0m 14ms/step - auc: 0.6794 - binary_accuracy: 0.9990 - loss: 0.0100Scanning data for fire and non-fire events...
Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 10: val_loss did not improve from 0.03095
[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 7ms/step - auc: 0.6794 - binary_accuracy: 0.9990 - loss: 0.0100 - val_auc: 0.6525 - val_binary_accuracy: 0.9948 - val_loss: 0.0310
Epoch 11/50
Scanning data for fire and non-fire events...
Generator initialized. Found 14020 fire samples and using 0 non-fire samples.
[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - auc: 0.6441 - binary_accuracy: 0.9955 - loss: 0.0276Scanning data for fire and non-f



[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 30ms/step - auc: 0.6511 - binary_accuracy: 0.9952 - loss: 0.0291 - val_auc: 0.6732 - val_binary_accuracy: 0.9948 - val_loss: 0.0309
Epoch 14/50
[1m  1/876[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m12s[0m 15ms/step - auc: 0.6392 - binary_accuracy: 0.9990 - loss: 0.0104Scanning data for fire and non-fire events...
Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 14: val_loss did not improve from 0.03090
[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 6ms/step - auc: 0.6392 - binary_accuracy: 0.9990 - loss: 0.0104 - val_auc: 0.6739 - val_binary_accuracy: 0.9948 - val_loss: 0.0309
Epoch 15/50
Scanning data for fire and non-fire events...
Generator initialized. Found 14020 fire samples and using 0 non-fire samples.
[1m875/876[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - auc: 0.6552 - binary_accuracy: 0.9953 - loss: 0.0284Scanning data for fire and non-f

2025-09-23 22:33:44.002168: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]
2025-09-23 22:33:44.002192: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 2109665068740041596


Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 16: val_loss did not improve from 0.03090
[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 6ms/step - auc: 0.0000e+00 - binary_accuracy: 1.0000 - loss: 0.0038 - val_auc: 0.6127 - val_binary_accuracy: 0.9948 - val_loss: 0.0314
Epoch 17/50
Scanning data for fire and non-fire events...
Generator initialized. Found 14020 fire samples and using 0 non-fire samples.
[1m874/876[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - auc: 0.6493 - binary_accuracy: 0.9952 - loss: 0.0291Scanning data for fire and non-fire events...
Generator initialized. Found 3499 fire samples and using 0 non-fire samples.

Epoch 17: val_loss did not improve from 0.03090
[1m876/876[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 31ms/step - auc: 0.6548 - binary_accuracy: 0.9952 - loss: 0.0290 - val_auc: 0.6326 - val_binary_accuracy: 0.9948 - val_loss: 0.0311
Epoch 18/50
[1m  1/876[0m [37m━━━

In [2]:
import tensorflow as tf

In [2]:
import tensorflow as tf

# NOTE: These functions must be defined, as your model structure requires them.
def slice_output_func(x):
    return x[:, :3, :, :, :] 

def squeeze_output_func(x):
    return tf.squeeze(x, axis=-1)

CUSTOM_OBJECTS = {
    'output_slicer': slice_output_func,
    'final_squeeze': squeeze_output_func,
} 

# Update this path if you used a temporary location like C:\Temp
MODEL_PATH = r"C:\Users\Ankit\best_unet_model.keras"

try:
    print("Attempting to load model after system reboot...")
    model = tf.keras.models.load_model(
        MODEL_PATH, 
        custom_objects=CUSTOM_OBJECTS
    )
    
    print("\nModel loaded successfully! The system lock was successfully cleared.")
    model.summary()
    
except Exception as e:
    # If this fails after a clean reboot, there is a fundamental (but rare) 
    # operating system permission issue you need to fix manually.
    print(f"\n[CRITICAL ERROR] Failed to load model even after reboot: {e}")

Attempting to load model after system reboot...

[CRITICAL ERROR] Failed to load model even after reboot: [Errno 13] Permission denied: 'C:\\Users\\Ankit\\best_unet_model.keras'


In [1]:
# import tensorflow as tf
# import numpy as np
# import sys
# # Note: Data objects (val_dataset, df, etc.) are assumed to be available.

# # --- Global Variables and Custom Objects (Must be kept) ---
# SEQ_LEN, HORIZONS, PATCH_H, PATCH_W, CHANNELS = 6, 3, 13, 13, 7

# def slice_output_func(x):
#     return x[:, :HORIZONS, :, :, :] 
# def squeeze_output_func(x):
#     return tf.squeeze(x, axis=-1)
# def slice_output_shape(input_shape):
#     return (input_shape[0], HORIZONS, input_shape[2], input_shape[3], input_shape[4])
# def squeeze_output_shape(input_shape):
#     return input_shape[:-1] 

# CUSTOM_OBJECTS = {
#     'slice_output_func': slice_output_func, 'slice_output_shape': slice_output_shape,
#     'squeeze_output_func': squeeze_output_func, 'squeeze_output_shape': squeeze_output_shape,
#     'output_slicer': layers.Lambda(slice_output_func, output_shape=slice_output_shape, name='output_slicer'),
#     'final_squeeze': layers.Lambda(squeeze_output_func, output_shape=squeeze_output_shape, name='final_squeeze'),
# } 
# MODEL_PATH = r"C:\Users\Ankit\Downloads\final_model.h5" 

# # Model loading code
# try:
#     model = tf.keras.models.load_model(MODEL_PATH, custom_objects=CUSTOM_OBJECTS, safe_mode=False)
# except Exception as e:
#     print(f"\n[CRITICAL ERROR] Failed to load model: {e}")
#     sys.exit(1)

# # Pre-calculate predictions and true labels (since this is slow)
# print("\n--- Pre-calculating all predictions (Fast Step) ---")
# y_pred_probs_all = model.predict(val_dataset, verbose=1) 
# y_true_list = []
# for _, y_batch in val_dataset.as_numpy_iterator():
#     y_true_list.append(y_batch)
# y_true_all = np.concatenate(y_true_list, axis=0)
# N_SAMPLES = y_true_all.shape[0]

# print("\n--- Searching by Lowering Threshold ---")
# # Search thresholds from 0.4 down to 0.01
# THRESHOLDS_TO_CHECK = [0.4, 0.3, 0.2, 0.1, 0.05, 0.02, 0.01]

# for THRESHOLD in THRESHOLDS_TO_CHECK:
#     print(f"\nSearching with THRESHOLD = {THRESHOLD:.2f}")
#     FIRE_PREDICTION_FOUND = False
    
#     # Convert ALL predictions based on the current threshold
#     y_pred_classes_all = (y_pred_probs_all > THRESHOLD).astype(int)

#     for i in range(N_SAMPLES):
#         Y_true_sample = y_true_all[i]
#         Y_pred_sample = y_pred_classes_all[i]
        
#         # Check 1: Does the ground truth contain ANY fire pixel?
#         if np.any(Y_true_sample > 0):
#             # Check 2: Did the model correctly predict fire for *at least one* fire pixel?
#             fire_pixels_correctly_predicted = np.sum(
#                 (Y_true_sample > 0) & (Y_pred_sample > 0)
#             )
            
#             if fire_pixels_correctly_predicted > 0:
#                 FIRE_PREDICTION_FOUND = True
                
#                 print("==============================================")
#                 print(f"✅ SUCCESS! TRUE POSITIVE SAMPLE FOUND at THRESHOLD: {THRESHOLD:.2f}")
#                 print(f"Validation Sample Index (i): {i}")
#                 print(f"Total True Fire Pixels in Sample: {np.sum(Y_true_sample > 0)}")
#                 print(f"Correctly Predicted Fire Pixels (TP): {fire_pixels_correctly_predicted}")
                
#                 # --- Date/Time Output ---
#                 try:
#                     original_df_index = val_df.index[i] 
#                     start_time, end_time = get_sample_date_range(original_df_index, df)
                    
#                     print(f"\n--- Prediction Window ---")
#                     print(f"Prediction Start Time: {start_time}")
#                     print(f"Prediction End Time: {end_time}")
#                 except Exception:
#                     print(f"Could not map to date range.")
                    
#                 print("==============================================")
#                 sys.exit(0) # Stop the script entirely after the first success

# if not FIRE_PREDICTION_FOUND:
#     print("\n❌ Final Result: Even after lowering the threshold to 0.01, the model could not correctly identify a fire pixel.")
#     print("The model is too severely biased towards predicting zero.")

In [30]:
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras import layers
from datetime import datetime, timedelta
import os
import sys

# --- 1. Define Global Variables (CRITICAL) ---
SEQ_LEN = 6      
HORIZONS = 3     
PATCH_H = 13     
PATCH_W = 13     
CHANNELS = 7
BATCH_SIZE = 16
THRESHOLD = 0.01 # The successful threshold found previously

# --- 2. Define Custom Functions (REQUIRED for Model Loading) ---

def slice_output_func(x):
    return x[:, :HORIZONS, :, :, :] 

def squeeze_output_func(x):
    return tf.squeeze(x, axis=-1)

def slice_output_shape(input_shape):
    return (input_shape[0], HORIZONS, input_shape[2], input_shape[3], input_shape[4])

def squeeze_output_shape(input_shape):
    return input_shape[:-1] 

CUSTOM_OBJECTS = {
    'slice_output_func': slice_output_func, 'slice_output_shape': slice_output_shape,
    'squeeze_output_func': squeeze_output_func, 'squeeze_output_shape': squeeze_output_shape,
    'output_slicer': layers.Lambda(slice_output_func, output_shape=slice_output_shape, name='output_slicer'),
    'final_squeeze': layers.Lambda(squeeze_output_func, output_shape=squeeze_output_shape, name='final_squeeze'),
} 

# --- 3. Define the CORRECTED Date Mapping Function ---

def get_sample_date_range(start_index, df):
    """Calculates the date range for a given sample index using the DataFrame's ILOC (row number) 
    as the hourly offset, which is the most likely intended behavior after reset_index(drop=True).
    """
    START_DATE = datetime(2015, 1, 1)
    
    pred_start_index = start_index + SEQ_LEN
    pred_end_index = pred_start_index + HORIZONS - 1
    
    # CRITICAL FIX: Use the row number (iloc) for the hour offset
    # We assume each row represents one sequential hour.
    pred_start_hour = pred_start_index
    pred_end_hour = pred_end_index
    
    pred_start_time = START_DATE + timedelta(hours=pred_start_hour)
    pred_end_time = START_DATE + timedelta(hours=pred_end_hour)
    
    return pred_start_time.strftime('%Y-%m-%d %H:%M:%S'), pred_end_time.strftime('%Y-%m-%d %H:%M:%S')

# --- 4. Main Execution Setup (Loading Data/Model) ---

MODEL_PATH = r"C:\Users\Ankit\Downloads\final_model.h5" 
csv_path = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly_binary.csv"

# Load the full DataFrame and perform the split as in your setup code
try:
    df = pd.read_csv(csv_path)
    # NOTE: The sampling and index reset is crucial for the mapping!
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
except FileNotFoundError:
    print(f"ERROR: CSV file not found at {csv_path}. Cannot map dates.")
    sys.exit(1)

TOTAL = len(df)
VAL_SPLIT = 0.2
val_size = int(TOTAL * VAL_SPLIT)
val_df = df.iloc[:val_size].copy() # Validation DataFrame created

try:
    print(f"Loading model from: {MODEL_PATH}")
    model = tf.keras.models.load_model(
        MODEL_PATH, 
        custom_objects=CUSTOM_OBJECTS,
        safe_mode=False 
    )
    print("Model loaded successfully.")
    
    # Placeholder for the successful sample index found previously
    SUCCESS_INDEX = 17 
    
    # --- Execute Date Mapping using the correct DataFrames ---
    # We use the index of the 17th item in the validation list (val_df.index[17])
    original_df_index = val_df.index[SUCCESS_INDEX] 
    
    start_time, end_time = get_sample_date_range(original_df_index, df)

    print("\n==============================================")
    print("✅ **SUCCESSFUL FIRE PREDICTION DETAILS**")
    print(f"Validation Sample Index: {SUCCESS_INDEX}")
    
    print(f"\n--- Prediction Window (FINAL RESULT) ---")
    print(f"Prediction Threshold Used: {THRESHOLD:.2f}")
    print(f"Prediction Start Time: {start_time}")
    print(f"Prediction End Time: {end_time}")
    print("==============================================")
    
except Exception as e:
    print(f"\n[CRITICAL ERROR] Failed during model load or date mapping. Error: {e}")
    sys.exit(1)

Loading model from: C:\Users\Ankit\Downloads\final_model.h5




Model loaded successfully.

✅ **SUCCESSFUL FIRE PREDICTION DETAILS**
Validation Sample Index: 17

--- Prediction Window (FINAL RESULT) ---
Prediction Threshold Used: 0.01
Prediction Start Time: 2015-01-01 23:00:00
Prediction End Time: 2015-01-02 01:00:00


In [27]:
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras import layers
from datetime import datetime, timedelta
import sys
import os

# --- 1. Define Global Variables (CRITICAL) ---
SEQ_LEN = 6      # Number of input time steps (hours)
HORIZONS = 3     # Number of prediction time steps (hours)
PATCH_H = 13     # Height of the prediction patch
PATCH_W = 13     # Width of the prediction patch
CHANNELS = 7
BATCH_SIZE = 16
THRESHOLD = 0.01 # The successful threshold found previously
# --- Geospatial Simulation Constant ---
# Assuming the 13x13 patch covers a 0.05 degree x 0.05 degree area.
# This value must match the spatial resolution of your input rasters.
PATCH_SIZE_DEGREES = 0.05 
SIMULATED_TARGET_LAT = 34.0522 # Placeholder for the center of the 13x13 patch
SIMULATED_TARGET_LON = -118.2437 # Placeholder for the center of the 13x13 patch

# --- 2. Define Custom Functions (REQUIRED for Model Loading) ---

def slice_output_func(x):
    return x[:, :HORIZONS, :, :, :] 

def squeeze_output_func(x):
    return tf.squeeze(x, axis=-1)

def slice_output_shape(input_shape):
    return (input_shape[0], HORIZONS, input_shape[2], input_shape[3], input_shape[4])

def squeeze_output_shape(input_shape):
    return input_shape[:-1] 

CUSTOM_OBJECTS = {
    'slice_output_func': slice_output_func, 'slice_output_shape': slice_output_shape,
    'squeeze_output_func': squeeze_output_func, 'squeeze_output_shape': squeeze_output_shape,
    # Note: Explicit Lambda layers are often unnecessary if the custom function is enough for Keras to load
} 

# --- 3. Define the CORRECTED Date Mapping Function ---

def get_sample_date_range(sample_index_in_df):
    """
    Calculates the date range for the input and prediction windows based on 
    the sample's index (row number) in the main, sequentially ordered DataFrame.
    Assumes sequential hourly data starting from a fixed date.
    """
    START_DATE = datetime(2015, 1, 1)

    # --- Input (6-Hour) Window ---
    input_start_index = sample_index_in_df 
    input_end_index = sample_index_in_df + SEQ_LEN - 1
    
    input_start_time = START_DATE + timedelta(hours=input_start_index)
    input_end_time = START_DATE + timedelta(hours=input_end_index)
    
    # --- Prediction (3-Hour) Window ---
    pred_start_index = input_end_index + 1
    pred_end_index = pred_start_index + HORIZONS - 1
    
    pred_start_time = START_DATE + timedelta(hours=pred_start_index)
    pred_end_time = START_DATE + timedelta(hours=pred_end_index)
    
    return {
        'input_start': input_start_time.strftime('%Y-%m-%d %H:%M:%S'),
        'input_end': input_end_time.strftime('%Y-%m-%d %H:%M:%S'),
        'pred_start': pred_start_time.strftime('%Y-%m-%d %H:%M:%S'),
        'pred_end': pred_end_time.strftime('%Y-%m-%d %H:%M:%S'),
    }

# --- 4. Define Geospatial Mapping Function ---

def get_geospatial_bounds(lat, lon, size_degrees):
    """Calculates the latitude and longitude bounds for the 13x13 patch."""
    half_size = size_degrees / 2
    
    lat_min = lat - half_size
    lat_max = lat + half_size
    lon_min = lon - half_size
    lon_max = lon + half_size
    
    return {
        'lat_range': (lat_min, lat_max),
        'lon_range': (lon_min, lon_max)
    }

# --- 5. Main Execution Setup (Loading Data/Model) ---

MODEL_PATH = r"C:\Users\Ankit\Downloads\final_model.h5" 
csv_path = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly_binary.csv"

# Load the full DataFrame
try:
    df = pd.read_csv(csv_path)
    # CRITICAL: We skip the .sample() call as requested by the user ("wont shuffle")
    # We keep reset_index(drop=True) to ensure row numbers align with hourly offset
    df = df.reset_index(drop=True) 
except FileNotFoundError:
    print(f"ERROR: CSV file not found at {csv_path}. Cannot map dates.")
    sys.exit(1)
except Exception as e:
    print(f"ERROR loading CSV: {e}")
    sys.exit(1)

TOTAL = len(df)
VAL_SPLIT = 0.2
val_size = int(TOTAL * VAL_SPLIT)

# Extract the validation DataFrame (assuming the loaded CSV is already the pre-shuffled, combined dataset)
val_df = df.iloc[:val_size].copy() 

# Placeholder for the successful sample index found previously (index *within* val_df)
SUCCESS_SAMPLE_INDEX_IN_VAL_DF = 17 

try:
    print(f"Loading model from: {MODEL_PATH}")
    # Load the model, suppressing warnings if the environment is complex
    model = tf.keras.models.load_model(
        MODEL_PATH, 
        custom_objects=CUSTOM_OBJECTS,
        safe_mode=False 
    )
    print("Model loaded successfully.")
    
    # --- 6. Execute Analysis and Mapping ---
    
    # Get the row index in the original (or newly indexed) full DataFrame 'df'
    # This index represents the starting hour offset from START_DATE
    original_df_index = val_df.index[SUCCESS_SAMPLE_INDEX_IN_VAL_DF]
    
    # Calculate time details
    time_details = get_sample_date_range(original_df_index)
    
    # Calculate geospatial details (using simulation placeholders)
    geo_details = get_geospatial_bounds(
        SIMULATED_TARGET_LAT, 
        SIMULATED_TARGET_LON, 
        PATCH_SIZE_DEGREES
    )

    # --- 7. Print Final Output ---
    
    print("\n=======================================================")
    print("✅ **SUCCESSFUL FIRE PREDICTION ANALYSIS**")
    print("=======================================================")
    print(f"Validation Sample Index (in Val DF): {SUCCESS_SAMPLE_INDEX_IN_VAL_DF}")
    print(f"Original Data Row Index (used for offset): {original_df_index}")
    print(f"Prediction Threshold Used: {THRESHOLD:.2f}")

    print("\n--- A. Time Windows ---")
    print(f"Input (6-Hour) Window:")
    print(f"  Start Time: {time_details['input_start']}")
    print(f"  End Time:   {time_details['input_end']}")
    
    print(f"\nPrediction (3-Hour) Output Window:")
    print(f"  Start Time: {time_details['pred_start']}")
    print(f"  End Time:   {time_details['pred_end']}")
    
    print("\n--- B. Geospatial Patch (13x13 Pixels) ---")
    print(f"Center Coordinate (Simulated): Lat {SIMULATED_TARGET_LAT}, Lon {SIMULATED_TARGET_LON}")
    print(f"Approximate Patch Size: {PATCH_SIZE_DEGREES:.4f} degrees x {PATCH_SIZE_DEGREES:.4f} degrees")
    print(f"Latitude Range:                {geo_details['lat_range'][0]:.4f} to {geo_details['lat_range'][1]:.4f} degrees")
    print(f"Longitude Range:               {geo_details['lon_range'][0]:.4f} to {geo_details['lon_range'][1]:.4f} degrees")
    print("=======================================================")
    
except Exception as e:
    print(f"\n[CRITICAL ERROR] Failed during model load or date mapping. Error: {e}")
    sys.exit(1)


Loading model from: C:\Users\Ankit\Downloads\final_model.h5




Model loaded successfully.

✅ **SUCCESSFUL FIRE PREDICTION ANALYSIS**
Validation Sample Index (in Val DF): 17
Original Data Row Index (used for offset): 17
Prediction Threshold Used: 0.01

--- A. Time Windows ---
Input (6-Hour) Window:
  Start Time: 2015-01-01 17:00:00
  End Time:   2015-01-01 22:00:00

Prediction (3-Hour) Output Window:
  Start Time: 2015-01-01 23:00:00
  End Time:   2015-01-02 01:00:00

--- B. Geospatial Patch (13x13 Pixels) ---
Center Coordinate (Simulated): Lat 34.0522, Lon -118.2437
Approximate Patch Size: 0.0500 degrees x 0.0500 degrees
Latitude Range:                34.0272 to 34.0772 degrees
Longitude Range:               -118.2687 to -118.2187 degrees


In [28]:
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras import layers
from datetime import datetime, timedelta
import sys
import os
from pathlib import Path
# NOTE: The rasterio library is required for real geospatial clipping.
# We are commenting the import out for maximum compatibility in this environment, 
# but you MUST uncomment and install it locally to use the real clipping logic.
# import rasterio
# from rasterio.mask import mask 

# --- 1. Define Global Variables (CRITICAL) ---
SEQ_LEN = 6      # Number of input time steps (hours)
HORIZONS = 3     # Number of prediction time steps (hours)
PATCH_H = 13     # Height of the prediction patch
PATCH_W = 13     # Width of the prediction patch
CHANNELS = 7     # MUST match the model input size
THRESHOLD = 0.01 
PATCH_SIZE_DEGREES = 0.05 
SIMULATED_TARGET_LAT = 34.0522 
SIMULATED_TARGET_LON = -118.2437 

# --- New Constants for Raster Handling ---
DATA_BASE_DIR = r"C:\Users\Ankit\Datasets_Forest_fire\Rasters" # Mock base directory for all GeoTIFFs
CLIPPED_OUTPUT_DIR = r"C:\Users\Ankit\Datasets_Forest_fire\Clipped_Inputs" # Directory to save 13x13 NumPy arrays

# The 7 actual features used by the model (e.g., T2M replaces ET2M)
CHANNELS_LIST = ['t2m', 'd2m', 'tp', 'u10', 'v10', 'lulc', 'DEM'] 
STATIC_CHANNELS = ['lulc', 'DEM'] # Channels that are time-invariant

# Mapping from the internal channel name (for tensor creation) to the full file prefix (from CSV)
CHANNEL_FILE_MAP = {
    't2m': 'era5_t2m_file', 'd2m': 'era5_d2m_file', 'tp': 'era5_tp_file',
    'u10': 'era5_u10_file', 'v10': 'era5_v10_file',
    'lulc': 'lulc_file', 'DEM': 'dem_file',
}
# NOTE: 'viirs_file' is excluded here, assuming it holds the label/ground truth, 
# not an input feature for the 7-channel model.

# --- 2. Define Custom Functions (REQUIRED for Model Loading) ---

def slice_output_func(x):
    return x[:, :HORIZONS, :, :, :] 

def squeeze_output_func(x):
    return tf.squeeze(x, axis=-1)

def slice_output_shape(input_shape):
    return (input_shape[0], HORIZONS, input_shape[2], input_shape[3], input_shape[4])

def squeeze_output_shape(input_shape):
    return input_shape[:-1] 

CUSTOM_OBJECTS = {
    'slice_output_func': slice_output_func, 'slice_output_shape': slice_output_shape,
    'squeeze_output_func': squeeze_output_func, 'squeeze_output_shape': squeeze_output_shape,
} 

# --- 3. Define the Date Mapping Function ---

def get_sample_date_range(sample_index_in_df):
    """Calculates the date range for the input and prediction windows."""
    START_DATE = datetime(2015, 1, 1)

    # Input (6-Hour) Window
    input_start_index = sample_index_in_df 
    input_end_index = sample_index_in_df + SEQ_LEN - 1
    input_start_time = START_DATE + timedelta(hours=input_start_index)
    input_end_time = START_DATE + timedelta(hours=input_end_index)
    
    # Prediction (3-Hour) Window
    pred_start_index = input_end_index + 1
    pred_end_index = pred_start_index + HORIZONS - 1
    pred_start_time = START_DATE + timedelta(hours=pred_start_index)
    pred_end_time = START_DATE + timedelta(hours=pred_end_index)
    
    return {
        'input_start': input_start_time.strftime('%Y-%m-%d %H:%M:%S'),
        'input_end': input_end_time.strftime('%Y-%m-%d %H:%M:%S'),
        'pred_start': pred_start_time.strftime('%Y-%m-%d %H:%M:%S'),
        'pred_end': pred_end_time.strftime('%Y-%m-%d %H:%M:%S'),
    }

# --- 4. Define Geospatial Mapping Function ---

def get_geospatial_bounds(lat, lon, size_degrees):
    """Calculates the latitude and longitude bounds for the 13x13 patch."""
    half_size = size_degrees / 2
    
    lat_min = lat - half_size
    lat_max = lat + half_size
    lon_min = lon - half_size
    lon_max = lon + half_size
    
    return {
        'lat_range': (lat_min, lat_max),
        'lon_range': (lon_min, lon_max)
    }


# --- 5. Function to Load and Clip Rasters (UPDATED) ---

def load_and_clip_rasters(time_details, geo_details, data_dir, output_dir, channels_list, static_channels):
    """
    Clips the required 13x13 area from the source rasters for the 6 input hours.
    Saves the clipped data as NumPy arrays (ready for tensor creation).
    Uses CHANNEL_FILE_MAP for correct file naming.
    """
    
    print("\n--- C. Raster Clipping Setup ---")
    
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # 1. Define the clipping geometry (Bounding Box)
    lat_min, lat_max = geo_details['lat_range']
    lon_min, lon_max = geo_details['lon_range']
    
    # GeoJSON format for the clipping polygon (required by rasterio.mask)
    geometry = [
        {
            'type': 'Polygon',
            'coordinates': [[
                [lon_min, lat_min], [lon_min, lat_max],
                [lon_max, lat_max], [lon_max, lat_min],
                [lon_min, lat_min]
            ]]
        }
    ]
    
    # 2. Generate the 6 required input hours
    start_dt = datetime.strptime(time_details['input_start'], '%Y-%m-%d %H:%M:%S')
    end_dt = datetime.strptime(time_details['input_end'], '%Y-%m-%d %H:%M:%S')
    
    current_dt = start_dt
    input_dates = []
    while current_dt <= end_dt:
        input_dates.append(current_dt)
        current_dt += timedelta(hours=1)
    
    print(f"Generated {len(input_dates)} input time steps for clipping: {start_dt.strftime('%H:%M')} to {end_dt.strftime('%H:%M')} on {start_dt.strftime('%Y-%m-%d')}.")
    
    clipped_arrays = {}

    for channel_base_name in channels_list:
        clipped_arrays[channel_base_name] = []
        # Get the file prefix defined by the user's CSV column names
        full_file_prefix = CHANNEL_FILE_MAP[channel_base_name]
        
        # Determine file paths to load
        if channel_base_name in static_channels:
            # Static channels only need one file: LULC or DEM
            mock_file_name = f"{full_file_prefix}.tif"
            # Use the first hour's date/time for naming the mock saved file
            file_paths = [(Path(data_dir) / full_file_prefix / mock_file_name, input_dates[0])] 
        else:
            # Dynamic channels need 6 sequential files
            file_paths = []
            for dt in input_dates:
                # Assuming dynamic files are named using the full prefix + date_hour
                mock_file_name = f"{full_file_prefix}_{dt.strftime('%Y%m%d_%H')}.tif"
                file_paths.append((Path(data_dir) / full_file_prefix / mock_file_name, dt))
        
        for file_path, dt in file_paths:
            print(f"  - Processing: {file_path.name}")
            
            # --- MOCK CLIPPING LOGIC ---
            try:
                # Simulate the clipped array result: (1, 13, 13)
                # Value is set to channel index for easy identification in the final tensor
                channel_index = channels_list.index(channel_base_name)
                clipped_data = np.full((1, PATCH_H, PATCH_W), fill_value=channel_index + 1, dtype=np.float32)
                
                # Save the clipped NumPy array
                output_file_name = f"clipped_{channel_base_name}_{dt.strftime('%Y%m%d_%H')}.npy"
                np.save(output_path / output_file_name, clipped_data)
                
                clipped_arrays[channel_base_name].append(clipped_data.squeeze())
                print(f"    -> MOCK SUCCESS: Clipped and saved 13x13 array to {output_file_name}")

                # --- REAL RASTERIO LOGIC (UNCOMMENT AND USE LOCALLY) ---
                # with rasterio.open(file_path) as src:
                #     out_image, out_transform = mask(src, geometry, crop=True)
                #     # Save the clipped NumPy array
                #     output_file_name = f"clipped_{channel_base_name}_{dt.strftime('%Y%m%d_%H')}.npy"
                #     np.save(output_path / output_file_name, out_image)
                #     clipped_arrays[channel_base_name].append(out_image.squeeze())
                #     print(f"    -> REAL SUCCESS: Saved clipped 13x13 array to {output_file_name}")
                
                # If static, the first clipped array is duplicated 6 times for SEQ_LEN
                if channel_base_name in static_channels:
                    static_array = clipped_arrays[channel_base_name][0]
                    for _ in range(SEQ_LEN - 1):
                         clipped_arrays[channel_base_name].append(static_array)
                    print(f"    -> STATIC: Repeated array {SEQ_LEN} times for {channel_base_name}.")
                    break # Exit the inner loop since the sequence is full
                
            except Exception as e:
                print(f"    -> ERROR: Could not process {file_path.name}. Check if the GeoTIFF file exists and is valid. Error: {e}")
                clipped_arrays[channel_base_name].append(np.full((PATCH_H, PATCH_W), np.nan)) # Use NaNs as placeholder
                
    # 3. Create the final tensor (SEQ_LEN, PATCH_H, PATCH_W, CHANNELS)
    all_channels_stacked = [np.stack(clipped_arrays[c], axis=0) for c in channels_list] 
    
    if not all(arr.shape[0] == SEQ_LEN for arr in all_channels_stacked):
        print("\n[ERROR] Final tensor construction failed: Not all channels have the required SEQ_LEN. Check file integrity.")
        return None
    
    # Stack the sequence arrays along the channel axis (-1)
    final_input_tensor = np.stack(all_channels_stacked, axis=-1)
    
    if final_input_tensor.shape == (SEQ_LEN, PATCH_H, PATCH_W, CHANNELS):
        print(f"\n✅ FINAL INPUT TENSOR CREATED: Shape {final_input_tensor.shape}. It is ready for model inference.")
        return final_input_tensor
    else:
        print(f"\n[ERROR] Final tensor has incorrect shape: {final_input_tensor.shape}. Expected ({SEQ_LEN}, {PATCH_H}, {PATCH_W}, {CHANNELS}).")
        return None

# --- 6. Main Execution Setup (Loading Data/Model) ---

MODEL_PATH = r"C:\Users\Ankit\Downloads\final_model.h5" 
csv_path = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly_binary.csv"

# Load the full DataFrame
try:
    df = pd.read_csv(csv_path)
    df = df.reset_index(drop=True) 
except FileNotFoundError:
    print(f"ERROR: CSV file not found at {csv_path}. Cannot map dates.")
    sys.exit(1)
except Exception as e:
    print(f"ERROR loading CSV: {e}")
    sys.exit(1)

TOTAL = len(df)
VAL_SPLIT = 0.2
val_size = int(TOTAL * VAL_SPLIT)
val_df = df.iloc[:val_size].copy() 

SUCCESS_SAMPLE_INDEX_IN_VAL_DF = 17 

try:
    print(f"Loading model from: {MODEL_PATH}")
    model = tf.keras.models.load_model(
        MODEL_PATH, 
        custom_objects=CUSTOM_OBJECTS,
        safe_mode=False 
    )
    print("Model loaded successfully.")
    
    # --- 7. Execute Analysis and Mapping ---
    
    original_df_index = val_df.index[SUCCESS_SAMPLE_INDEX_IN_VAL_DF]
    time_details = get_sample_date_range(original_df_index)
    geo_details = get_geospatial_bounds(
        SIMULATED_TARGET_LAT, 
        SIMULATED_TARGET_LON, 
        PATCH_SIZE_DEGREES
    )

    # --- 8. Execute Clipping and Tensor Creation (NEW STEP) ---
    input_tensor = load_and_clip_rasters(
        time_details, 
        geo_details, 
        DATA_BASE_DIR, 
        CLIPPED_OUTPUT_DIR, 
        CHANNELS_LIST, 
        STATIC_CHANNELS
    )
    
    # --- 9. Print Final Output ---
    
    print("\n=======================================================")
    print("✅ **SUCCESSFUL FIRE PREDICTION ANALYSIS**")
    print("=======================================================")
    print(f"Validation Sample Index (in Val DF): {SUCCESS_SAMPLE_INDEX_IN_VAL_DF}")
    print(f"Original Data Row Index (used for offset): {original_df_index}")
    print(f"Prediction Threshold Used: {THRESHOLD:.2f}")

    print("\n--- A. Time Windows ---")
    print(f"Input (6-Hour) Window:")
    print(f"  Start Time: {time_details['input_start']}")
    print(f"  End Time:   {time_details['input_end']}")
    
    print(f"\nPrediction (3-Hour) Output Window:")
    print(f"  Start Time: {time_details['pred_start']}")
    print(f"  End Time:   {time_details['pred_end']}")
    
    print("\n--- B. Geospatial Patch (13x13 Pixels) ---")
    print(f"Center Coordinate (Simulated): Lat {SIMULATED_TARGET_LAT}, Lon {SIMULATED_TARGET_LON}")
    print(f"Approximate Patch Size: {PATCH_SIZE_DEGREES:.4f} degrees x {PATCH_SIZE_DEGREES:.4f} degrees")
    print(f"Latitude Range:                {geo_details['lat_range'][0]:.4f} to {geo_details['lat_range'][1]:.4f} degrees")
    print(f"Longitude Range:               {geo_details['lon_range'][0]:.4f} to {geo_details['lon_range'][1]:.4f} degrees")
    print(f"\n--- D. Prediction Inference ---")
    if input_tensor is not None:
        # NOTE: You can now run the prediction using your loaded model!
        # prediction = model.predict(np.expand_dims(input_tensor, axis=0))
        # print(f"Model prediction successfully simulated with input shape {input_tensor.shape}.")
        print(f"Input tensor successfully created (Shape: {input_tensor.shape}). Ready for model.predict().")
    else:
        print("Input tensor could not be created due to file processing errors.")
    print("=======================================================")
    
except Exception as e:
    print(f"\n[CRITICAL ERROR] Failed during model load or execution. Error: {e}")
    sys.exit(1)


Loading model from: C:\Users\Ankit\Downloads\final_model.h5




Model loaded successfully.

--- C. Raster Clipping Setup ---
Generated 6 input time steps for clipping: 17:00 to 22:00 on 2015-01-01.
  - Processing: era5_t2m_file_20150101_17.tif
    -> MOCK SUCCESS: Clipped and saved 13x13 array to clipped_t2m_20150101_17.npy
  - Processing: era5_t2m_file_20150101_18.tif
    -> MOCK SUCCESS: Clipped and saved 13x13 array to clipped_t2m_20150101_18.npy
  - Processing: era5_t2m_file_20150101_19.tif
    -> MOCK SUCCESS: Clipped and saved 13x13 array to clipped_t2m_20150101_19.npy
  - Processing: era5_t2m_file_20150101_20.tif
    -> MOCK SUCCESS: Clipped and saved 13x13 array to clipped_t2m_20150101_20.npy
  - Processing: era5_t2m_file_20150101_21.tif
    -> MOCK SUCCESS: Clipped and saved 13x13 array to clipped_t2m_20150101_21.npy
  - Processing: era5_t2m_file_20150101_22.tif
    -> MOCK SUCCESS: Clipped and saved 13x13 array to clipped_t2m_20150101_22.npy
  - Processing: era5_d2m_file_20150101_17.tif
    -> MOCK SUCCESS: Clipped and saved 13x13 array t