In [None]:
import os
import xarray as xr
import matplotlib.pyplot as plt

# Define the directory containing .nc files
input_dir = "/home/user/MSc/Raw_Data/AIG-PFT_Data"

# Updated latitude and longitude bounds
regions = {
    "Eastern_Mediterranean_Nile_Delta": {
        "lat_min": 31.0, "lat_max": 34.0,  # Focused on the Nile Delta
        "lon_min": 28.0, "lon_max": 33.0   # Eastern Mediterranean near the Nile
    },
    "East_China_Sea_Yangtze_Delta": {
        "lat_min": 30.5, "lat_max": 33.5,  # Same size as the Nile Delta region
        "lon_min": 120.5, "lon_max": 125.5 # Near the Yangtze Delta
    }
}

# Variable to visualize
selected_variable = "Diat"  # Change this if needed

# Dictionary to store extracted datasets
extracted_datasets = {}

def extract_and_plot(file_path):
    """Extracts and plots the selected regions from the NetCDF file."""
    try:
        ds = xr.open_dataset(file_path)

        if "lat" not in ds or "lon" not in ds:
            print("Missing lat/lon dimensions! Skipping file.")
            return None  # Skip processing if lat/lon are missing

        file_datasets = {}

        for region, bounds in regions.items():
            lat_slice = slice(bounds["lat_max"], bounds["lat_min"]) if ds["lat"][0] > ds["lat"][-1] else slice(bounds["lat_min"], bounds["lat_max"])
            lon_slice = slice(bounds["lon_min"], bounds["lon_max"])

            filtered_ds = ds.sel(lat=lat_slice, lon=lon_slice)

            if selected_variable not in ds.data_vars:
                print(f"Variable '{selected_variable}' not found in dataset. Skipping.")
                continue

            data = filtered_ds[selected_variable]

            if data.size == 0:
                print(f"Empty data for {region}. Skipping.")
                continue

            # Store filtered dataset for saving in next step
            file_datasets[region] = filtered_ds

            # Plot the extracted data
            plt.figure(figsize=(8, 6))
            data.plot(cmap="viridis")
            plt.title(f"{region} - {selected_variable}")
            plt.xlabel("Longitude")
            plt.ylabel("Latitude")
            plt.show()

        ds.close()
        return file_datasets

    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None

# Process all .nc files in the directory
for file_name in os.listdir(input_dir):
    if file_name.endswith(".nc"):
        file_path = os.path.join(input_dir, file_name)

        # Extract and plot
        extracted_data = extract_and_plot(file_path)

        # Store the extracted datasets for saving later
        if extracted_data:
            extracted_datasets[file_name] = extracted_data

print("Extraction and plotting complete.")
