With this dataset, we are working with monthly environmental indices:

- LST *(Land Surface Temperature)*,
- NDVI *(Normalized Difference Vegetation Index)*,
- SM *(Soil Moisture)*,
- SPI *(Standardized Precipitation Index)*.

Before we dive into the code, we will need the following libraries:

- **Rasterio:** To read .tif files and extract spatial metadata.
- **Xarray:** To stack these 2D images into a 3D "Data Cube" (Time, Lat, Lon).
- **Matplotlib/Seaborn:** For visualization.

---

# 0. Initial Setup

## 0.1. Environment Installation

We need specialized geospatial libraries to handle coordinate systems and reproject the different data layers so they line up perfectly over Tunisia.

In [None]:
# Install rasterio for geospatial data handling
!pip install -q rasterio
# Install rioxarray for automatic CRS handling and reprojection
!pip install -q rioxarray regionmask



## 0.2. Library Imports

We load the standard data science stack along with `rioxarray` for geographic processing.

In [None]:
import os
import zipfile
import glob
import re
import rasterio
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import rioxarray
from rasterio.enums import Resampling


# Set plotting style
plt.style.use('ggplot')
%matplotlib inline

## 0.3. Helper Functions

These functions handle the messy parts of the data: extracting dates from your specific filenames and ensuring the spatial orientation (North/South) is correct.

In [7]:
def extract_date_from_filename(filename):
    """Parses year and month from string (e.g., '2023_01')"""
    match = re.search(r'(\d{4})_(\d{2})', filename)
    return f"{match.group(1)}-{match.group(2)}-01" if match else None

def load_and_project(base_path, folder_name, var_label, reference_da=None):
    """Loads TIFs, reprojects to UTM 32N, and aligns to a master grid."""
    # Search for files
    search_pattern = os.path.join(base_path, folder_name, "**/*.tif")
    files = sorted(glob.glob(search_pattern, recursive=True))
    
    if not files:
        print(f"‚ö†Ô∏è No files found for {var_label}")
        return None

    # Target CRS for Tunisia: UTM Zone 32N
    TARGET_CRS = "EPSG:32632"
    da_list = []
    
    for f in files:
        date_str = extract_date_from_filename(os.path.basename(f))
        if not date_str: continue
            
        # Open with rioxarray (automatically handles orientation)
        da = rioxarray.open_rasterio(f, masked=True).squeeze()
        
        # Reproject to Metric UTM 32N
        if reference_da is None:
            # First variable (LST) defines the "Master Grid"
            da = da.rio.reproject(TARGET_CRS, resampling=Resampling.bilinear)
        else:
            # Match all others to the Master Grid pixel-for-pixel
            da = da.rio.reproject_match(reference_da, resampling=Resampling.bilinear)
            
        da = da.expand_dims(time=[pd.to_datetime(date_str)])
        da_list.append(da)
    
    full_da = xr.concat(da_list, dim='time').rename(var_label)
    print(f"‚úÖ {var_label} loaded & aligned. Shape: {full_da.shape}")
    return full_da

## 0.4. Unified Data Loading

This cell creates your "Data Cube." It uses LST as the spatial reference and forces NDVI, Soil Moisture, and SPI to align to it. This solves the "backwards" and "distorted" issues by using a metric coordinate system.

In [11]:
# Update path to your Kaggle input
INPUT_PATH = '/Downloads/drouddght_data/'

# 1. Load Master Grid (LST)
lst_final = load_and_project(INPUT_PATH, 'LST', 'LST')

# 2. Align others to LST
ndvi_final = load_and_project(INPUT_PATH, 'NDVI', 'NDVI', reference_da=lst_final)
sm_final   = load_and_project(INPUT_PATH, 'SM', 'Soil_Moisture', reference_da=lst_final)
spi_final  = load_and_project(INPUT_PATH, 'SPI', 'SPI', reference_da=lst_final)

# 3. Final Merge
ds = xr.merge([lst_final, ndvi_final, sm_final, spi_final], join='inner')

print("\n--- Final Structured Dataset ---")
display(ds)

‚ö†Ô∏è No files found for LST
‚ö†Ô∏è No files found for NDVI
‚ö†Ô∏è No files found for Soil_Moisture
‚ö†Ô∏è No files found for SPI


TypeError: objects must be an iterable containing only DataTree(s), Dataset(s), DataArray(s), and dictionaries: [None, None, None, None]

## 0.5. Spatial & Orientation Verification

We plot a quick map to ensure Tunisia is right-side up and the aspect ratio is correct (not squashed).

In [None]:
plt.figure(figsize=(7, 9))
ds['NDVI'].isel(time=0).plot(cmap='YlGn', robust=True)

# Force metric 1:1 aspect ratio
plt.gca().set_aspect('equal', adjustable='box')

plt.title("Spatial Verification: Tunisia (UTM 32N)")
plt.xlabel("Easting (m)")
plt.ylabel("Northing (m)")
plt.show()

---

# 1. Understanding the Dataset

## 1.1. Data Source Audit & Reliability Analysis

Before building an uncertainty-aware forecasting model, we must establish how much "belief" we can place in each variable. In the context of Evidence Theory ("Smets" or "Dubois-Prade") or Possibility Theory, we treat these sensors not just as numbers, but as "mass functions" of evidence.

- **LST & NDVI:** High-resolution optical data. Very reliable for current state, but "stale" if clouds were present during the satellite overpass.
- **Soil Moisture (SM):** Often a blend of satellite microwave data and land-surface models. It has higher "epistemic uncertainty" (model bias).
- **SPI:** A purely mathematical transformation of precipitation. Its uncertainty comes from the density of the weather stations used to interpolate the rain map.

In this cell, we check for **Data Overlap** *(do sources disagree?)* and define our **Reliability Weights** which will later inform our uncertainty-aware model *(e.g., using Evidence Theory)*.

In [None]:
# 1.1 Data Source & Reliability Audit

# In UTM projection, coordinates are named 'x' and 'y'
total_pixels = ds.x.size * ds.y.size

# Checking for Spatial Disagreement (Mask Mismatch)
# This identifies if different sensors have different 'NoData' areas (e.g., coastal pixels)
nan_mask = ds.isnull()
disagreement_map = nan_mask.LST != nan_mask.NDVI

print(f"Total Grid Pixels: {total_pixels}")
print(f"Spatial Disagreement: {disagreement_map.sum().values.max()} pixels")

# Defining Reliability Factors (Basic Probability Assignments for Evidence Theory)
# These represent our 'Source Trust' coefficients
reliability_factors = {
    "LST": 0.85,          # High trust in thermal bands
    "NDVI": 0.90,         # Gold standard for vegetation state
    "Soil_Moisture": 0.70, # Lower trust due to model-based estimation
    "SPI": 0.80           # High trust, but statistically derived
}

# Summary of Time-Step Continuity
expected_months = pd.date_range(start=ds.time.min().values, 
                                end=ds.time.max().values, 
                                freq='MS')
missing_months = expected_months.difference(ds.time.values)

print(f"Temporal Coverage: {len(ds.time)} / {len(expected_months)} months")
if not missing_months.empty:
    print(f"‚ö†Ô∏è Warning: Missing months detected: {missing_months}")
else:
    print("‚úÖ Time series is temporally continuous.")

That **Spatial Disagreement (298,301 pixels)** is a massive red flag. Out of ~349k pixels, nearly 85% of your grid has "disagreement" in where data is valid (NaN values).

This tells us that your sources have very different "masks." For example, the Soil Moisture model might provide data over the desert where the NDVI sensor (which looks for greenness) sees "No Data," or the SPI dataset has a slightly different coastline. For an uncertainty-aware model, this means we must be very careful not to "hallucinate" relationships in areas where only one sensor is active.

## 1.2. Individual Feature Temporal Dynamics

To understand the "character" of our features, we need to see them as a human would. We are looking for:

1. **Seasonality:** Is the signal dominated by the calendar?

2. **Noise:** Does LST have spikes that look like sensor errors?

3. **Range:** Are the scales comparable (e.g., NDVI is 0 to 1, but LST is in Kelvin/Celsius)?



In [None]:
# 1.2 Temporal Visualization of Features

# Calculate the national spatial mean for each variable
df_features = ds.mean(dim=['x', 'y']).to_dataframe()

fig, axes = plt.subplots(4, 1, figsize=(15, 14), sharex=True)
feature_list = [
    ('LST', 'Red', 'Land Surface Temp'),
    ('NDVI', 'Green', 'Vegetation Index'),
    ('Soil_Moisture', 'Purple', 'Soil Moisture'),
    ('SPI', 'Blue', 'Standardized Precip Index')
]

for i, (col, color, title) in enumerate(feature_list):
    axes[i].plot(df_features.index, df_features[col], color=color, lw=1.5)
    axes[i].set_title(f"National Temporal Trend: {title}", fontsize=12, loc='left')
    axes[i].set_ylabel("Value")
    axes[i].grid(True, alpha=0.3)
    
    # Adding a rolling mean to visualize the long-term trend behind the seasonality
    axes[i].plot(df_features.index, df_features[col].rolling(12, center=True).mean(), 
                color='black', linestyle='--', alpha=0.8, label='12-Month Trend')

plt.xlabel("Year")
plt.tight_layout()
plt.show()

# Display basic statistics to understand scaling needs
print("--- Statistical Summary ---")
display(df_features.describe().T)

**Analysis for Model Selection:**

- **The "Memory" Effect:** Notice how SPI and Soil Moisture fluctuate rapidly, while NDVI has a smoother, slightly lagged curve. This suggests our model needs memory (like an LSTM or a Transformer).
- **Stationarity:** LST is highly stationary (predictable cycles), while SPI is stochastic (random).
- **Uncertainty Application:** Since your spatial disagreement is high, Possibility Theory might be more robust than Evidence Theory here. Possibility Theory allows us to define "imprecise" boundaries for pixels that are near the edge of a sensor's mask, whereas Evidence Theory ("Smets" or "Dubois-Prade") might struggle if the "conflict" between sources is too high due to missing data.

**Analysis of Statistical Distributions:**

- **LST (Mean: 31.89¬∞C, Max: 47.52¬∞C):** These values are in Celsius. A mean of ~32¬∞C suggests a hot, arid climate (consistent with Tunisia), but a maximum of 47.5¬∞C indicates extreme heat stress. The high standard deviation (10.13) confirms strong seasonal swings between winter and summer.
- **NDVI (Mean: 0.177, Max: 0.240):** This range is quite low. In healthy Mediterranean forests, NDVI can reach 0.6‚Äì0.8. A max of 0.24 means your dataset is dominated by sparse vegetation, olive groves, or arid lands. This small "dynamic range" (0.13 to 0.24) means the model must be very sensitive to small changes.
- **Soil_Moisture (Mean: 7.69, Max: 12.16):** These units are likely $kg/m^2$ (common in GLDAS/Noah models) rather than volumetric percentages ($m^3/m^3$). If we assume a 10cm soil layer, 7.69 $kg/m^2$ is roughly 7.6% volumetric moisture, which is extremely dry‚Äîtypical for Tunisian topsoil.
- **SPI (Mean: -0.03, Min: -1.42, Max: 2.47):** SPI is a unitless standard deviation.
  - **SPI = 0** is the historical average.
  - **SPI < -1.0** indicates "Moderately Dry."
  - **SPI < -1.5** is "Severely Dry."
  - Your minimum of -1.42 suggests you have captured moderate-to-severe drought events, but perhaps not the "Exceptional" droughts (SPI < -2.0) that occurred in very recent years.

**Observations on "Band" and "Spatial_Ref":**

- **Band (Mean: 1.0):** This is a metadata artifact from the .tif loading process. It just confirms every file has 1 band. This column can be ignored/dropped during preprocessing.
- **Spatial_Ref (Mean: 0.0):** This is a coordinate placeholder for the UTM 32N projection information. It contains no variable data.

**Decision for Modeling:**

The data is on completely different scales (LST goes up to 47, NDVI is 0.2). This confirms that normalization is mandatory. However, since we want to be "uncertainty aware," we should not just use a simple Min-Max scaler. We should consider a **Z-score (StandardScaler)** because SPI is already standardized; putting the other variables on a similar "standard deviation" scale will make the Theory of Evidence fusion much more mathematically stable.

**Hybrid Evidential-Possibilistic Model:**

1. **Transferable Belief Model (TBM):** Proposed by Philippe Smets. It doesn't force the data to 100%. If sensors conflict, it assigns that conflict to the Empty Set ($\varnothing$), which we interpret as "The sensors are broken/conflicting." This is a huge "Warning Light" for your model.
2. **Possibility Theory (Maxitive Fusion):** Instead of multiplying (which amplifies conflict), Possibility Theory uses a `max` operator. If one sensor says 1.0 and another says 0.0, the "Possibility" stays 1.0. It is much more "tolerant" of disagreement.
3. **Conflict Redistribution (PCR6):** This is the modern standard. It takes that "conflicting mass" and shares it back among the sensors based on their individual reliability, rather than throwing it into a 1% event.

## 1.3. Spatio-Temporal Field Visualization (Animated)

We will use `matplotlib.animation` to create a sequence. Since different variables represent different physical processes, we'll use specific colormaps:

- **LST:** `inferno` (perceptually uniform for heat).
- **NDVI:** `YlGn` (Yellow to Green for vegetation health).
- **Soil Moisture:** `Blues` (Light to deep water content).
- **SPI:** `RdYlBu` (Red for dry/negative, Blue for wet/positive‚Äîthe meteorological standard).

In [None]:
from matplotlib.animation import FuncAnimation, FFMpegWriter
from IPython.display import HTML


def create_feature_animation(da, var_name, cmap, title, save_filename=None):
    # Calculate global limits to keep the legend/colorbar static
    vmin, vmax = float(da.min()), float(da.max())
    
    fig, ax = plt.subplots(figsize=(6, 8))
    
    # Initialize the plot with the first time step
    im = da.isel(time=0).plot(ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, 
                               add_colorbar=True, 
                               cbar_kwargs={'label': var_name})
    
    ax.set_aspect('equal', adjustable='box')
    title_text = ax.set_title(f"{title} - {da.time.values[0].astype('M8[M]')}")

    def update(i):
        # Update the image data without redrawing the colorbar
        im.set_array(da.isel(time=i).values.flatten())
        title_text.set_text(f"{title} - {da.time.values[i].astype('M8[M]')}")
        return im, title_text

    # Create animation (we'll do 24 months to keep it fast, or len(da.time) for all)
    ani = FuncAnimation(fig, update, frames=24, interval=300, blit=True)
    
    # Save if a filename is provided
    if save_filename:
        writer = FFMpegWriter(fps=4, metadata=dict(artist='Bahri, Dhiaa Eddine'), bitrate=1800)
        ani.save(f"/kaggle/working/{save_filename}.mp4", writer=writer)
        print(f"‚úÖ Video saved to: /kaggle/working/{save_filename}.mp4")

    plt.close() 
    return ani

# Example: Animating NDVI for the most recent 2 years
# Change 'NDVI' to 'LST', 'Soil_Moisture', or 'SPI' as needed
# ndvi_anim = create_feature_animation(ds['NDVI'], "NDVI", "YlGn", "Vegetation Health")
# lst_anim = create_feature_animation(
#     ds['LST'], 
#     "LST (¬∞C)", 
#     "inferno", 
#     "Land Surface Temperature", 
#     # save_filename="tunisia_lst_3year_trend"
# )

# sm_anim = create_feature_animation(
#     ds['Soil_Moisture'], 
#     "Soil Moisture (mm)", 
#     "Blues", 
#     "Soil Moisture", 
#     # save_filename="tunisia_sm_3year_trend"
# )

spi_anim = create_feature_animation(
    ds['SPI'], 
    "Standardized Precipitation Index", 
    "RdYlBu", 
    "Standardized Precipitation Index", 
    # save_filename="tunisia_spi_3year_trend"
)

# HTML(ndvi_anim.to_jshtml())
# HTML(lst_anim.to_jshtml())
# HTML(sm_anim.to_jshtml())
HTML(spi_anim.to_jshtml())

## 1.4. Chronological Data Splitting

In drought forecasting, we must respect the "arrow of time." We cannot use a random shuffle split because drought has **temporal memory**‚Äîthe soil moisture of today is heavily dependent on the rainfall of last month. If we allow the model to see future data during training, we create "data leakage," making our uncertainty assessments useless.

We will split your 299 months (approx. 25 years) into three distinct blocks:

- **Train (70%):** Establishing the long-term climatological "baseline."
- **Validation (15%):** Tuning the hyper-parameters and fuzzy membership functions.
- **Test (15%):** Evaluating the model on the most recent, severe drought events in Tunisia (2021‚Äì2024).

In [None]:
# 1.4 Implementing the Chronological Split Strategy

# Define the split points
n_time = len(ds.time)
train_end = int(n_time * 0.70)
val_end = int(n_time * 0.85)

# Split the dataset
ds_train = ds.isel(time=slice(0, train_end))
ds_val   = ds.isel(time=slice(train_end, val_end))
ds_test  = ds.isel(time=slice(val_end, None))

print(f"--- Data Split Summary ---")
print(f"Train: {ds_train.time.values[0].astype('M8[D]')} to {ds_train.time.values[-1].astype('M8[D]')} ({len(ds_train.time)} months)")
print(f"Val:   {ds_val.time.values[0].astype('M8[D]')} to {ds_val.time.values[-1].astype('M8[D]')} ({len(ds_val.time)} months)")
print(f"Test:  {ds_test.time.values[0].astype('M8[D]')} to {ds_test.time.values[-1].astype('M8[D]')} ({len(ds_test.time)} months)")

# Reliability Factor: We calculate the 'Historical Variance' 
# This helps our uncertainty model understand the 'natural noise' of each sensor
uncertainty_base = {var: float(ds_train[var].std()) for var in ['LST', 'NDVI', 'Soil_Moisture', 'SPI']}
print(f"\nBase Aleatory Uncertainty (Standard Deviation):")
for var, val in uncertainty_base.items():
    print(f"- {var}: {val:.4f}")

**Analysis of the Split:**

By reserving the last 15% for testing, we are specifically testing the model's ability to handle the **2023-2024 Tunisian drought.** This is the ultimate "stress test" for an uncertainty-aware model:
> *can it maintain high "belief" in its forecast during an unprecedented extreme event?*

## 1.5. Reliability & Conflict Assessment

Before we leave the "Understanding" phase, we must mathematically define the Conflict ($K$) between our sources. In the Theory of Evidence, conflict arises when two sensors provide evidence for mutually exclusive states.

In our context, if **LST Anomaly** is very high (indicating extreme heat/drought) but **NDVI Anomaly** is also high (indicating healthy, green vegetation), these two sources are in conflict. This often happens in Tunisia due to irrigation or deep-rooted vegetation (olives) that resists short-term heat.

In [None]:
# 1.5 Conflict Mapping (Preparing for Evidence Theory)

def calculate_conflict_map(ds_slice):
    """
    Identifies pixels where LST and NDVI provide opposing signals.
    High value = High Conflict (One says drought, the other says healthy).
    """
    # Standardize locally for comparison
    lst_norm = (ds_slice.LST - ds_slice.LST.mean()) / ds_slice.LST.std()
    ndvi_norm = (ds_slice.NDVI - ds_slice.NDVI.mean()) / ds_slice.NDVI.std()
    
    # Conflict is high when LST is high (+) and NDVI is high (+), 
    # or LST is low (-) and NDVI is low (-). 
    # Drought typically sees LST (+) and NDVI (-).
    # We use the product to highlight areas where signs align incorrectly for drought.
    conflict = lst_norm * ndvi_norm
    return conflict

# Calculate conflict for a specific recent time step
recent_conflict = calculate_conflict_map(ds.isel(time=-1))

plt.figure(figsize=(10, 6))
recent_conflict.plot(cmap='bwr', robust=True)
plt.title("Spatial Conflict Map ($K$): LST vs NDVI\n(Red: High Conflict | Blue: Consensus)")
plt.gca().set_aspect('equal')
plt.show()

print(f"Mean Conflict Coefficient: {float(recent_conflict.mean()):.4f}")

A **Mean Conflict Coefficient of -0.6120** is actually excellent news.

In the way we calculated this (1$LST_{norm} \times NDVI_{norm}$), a negative value means the variables are moving in opposite directions.2 For drought monitoring, this is the "Natural Consensus":

- When Temperature ($LST$) goes up (positive anomaly), Vegetation ($NDVI$) goes down (negative anomaly).
- $(+) \times (-) = (-)$

Since your average is significantly negative (-0.61), it proves that for most of Tunisia, your sensors are "agreeing" on the physical reality of the ecosystem. The "Red" areas (positive values) in your map are the specific exceptions where the **Zadeh Paradox** would be a risk‚Äîthese are likely irrigated zones or coastal areas where the relationship breaks down.

## 1.6. Final Data Audit: Reliability Weights

Feature       | Reliability (Œ±) | Justification
--------------|-----------------|-----------------
LST           | 0.85            | "High precision, but sensitive to daily fluctuations and cloud shadows."
NDVI          | 0.90            | Most stable indicator of long-term drought impact on the ground.
Soil Moisture | 0.70            | Often model-derived; high uncertainty in deep soil layers.
SPI           | 0.80            | "Statistically robust, but spatial resolution is often coarser than satellite data."

In [None]:
# 1.6 Finalizing Reliability Constants
# These will be used in Phase 3 (Modeling) to weight the evidence fusion.

source_reliability = {
    "LST": 0.85,
    "NDVI": 0.90,
    "Soil_Moisture": 0.70,
    "SPI": 0.80
}

print("‚úÖ Phase 1 Complete: Dataset characteristics, spatial conflicts, and source reliability established.")

---

# 2. Preprocessing

Now that we have established the "Reliability" of our expert witnesses (the data columns), we move to **Phase 2**. In this phase, we transform raw physical measurements into standardized signals that our hybrid model can interpret.

## 2.1. Monthly Climatology and Anomaly Calculation

A raw LST of 35¬∞C in Tunisia is "cool" for August but "extreme" for January. To make our model context-aware, we calculate **Anomalies**. This removes the seasonal "heartbeat" of the Mediterranean climate, leaving only the signals of environmental stress.

We calculate the **Monthly Mean ($\mu$)** and **Standard Deviation ($\sigma$)** for the Training period (2000‚Äì2017) and apply those parameters to the Validation and Test sets to prevent **Temporal Data Leakage**.

In [None]:
# 2.1 Calculating Monthly Climatologies and Anomalies

def get_climatology(ds_train):
    """Calculates the mean and std for each month using the training set."""
    clim_mean = ds_train.groupby("time.month").mean("time")
    clim_std = ds_train.groupby("time.month").std("time")
    return clim_mean, clim_std

# Calculate baseline from Training Set ONLY
ds_mean, ds_std = get_climatology(ds_train)

def apply_anomalies(ds, ds_mean):
    """Subtracts the historical monthly mean to get the anomaly."""
    return ds.groupby("time.month") - ds_mean

# Apply to all splits
ds_train_anom = apply_anomalies(ds_train, ds_mean)
ds_val_anom   = apply_anomalies(ds_val, ds_mean)
ds_test_anom  = apply_anomalies(ds_test, ds_mean)

print("‚úÖ Anomalies calculated using 2000-2017 baseline.")
print(f"LST Anomaly Example (First Pixel): {float(ds_train_anom.LST.isel(x=0, y=0, time=0)):.2f}¬∞C")

That `nan¬∞C` is a timely red flag. It confirms that the 85% spatial disagreement we identified earlier isn't just a metadata curiosity‚Äîit‚Äôs a structural reality of your data.

If the "First Pixel" is `NaN`, it means that either the raw value was missing or our 2000‚Äì2017 climatology baseline for that specific month/pixel coordinate has no data. Before we touch any more math, we need to map the "Anatomy of Missingness."

## 2.2. Corrective Step: The "Vouching" Mask

If a pixel is missing for all years in January, the mean becomes `NaN`. When you subtract that `NaN` from your data, the result is `NaN`. This is why your "First Pixel" example returned `NaN`.

We must distinguish between **Land** (where data exists) and **Void** (where it doesn't).

Instead of just plotting missingness, we will create a **Count Map**. This tells us how many actual physical observations supported each "Mean" we calculated in Step 2.1. This is the first step toward true **Epistemic Uncertainty** modeling.

In [None]:
# 2.2 Quantifying the "Evidence Base" for our Climatology

def analyze_missingness_fixed(ds):
    null_counts = ds.isnull().sum(dim='time').compute()
    total_time = len(ds.time)
    null_perc = (null_counts / total_time) * 100
    
    # Increase figure width to accommodate 4 maps side-by-side
    fig, axes = plt.subplots(1, 4, figsize=(24, 6))
    vars = ['LST', 'NDVI', 'Soil_Moisture', 'SPI']
    
    for i, var in enumerate(vars):
        # We use .plot() but must ensure the aspect ratio is locked
        p = null_perc[var].plot(ax=axes[i], cmap='Reds', vmin=0, vmax=100, add_colorbar=True)
        
        # --- THE FIX IS HERE ---
        axes[i].set_aspect('equal', adjustable='box') 
        # -----------------------
        
        axes[i].set_title(f"{var}\n% Missing")
    
    plt.tight_layout()
    plt.show()

analyze_missingness_fixed(ds)

Soil Moisture is a massive "smoking gun" for our model's uncertainty. If the Soil Moisture plot is a solid red (or white) rectangle with 0% missing values while **LST** and **NDVI** have irregular "jagged" holes, it means you are dealing with two different types of "Expert" data:

1. **Direct Observation (LST/NDVI):** These are likely from optical satellites (like MODIS). They have holes where there are clouds or sensor gaps. This is **real-world noise**.
2. **Reanalysis/Model Data (Soil Moisture):** If it's a perfect rectangle, it‚Äôs likely a "gridded product" (like GLDAS or ERA5-Land). These use a computer model to fill the gaps.

**The Danger: "Fake Certainty:"**

This is a classic trap. The Soil Moisture "Expert" is essentially saying: 
> *"I am never missing!"*

but in reality, it's just guessing what happened under the clouds.

If we use the **standard Dempster-Shafer** rule here, the model might "trust" Soil Moisture more simply because it's always there, even though it's actually the least "raw" measurement. We need to penalize this "rectangle of 0%" to reflect its **Epistemic Uncertainty**.

## 2.3. The "Tunisia-First" Masking & Interpolation

We will use `geopandas` and `regionmask` to create a perfect vector-based mask of Tunisia. Then, we will fill the internal gaps using Linear Interpolation, but only for small gaps to avoid creating "fake data" over large missing regions.

In [None]:
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np

# 1. CREATE THE LAND MASK (Logical approach)
# A pixel is 'Tunisia' if it has valid NDVI more than 5% of the time.
# This is version-independent and perfectly matches your grid.
land_mask = (ds.NDVI.notnull().sum(dim='time') > (len(ds.time) * 0.05)).compute()

# 2. DEFINE THE CLEANING FUNCTION
def clean_dataset(ds_input, mask):
    # Apply the mask: Everything outside Tunisia becomes NaN
    ds_land = ds_input.where(mask)
    
    # INTERPOLATION: 
    # We use 'linear' interpolation along the time dimension.
    # limit=2 ensures we only fill small gaps (clouds) and don't invent months of data.
    print("Starting temporal interpolation... this may take a minute.")
    ds_filled = ds_land.interpolate_na(
        dim="time", 
        method="linear", 
        limit=2
    )
    return ds_filled

# 3. EXECUTE
ds_clean = clean_dataset(ds, land_mask)

# 4. VERIFY THE MASK
plt.figure(figsize=(6, 8))
land_mask.plot(cmap='Greys_r')
plt.title("Generated Tunisia Land Mask\n(Black = Sea/Void, White = Study Area)")
plt.gca().set_aspect('equal')
plt.show()

print("‚úÖ Step 2.3 Complete: Masked and Interpolated.")

## 2.4. The "Interpolation Penalty" (Crucial for Uncertainty)

Since we have now "invented" some data via interpolation, we must update our Reliability Factors ($\alpha$) from Phase 1.

If a value was interpolated, its reliability should drop. In Dempster-Shafer theory, this is how we tell the model:
> *"I have a value here, but I'm not 100% sure it's real."*

In [None]:
# Create a 'Quality Flag' layer
# 1 = Raw Data, 0.5 = Interpolated, 0 = Missing
quality_flag = xr.where(ds.notnull(), 1.0, 0.0)
quality_flag = quality_flag.where(land_mask) # Only land

# Update quality for pixels that WERE null but are NOW filled
interpolated_mask = ds.isnull() & ds_clean.notnull()
quality_flag = xr.where(interpolated_mask, 0.5, quality_flag)

print("‚úÖ Quality metadata created: Interpolated pixels penalized by 50%.")

## 2.5. Standardize (Z-Scores)

Now that the data is cleaned, we calculate the statistics. Note the use of `.compute()`‚Äîthis is the secret to preventing Kaggle memory crashes.

In [None]:
# Separate SPI from the variables that need scaling
vars_to_scale = ['LST', 'NDVI', 'Soil_Moisture']
spi_var = ['SPI']

# 1. Calculate Climatology only for the unscaled variables
print("Calculating Climatology for LST, NDVI, and Soil Moisture...")
ds_mean = ds_clean[vars_to_scale].isel(time=slice(0, int(len(ds_clean.time) * 0.7))).groupby("time.month").mean("time").compute()
ds_std  = ds_clean[vars_to_scale].isel(time=slice(0, int(len(ds_clean.time) * 0.7))).groupby("time.month").std("time").compute()

import gc

# 1. Clear everything from previous attempts
gc.collect()

# 2. Ensure your ds_clean is 'chunked' (This is the most important step)
# This breaks the data into small blocks so it never fills the RAM
ds_clean = ds_clean.chunk({'time': 12, 'x': -1, 'y': -1})

# 3. Calculate Climatology (Keep this small)
print("Calculating small Climatology baseline...")
train_slice = ds_clean.isel(time=slice(0, int(len(ds_clean.time) * 0.7)))
ds_mean = train_slice[vars_to_scale].groupby("time.month").mean("time").compute()
ds_std  = train_slice[vars_to_scale].groupby("time.month").std("time").compute()

# 4. Define the recipe for Standardization (DO NOT COMPUTE YET)
print("Defining lazy standardization recipe...")
ds_scaled = (ds_clean[vars_to_scale].groupby("time.month") - ds_mean) / ds_std

if 'month' in ds_scaled.coords:
    ds_scaled = ds_scaled.drop_vars('month')

# 5. Combine with SPI (Still Lazy)
ds_z = xr.merge([ds_scaled, ds_clean[spi_var]])

# 6. Final verification - We only compute ONE small slice to prove it works
print("Testing one time-step to verify...")
test_slice = ds_z.LST.isel(time=0).compute()
print("‚úÖ Success! The recipe is ready without crashing the kernel.")

---

# 3. Data Fusion

## 3.1. The Fuzzy Membership Functions

Now that we have **Z-scores** for all features (including SPI), we need to translate these raw numbers into "Expert Opinions." In fuzzy logic, a Z-score of $-2.0$ shouldn't just be a number; it should represent a **Degree of Membership** in a category (e.g., "Severe Drought").

We will define three fuzzy sets for each expert:

1. **Dry (D):** High membership when Z-scores are low (for NDVI, SM, SPI) or high (for LST).
2. **Normal (N):** High membership when Z-scores are near zero.
3. **Wet (W):** High membership when Z-scores are positive (for NDVI, SM, SPI) or low (for LST).

### 3.1.1. Defining the Membership Curves

We will use **Sigmoid** functions for the extremes (Dry/Wet) and a **Gaussian** function for the Normal state. This provides a smooth transition between states, which is essential for the "Uncertainty" aspect of our hybrid model.

In [None]:
import numpy as np

def fuzzy_membership(z_score, var_name):
    """
    Assigns membership values for Dry, Normal, and Wet states.
    Note: LST is 'inverted' (High Z-score = Dry).
    """
    # 1. Define 'Dry' Membership
    if var_name == 'LST':
        # High LST = Dry (Right-leaning Sigmoid)
        m_dry = 1 / (1 + np.exp(-2 * (z_score - 1.5)))
        m_wet = 1 / (1 + np.exp(2 * (z_score + 1.5)))
    else:
        # Low NDVI/SM/SPI = Dry (Left-leaning Sigmoid)
        m_dry = 1 / (1 + np.exp(2 * (z_score + 1.5)))
        m_wet = 1 / (1 + np.exp(-2 * (z_score - 1.5)))
        
    # 2. Define 'Normal' Membership (Gaussian centered at 0)
    m_normal = np.exp(-0.5 * (z_score**2))
    
    # 3. Normalize to ensure they sum to something logical (Optional but helpful)
    # In DS Theory, these will become our basic probability assignments (m)
    return m_dry, m_normal, m_wet

print("‚úÖ Fuzzy logic rules defined.")

### 3.1.2. Translating to Evidence Theory (Mass Functions)

This is the bridge to the **Dempster-Shafer (DS)** part of your project. Each fuzzy membership value becomes a **Mass ($m$)**.

Recall your "Reliability Factors" ($\alpha$)? We apply them here. If Expert "LST" is only 80% reliable, we multiply its fuzzy memberships by $0.8$. The remaining $0.2$ becomes **Ignorance ($\Theta$)**.

**Visualizing the NDVI Expert's Logic:**

We will plot the membership curves over a range of Z-scores from **-4 to +4**. This covers almost every statistical possibility in your Tunisian dataset.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_fuzzy_logic(var_name='NDVI'):
    # Generate a range of Z-scores
    z_range = np.linspace(-4, 4, 100)
    
    # Calculate memberships using the logic we defined
    # (Left-leaning Sigmoid for Dry, Gaussian for Normal, Right-leaning for Wet)
    m_dry = 1 / (1 + np.exp(2 * (z_range + 1.5)))
    m_normal = np.exp(-0.5 * (z_range**2))
    m_wet = 1 / (1 + np.exp(-2 * (z_range - 1.5)))
    
    # Plotting
    plt.figure(figsize=(10, 6))
    plt.plot(z_range, m_dry, label='Dry Membership (Stress)', color='red', lw=2)
    plt.plot(z_range, m_normal, label='Normal Membership', color='green', lw=2)
    plt.plot(z_range, m_wet, label='Wet Membership', color='blue', lw=2)
    
    # Add context lines
    plt.axvline(x=-1.5, color='gray', linestyle='--', alpha=0.5, label='Severe Threshold')
    plt.axvline(x=0, color='black', linestyle=':', alpha=0.3)
    
    plt.title(f"Fuzzy Expert Opinion for {var_name}")
    plt.xlabel("Standardized Anomaly (Z-Score)")
    plt.ylabel("Degree of Membership ($\mu$)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# visualize_fuzzy_logic('NDVI')

# Commented out placeholders for other variables to stay focused:
visualize_fuzzy_logic('LST') # Note: LST would be flipped!
# visualize_fuzzy_logic('Soil_Moisture')
# visualize_fuzzy_logic('SPI')

**How to interpret this for your Hybrid Model:**

- **The Overlap:** Notice that at a Z-score of **-1.0**, the expert is "split." It has a ~30% membership in **Dry** and a ~60% membership in **Normal**. This is the "Fuzzy" part‚Äîit acknowledges that nature doesn't have a hard cutoff.
- **The High-Certainty Zone:** Once the Z-score hits **-3.0**, the "Dry" membership is nearly 1.0. The expert is shouting:
  > *"This is a drought!"*
- **The Evidence Transition:** In the next step, we will take these three values (Dry, Normal, Wet) and multiply them by your **Reliability Factor ($\alpha$)**.
    - If $\alpha = 0.8$ and Dry $= 1.0$, then the **Mass ($m$)** for Drought becomes **0.8**.
    - The "missing" **0.2** becomes **Ignorance ($\Theta$)**, meaning:
      > *"I am 80% sure it's a drought, and 20% of me is just unsure because I'm an imperfect sensor."*
      
Now that we have the expert "opinions" (Fuzzy Membership) for each variable, we enter:

## 3.2. Dempster-Shafer (DS) Evidence Fusion

This is where your model becomes a "Hybrid." We aren't just averaging the numbers; we are mathematically resolving the conflict between experts (e.g., if LST says "Hot/Dry" but NDVI still looks "Green/Normal").

### 3.2.1. Defining the "Basic Probability Assignments" (Masses)

We convert our Fuzzy values into **Masses ($m$)**. For each pixel, each expert $i$ provides:

- **$m_i(\{D\})$:** Mass assigned to **Drought**.
- **$m_i(\{N\})$:** Mass assigned to **Normal**.
- **$m_i(\{W\})$:** Mass assigned to **Wet**.
- **$m_i(\Theta)$:** Mass assigned to **Ignorance** (Uncertainty).

### 3.2.2. The Fusion Engine (Dempster‚Äôs Rule of Combination)

To combine Expert A (LST) and Expert B (NDVI), we use the orthogonal sum. We will define a function that takes two "opinion vectors" and merges them into one.

In [None]:
def dempster_combination(m1, m2):
    """
    Combines two mass distributions using Dempster's Rule.
    m1, m2: dicts with keys 'D', 'N', 'W', 'Theta'
    """
    # 1. Calculate the 'Conflict' (K)
    # K is the sum of products where the experts disagree (e.g., D ‚à© N, D ‚à© W)
    k = (m1['D']*m2['N'] + m1['D']*m2['W'] + 
         m1['N']*m2['D'] + m1['N']*m2['W'] + 
         m1['W']*m2['D'] + m1['W']*m2['N'])
    
    # If k = 1, the experts completely contradict each other (Extreme Conflict)
    if k >= 1.0: return {'D':0, 'N':0, 'W':0, 'Theta':1} 
    
    scaling_factor = 1 / (1 - k)
    
    # 2. Calculate Combined Masses
    m_combined = {}
    for state in ['D', 'N', 'W']:
        # Intersection of same states + intersection with ignorance
        m_combined[state] = (m1[state]*m2[state] + 
                             m1[state]*m2['Theta'] + 
                             m2[state]*m1['Theta']) * scaling_factor
        
    # 3. Calculate remaining Ignorance
    m_combined['Theta'] = (m1['Theta'] * m2['Theta']) * scaling_factor
    
    return m_combined

print("‚úÖ Fusion Engine logic ready.")

**Why this is better than a simple average:**

- **The Ignorance Filter:** If Soil Moisture is missing data (Mass goes to $\Theta$), the formula automatically gives more weight to the other experts.
- **Conflict Detection:** If $K$ is high, it tells us the sensors are "arguing." This is a high-level feature we can later feed into your TPU-based Neural Network to help it learn which expert is right in specific Tunisian regions (like the more humid North vs. the arid South).

**Testing a "Conflict" Scenario:**

Let's simulate a real-world scenario:

- **LST Expert:** Says it's extremely hot (Z = +2.5) $\rightarrow$ High mass for **Dry**.
- **NDVI Expert:** Says the plants are still green (Z = 0.0) $\rightarrow$ High mass for **Normal**.

In [None]:
# Simulated reliability (alpha)
alpha_lst = 0.8
alpha_ndvi = 0.9

# Get fuzzy memberships for LST (Hot = Dry)
# We'll use the logic: m_dry=0.9, m_normal=0.1
m_lst = {'D': 0.9*alpha_lst, 'N': 0.1*alpha_lst, 'W': 0, 'Theta': (1-alpha_lst)}

# Get fuzzy memberships for NDVI (Normal)
m_ndvi = {'D': 0.1*alpha_ndvi, 'N': 0.8*alpha_ndvi, 'W': 0.1*alpha_ndvi, 'Theta': (1-alpha_ndvi)}

# Fuse them!
result = dempster_combination(m_lst, m_ndvi)

print(f"Fused Opinion: {result}")


def robust_fusion(m1, m2):
    # Calculate Conflict (K)
    k = (m1['D']*m2['N'] + m1['D']*m2['W'] + 
         m1['N']*m2['D'] + m1['N']*m2['W'] + 
         m1['W']*m2['D'] + m1['W']*m2['N'])
    
    m_fused = {}
    # Combine beliefs only (No normalization here yet)
    for state in ['D', 'N', 'W']:
        m_fused[state] = (m1[state]*m2[state] + 
                          m1[state]*m2['Theta'] + 
                          m2[state]*m1['Theta'])
        
    # Yager-style: All conflict (K) + intersection of ignorance goes to Theta
    m_fused['Theta'] = (m1['Theta'] * m2['Theta']) + k
    
    return m_fused

print("‚úÖ Robust Fusion Engine initialized. Zadeh Paradox mitigated by assigning conflict to Ignorance.")
print(f"Robust Fused Opinion: {robust_fusion(m_lst, m_ndvi)}")

Feature    | Standard DS                       | Robust (Yager)
-----------|-----------------------------------|------------------
Philosophy | "Someone must be right."          | "If you're fighting, I don't trust either."
Certainty  | High (95%)                        | Low (38%)
Risk       | High (Potential for false alarms) | Low (Admits data quality issues)
Model Role | Makes a choice.                   | Passes the problem to the Neural Network.

Since we are dealing with a 25-year dataset over the entire map of Tunisia, we cannot use a Python loop. We need to **vectorize** the Robust Fusion logic so that it runs as a single mathematical operation across all pixels simultaneously.

## 3.3. Vectorized Robust Fusion (The "Evidence Cube")

We will convert our fuzzy logic into `xarray`/`numpy` operations. Instead of dictionaries, we will create a new dataset where each pixel has four "Mass" layers: **D, N, W, and Theta**.

In [None]:
import xarray as xr
import numpy as np

def get_mass_assignments(ds_z, alpha_dict):
    """
    Converts Z-scores into Mass assignments (m) for all experts.
    alpha_dict: {'LST': 0.8, 'NDVI': 0.9, ...}
    """
    mass_ds = xr.Dataset()
    
    for var in ds_z.data_vars:
        z = ds_z[var]
        alpha = alpha_dict.get(var, 0.8)
        
        # 1. Calculate Fuzzy Memberships (Vectorized)
        if var == 'LST':
            m_dry_raw = 1 / (1 + np.exp(-2 * (z - 1.5)))
            m_wet_raw = 1 / (1 + np.exp(2 * (z + 1.5)))
        else:
            m_dry_raw = 1 / (1 + np.exp(2 * (z + 1.5)))
            m_wet_raw = 1 / (1 + np.exp(-2 * (z - 1.5)))
            
        m_norm_raw = np.exp(-0.5 * (z**2))
        
        # 2. Apply Reliability Discounting (Mass Assignment)
        mass_ds[f'{var}_D'] = m_dry_raw * alpha
        mass_ds[f'{var}_N'] = m_norm_raw * alpha
        mass_ds[f'{var}_W'] = m_wet_raw * alpha
        mass_ds[f'{var}_Theta'] = (1 - alpha)
        
    return mass_ds

# Define your expert reliabilities based on Phase 1 research
alphas = {'LST': 0.85, 'NDVI': 0.9, 'Soil_Moisture': 0.75, 'SPI': 0.95}
ds_mass = get_mass_assignments(ds_z, alphas)

## 3.4. The Chain-Fusion Step

Because we have 4 experts, we fuse them sequentially:

1. Fuse **LST** and **NDVI** $\rightarrow$ **Result A**
2. Fuse **Result A** and **Soil Moisture** $\rightarrow$ **Result B**
3. Fuse **Result B** and **SPI** $\rightarrow$ **Final Decision Cube**

In [None]:
def fuse_two_experts_vectorized(m1_prefix, m2_prefix, ds):
    """
    Fuses two experts across the entire map using Robust (Yager) logic.
    """
    # Extract mass components
    d1, n1, w1, t1 = ds[f'{m1_prefix}_D'], ds[f'{m1_prefix}_N'], ds[f'{m1_prefix}_W'], ds[f'{m1_prefix}_Theta']
    d2, n2, w2, t2 = ds[f'{m2_prefix}_D'], ds[f'{m2_prefix}_N'], ds[f'{m2_prefix}_W'], ds[f'{m2_prefix}_Theta']
    
    # 1. Calculate Conflict K (Vectorized)
    k = (d1*n2 + d1*w2 + n1*d2 + n1*w2 + w1*d2 + w1*n2)
    
    # 2. Combine Beliefs (Robust Logic)
    # New state = (Agreement + Ignorance interaction)
    fused_d = (d1*d2 + d1*t2 + d2*t1)
    fused_n = (n1*n2 + n1*t2 + n2*t1)
    fused_w = (w1*w2 + w1*t2 + w2*t1)
    
    # 3. Conflict goes to Ignorance
    fused_t = (t1*t2) + k
    
    return fused_d, fused_n, fused_w, fused_t

# --- EXECUTION ---
# Step 1: Fuse LST and NDVI
d_ln, n_ln, w_ln, t_ln = fuse_two_experts_vectorized('LST', 'NDVI', ds_mass)

# Step 2: Create temporary storage for intermediate result to keep code clean
# In a real run, we would continue this for SM and SPI...

**Why this is the turning point:**

Once this runs, we will have a `fused_t` (Conflict/Ignorance) map.

- Regions with high `fused_t` are "Uncertainty Hotspots."
- In the North of Tunisia, where sensors usually agree, `fused_t` will be low.
- In the transition zones (Steppe), `fused_t` will be high.

To visualize the uncertainty, we will calculate the final fusion of all four experts (LST, NDVI, Soil Moisture, and SPI) and then plot the **Ignorance ($\Theta$)** layer.

This map is essentially a "Conflict Map." If a pixel is bright, it means the satellite data is contradictory, and the hybrid model is correctly identifying its own limitations.

## 3.5. The 4-Way Fusion Execution

In [None]:
# Sequential Fusion of all 4 Experts
# 1. LST + NDVI
d1, n1, w1, t1 = fuse_two_experts_vectorized('LST', 'NDVI', ds_mass)

# 2. Add Soil Moisture
# We create a temporary structure to hold the previous result
ds_temp = xr.Dataset({'tmp_D': d1, 'tmp_N': n1, 'tmp_W': w1, 'tmp_Theta': t1})
d2, n2, w2, t2 = fuse_two_experts_vectorized('tmp', 'Soil_Moisture', 
                                            xr.merge([ds_temp, ds_mass]))

# 3. Add SPI (The final expert)
ds_temp2 = xr.Dataset({'tmp2_D': d2, 'tmp2_N': n2, 'tmp2_W': w2, 'tmp2_Theta': t2})
final_D, final_N, final_W, final_Theta = fuse_two_experts_vectorized('tmp2', 'SPI', 
                                                                    xr.merge([ds_temp2, ds_mass]))

# Create the Final Hybrid Dataset
ds_hybrid = xr.Dataset({
    'Drought_Belief': final_D,
    'Normal_Belief': final_N,
    'Wet_Belief': final_W,
    'Uncertainty': final_Theta
})

print("‚úÖ Full 4-Way Robust Fusion complete.")

## 3.6. Visualizing the "Uncertainty Hotspots"

Let's look at a slice from a known drought year (e.g., Summer 2021). We want to see if the model is more "confused" in the desert transition zones or the agricultural North.

In [None]:
import matplotlib.pyplot as plt

# 1. Select time and ensure we only have one month (e.g., the first one)
# We add .squeeze() to drop any dimension that has only 1 entry
sample_time = ds_hybrid.isel(time=-30).squeeze()

# If 'month' still has multiple entries, we must pick one to plot a 2D map
if 'month' in sample_time.dims and sample_time.month.size > 1:
    sample_time = sample_time.isel(month=0)

fig, ax = plt.subplots(1, 2, figsize=(15, 8))

time_label = sample_time.time.dt.strftime("%B %Y").values

# 2. Use '...' in transpose to avoid the ValueError
# This handles ('y', 'x', 'month') by moving y and x to the front
plot_data_1 = sample_time.Drought_Belief.transpose('y', 'x', ...)
plot_data_2 = sample_time.Uncertainty.transpose('y', 'x', ...)

# Map 1: Belief
plot_data_1.plot.pcolormesh(ax=ax[0], x='x', y='y', cmap='YlOrRd', robust=True)
ax[0].set_title(f"Hybrid Belief: Drought (D ({time_label}))")
ax[0].set_aspect('equal')

# Map 2: Uncertainty
plot_data_2.plot.pcolormesh(ax=ax[1], x='x', y='y', cmap='Purples', robust=True)
ax[1].set_title(r"Hybrid Uncertainty: Conflict ($\Theta$)")
ax[1].set_aspect('equal')

plt.tight_layout()
plt.show()

---

# 4. Processessing

## 4.1. Preparing for the TPU (Data Shaping)

To move into the Deep Learning phase, we need to transform our `ds_hybrid` into a format the TPU can digest. TPUs love **4D Tensors**: `(Samples, Time_Steps, Features, Space)`.

We will treat the 4 fused mass layers (D, N, W, Theta) as our **Features**.

In [None]:
import dask.array as da

# 1. Ensure we are calling the correct variables from ds_hybrid
# Let's double check the keys exist
print("Available variables:", list(ds_hybrid.data_vars))

# 2. Extract arrays as Dask arrays (Lazy)
# This prevents the "empty list" error by pulling them explicitly
try:
    d_map = ds_hybrid['Drought_Belief'].data
    n_map = ds_hybrid['Normal_Belief'].data
    w_map = ds_hybrid['Wet_Belief'].data
    u_map = ds_hybrid['Uncertainty'].data
    
    # 3. Stack them along a NEW 'feature' dimension (axis=-1)
    # The shape will be (time, y, x, 4)
    data_for_tpu = da.stack([d_map, n_map, w_map, u_map], axis=-1)
    
    print(f"‚úÖ Success! Data Shape: {data_for_tpu.shape}")
    print("Format: (Time, Latitude, Longitude, Features)")

except KeyError as e:
    print(f"‚ùå Error: One of the features was not found in ds_hybrid: {e}")

# 1. Fix the 5D to 4D issue
# We take the mean across the 'month' axis (axis 3) to collapse the redundant dimension
if len(data_for_tpu.shape) == 5:
    data_for_tpu = data_for_tpu.mean(axis=3)

print(f"Corrected Shape for TPU: {data_for_tpu.shape}") # Should be (299, 851, 410, 4)

## 4.2. The TPU Dataset Bridge

Now that we have a 4D Dask array, we can set up the TensorFlow pipeline. This code needs to be very specific about types to work on a TPU.

In [None]:
import tensorflow as tf

def dask_to_tf_dataset(dask_array, batch_size=8):
    # Determine the shape of a single time-step (y, x, features)
    sample_shape = dask_array.shape[1:] 
    
    def generator():
        # Loop through the time dimension
        for i in range(dask_array.shape[0]):
            # Pull one month into RAM at a time
            yield dask_array[i].compute()

    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=tf.TensorSpec(shape=sample_shape, dtype=tf.float32)
    )

    # TPU Requirement: Batch size must be a multiple of 8 (number of TPU cores)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset.prefetch(tf.data.AUTOTUNE)

# Create the final pipeline
tpu_input = dask_to_tf_dataset(data_for_tpu)
print("üöÄ TPU Pipeline Ready.")