In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import geopandas as gpd
from shapely.geometry import Point
import glob
import papermill as pm

In [None]:
# Dictionary containing city locations and their respective domains
location = {
     'Mexico City' : dict(lon=-99.0833, lat=19.4667, domain = 'CAM-22'),
     'Buenos Aires' : dict(lon=-58.416, lat=-34.559, domain = 'SAM-22'),
     'New York' : dict(lon=-74.2261, lat=40.8858, domain = 'NAM-22'),
     'Sydney' : dict(lon=151.01810, lat=-33.79170, domain = 'AUS-22'),
     'Beijing' : dict(lon=116.41, lat=39.90, domain = 'EAS-22'),
     'Tokyo' : dict(lon = 139.84, lat = 35.65, domain = 'EAS-22'),
     'Jakarta' : dict(lon = 106.81, lat = -6.2, domain = 'SEA-22'), 
     'Johannesburg' : dict(lon=28.183, lat=-25.733, domain = 'AFR-22'),
     'Riyadh' : dict(lon=46.73300, lat=24.7000, domain = 'WAS-22'),
     'Berlin' : dict(lon=13.4039, lat=52.4683, domain = 'EUR-11'),
     'Paris' : dict(lon=  2.35, lat=48.85, domain = 'EUR-11'),
     'London' : dict(lon= -0.13, lat=51.50, domain = 'EUR-11'),
     'Madrid' : dict(lon= -3.70, lat=40.42, domain = 'EUR-11'),
     'Los Angeles': dict(lon = -118.24, lat = 34.05, domain = 'NAM-22'),
     'Montreal': dict(lon = -73.56, lat = 45.50, domain = 'NAM-22'),
     'Chicago': dict(lon = -87.55, lat = 41.73, domain = 'NAM-22'),
     'Bogota': dict(lon = -74.06, lat = 4.62, domain = 'SAM-22'),
     'Baghdad': dict(lon = 44.40, lat = 33.34, domain = 'WAS-22'),
     'Tehran': dict(lon = 51.42, lat = 35.69, domain = 'WAS-22'),
     'Tashkent': dict(lon = 69.24, lat = 41.31, domain = 'WAS-22'),
     'Cairo': dict(lon = 31.25, lat = 30.06, domain = 'AFR-22'),
     'Delhi [New Delhi]': dict(lon = 77.22, lat = 28.64, domain = 'WAS-22'),
}

model_dict={
     'REMO' : dict(urban_variable='urban',orog_variable='orog',sea_variable='sftlf'),
    'RegCM' : dict(urban_variable='sftuf',orog_variable='orog',sea_variable='sftlf'),
    }

In [None]:
def data_city(city: str, location: dict, model: str, data: str = "CORDEX-CORE"):
    '''
    Retrieves urban, orography, sea masks, and temp dataset for the chosen city and model.

    Parameters:
        city (str): The chosen city.
        location (dict): Dictionary containing city locations and their respective domains.
        model (str): The chosen model. Options are "REMO" and "RegCM".
        data (str): The chosen data options: CORDEX-CORE or CORDEX-EUR-11.

    Returns:
        tuple: Dataset of minimal temperature, sea mask, orography mask, urban fraction, latitude, longitude of the city, and nearby stations.
    '''
    # Retrieve city data from location dictionary
    city_data = location.get(city, {})
    lon = city_data.get('lon')
    lat = city_data.get('lat')
    domain = city_data.get('domain')

    if data=="CORDEX-CORE":
        ds=xr.open_dataset("/lustre/gmeteo/WORK/DATA/C3S-CDS/C3S-CICA-Atlas/v1/CORDEX-CORE/historical/tn_CORDEX-CORE_historical_mon_197001-200512.nc")
    elif domain=="EUR-11" and data=="CORDEX-EUR-11":
        ds=xr.open_dataset("/lustre/gmeteo/WORK/DATA/C3S-CDS/C3S-CICA-Atlas/v1/CORDEX-EUR-11/historical/tn_CORDEX-EUR-11_historical_mon_197001-200512.nc")

    # Paths for sea, orography, and urban masks based on model and domain
    if model=="REMO":
        base_path_sea='/lustre/gmeteo/WORK/DATA/CORDEX-FPS-URB-RCC/nextcloud/CORDEX-CORE-WG/REMO/land-sea-mask_C/'
        base_path_urban='/lustre/gmeteo/WORK/DATA/CORDEX-FPS-URB-RCC/nextcloud/CORDEX-CORE-WG/REMO/urbanfraction_C/orig_v3/'
        base_path_orography='/lustre/gmeteo/WORK/DATA/CORDEX-FPS-URB-RCC/nextcloud/CORDEX-CORE-WG/REMO/orography_C/'
        # Searching for files
        file_sea = glob.glob(base_path_sea +'*'+domain+'*')
        file_urban = glob.glob(base_path_urban +'*'+domain+'*')
        file_orography = glob.glob(base_path_orography+'*'+domain+'*')

    elif model=="RegCM":
        base_path_sea='/lustre/gmeteo/WORK/DATA/CORDEX-FPS-URB-RCC/nextcloud/CORDEX-CORE-WG/RegCM/land-sea-mask_C/'
        base_path_urban='/lustre/gmeteo/WORK/DATA/CORDEX-FPS-URB-RCC/nextcloud/CORDEX-CORE-WG/RegCM/urbanfraction_C/'
        base_path_orography='/lustre/gmeteo/WORK/DATA/CORDEX-FPS-URB-RCC/nextcloud/CORDEX-CORE-WG/RegCM/orography_C/'
        # Searching for files
        file_sea = glob.glob(base_path_sea +'*'+domain+'*')
        file_urban = glob.glob(base_path_urban +'*'+domain+'*')
        file_orography = glob.glob(base_path_orography+ '*'+domain+'*')


    # Open datasets for sea, orography, and urban masks
    sea_mask = xr.open_dataset(file_sea[0])
    orography = xr.open_dataset(file_orography[0])
    urbanfraction = xr.open_dataset(file_urban[0])

    ghcnd_stations_url = 'https://www.ncei.noaa.gov/data/global-historical-climatology-network-daily/doc/ghcnd-stations.txt'
    ghcnd_stations_column_names = ['code', 'lat', 'lon', 'elev', 'name', 'net', 'numcode']
    ghcnd_stations_column_widths = [   11,     9,    10,      7,     34,     4,       10 ]
    df = pd.read_fwf(ghcnd_stations_url, header = 0, widths = ghcnd_stations_column_widths, names = ghcnd_stations_column_names)
    ghcnd_stations=gpd.GeoDataFrame(df, geometry=gpd.points_from_xy(df.lon, df.lat), crs = 'EPSG:4326')
    
    rval = ghcnd_stations.assign(dist = ghcnd_stations.distance(Point(lon, lat)))
    rval.sort_values(by = 'dist', inplace = True)
    rval = rval[rval.dist < 0.5].to_crs(epsg=3857)  
    return ds, sea_mask, orography, urbanfraction, lat, lon, rval
        

In [None]:
from papermill import execute_notebook

# Parameters for data_city (all required for execution)
city = "New York"  # Replace with the desired city name
location = {
     'Mexico City' : dict(lon=-99.0833, lat=19.4667, domain = 'CAM-22'),
     'Buenos Aires' : dict(lon=-58.416, lat=-34.559, domain = 'SAM-22'),
     'New York' : dict(lon=-74.2261, lat=40.8858, domain = 'NAM-22'),
     'Sydney' : dict(lon=151.01810, lat=-33.79170, domain = 'AUS-22'),
     'Beijing' : dict(lon=116.41, lat=39.90, domain = 'EAS-22'),
     'Tokyo' : dict(lon = 139.84, lat = 35.65, domain = 'EAS-22'),
     'Jakarta' : dict(lon = 106.81, lat = -6.2, domain = 'SEA-22'), 
     'Johannesburg' : dict(lon=28.183, lat=-25.733, domain = 'AFR-22'),
     'Riyadh' : dict(lon=46.73300, lat=24.7000, domain = 'WAS-22'),
     'Berlin' : dict(lon=13.4039, lat=52.4683, domain = 'EUR-11'),
     'Paris' : dict(lon=  2.35, lat=48.85, domain = 'EUR-11'),
     'London' : dict(lon= -0.13, lat=51.50, domain = 'EUR-11'),
     'Madrid' : dict(lon= -3.70, lat=40.42, domain = 'EUR-11'),
     'Los Angeles': dict(lon = -118.24, lat = 34.05, domain = 'NAM-22'),
     'Montreal': dict(lon = -73.56, lat = 45.50, domain = 'NAM-22'),
     'Chicago': dict(lon = -87.55, lat = 41.73, domain = 'NAM-22'),
     'Bogota': dict(lon = -74.06, lat = 4.62, domain = 'SAM-22'),
     'Baghdad': dict(lon = 44.40, lat = 33.34, domain = 'WAS-22'),
     'Tehran': dict(lon = 51.42, lat = 35.69, domain = 'WAS-22'),
     'Tashkent': dict(lon = 69.24, lat = 41.31, domain = 'WAS-22'),
     'Cairo': dict(lon = 31.25, lat = 30.06, domain = 'AFR-22'),
     'Delhi [New Delhi]': dict(lon = 77.22, lat = 28.64, domain = 'WAS-22'),
}
model = "REMO"  # Replace with the model name (e.g., REMO, RegCM)
data = "CORDEX-CORE"  # Replace with the data source (e.g., CORDEX-CORE)
season = 'all'  # Season (all, specific season like 'jfm')

model_dict={
     'REMO' : dict(urban_variable='urban',orog_variable='orog',sea_variable='sftlf'),
    'RegCM' : dict(urban_variable='sftuf',orog_variable='orog',sea_variable='sftlf'),
    }
 

# Get data for the specified city
ds, city_lon, city_lat, rval, urbanfraction, orography, sea_mask = data_city(city, location, model, data)



In [None]:
def plot_boundary(ax, lon, lat, dist_lon, dist_lat, color='blue',zorder=1):
    """
    Plot the boundary of a square defined by longitude and latitude values.

    Parameters:
        ax (matplotlib.axes.Axes): The matplotlib axes to plot on.
        lon (float): Longitude value.
        lat (float): Latitude value.
        dist_lon (float): Distance in longitude.
        dist_lat (float): Distance in latitude.
        color (str, optional): Color of the boundary line.
    """
    ax.plot([lon - dist_lon, lon + dist_lon, lon + dist_lon, lon - dist_lon, lon - dist_lon], 
            [lat - dist_lat, lat - dist_lat, lat + dist_lat, lat + dist_lat, lat - dist_lat], 
            color=color, linewidth=1,zorder=zorder)

In [None]:
def plot_cities(ds: xr.Dataset, city: str, model: str, model_dict:dict, sea_mask: xr.Dataset, orography: xr.Dataset, urbanfraction: xr.Dataset,
                CORDEX: str, lat: int, lon: int, dist_lon: int, dist_lat: int, vtmin: int, vtmax: int, rval: xr.Dataset,  season: str='all',
                period: slice = slice('1980-01-01', '2000-12-31'),  urban_min: float = 0.1,orog_max: float = 100, sea_max:float=50,percentil: float =10,cmap_urban:str='binary',
               cmap_orog:str='terrain',cmap_sea:str='ocean_r') :
    '''
    Plots urban, orography, sea masks, and mean min temperature for selected cities and models.
    
    Parameters:
    ds (xr.Dataset): Dataset containing temperature data.
    city (str): The chosen city.
    model (str): The chosen model.
    sea_mask (xr.Dataset): Dataset containing sea mask data.
    orography (xr.Dataset): Dataset containing orography data.
    urbanfraction (xr.Dataset): Dataset containing urban fraction data.
    CORDEX (str): Name of the CORDEX dataset.
    lat (int): Latitude of the city.
    lon (int): Longitude of the city.
    dist_lon (int): Distance in longitude.
    dist_lat (int): Distance in latitude.
    vtmin (int): Minimum value for temperature.
    vtmax (int): Maximum value for temperature.
    season (str, optional): The selected season for filtering temperature data. 
                            Options are "all" (default), "jfm" (January-February-March), "amj" (April-May-June), "jas" (July-August-September), and "ond" (October-November-December).
    period (slice): Slice object with datetime indices.
    rval (xr.Dataset): Dataset containing values for reference.
    orog_max (int, optional): Maximum limit for orography. Defaults to 100.
    urban_min (float, optional): Minimum threshold for urban areas. Defaults to 0.1.
    percentil (float, optional): Percentile value for calculating percentiles. Defaults to 10.
    
    Returns:
    fig (plt.Figure): Plot object
    '''
    #Draws the city
    root = '/lustre/gmeteo/WORK/DATA/CORDEX-FPS-URB-RCC/nextcloud/'
    ucdb_info = gpd.read_file(root  + 'CORDEX-CORE-WG/GHS_FUA_UCD/GHS_STAT_UCDB2015MT_GLOBE_R2019A_V1_2.gpkg')
    ucdb_city = ucdb_info.query(f'UC_NM_MN =="{city}"').to_crs(crs = 'EPSG:4326')
    if city == 'London':
        ucdb_city = ucdb_city[ucdb_city['CTR_MN_NM'] == 'United Kingdom']
    
    # Select data from the dataset based on the time range and season
    ds = ds.sel(time=period)
    season_to_month_bounds = {
        'jfm': (1, 3),
        'amj': (4, 6),
        'jas': (7, 9),
        'ond': (10, 12)
    }
    # Select data from dataset based on time range and season 
    def select_data_by_season(ds, season, period):
        if season != 'all':
            lower_bound, upper_bound = season_to_month_bounds[season]
            ds= ds.sel(time=(ds['time'].dt.month >= lower_bound) & (ds['time'].dt.month <= upper_bound))
        else:
            ds= ds.sel(time=period)

    # Define latitude and longitude ranges based on city location and distances
    lon_min = lon - dist_lon
    lon_max = lon + dist_lon
    lat_min = lat - dist_lat
    lat_max = lat + dist_lat

    # Check for valid members within the specified latitude and longitude range
    member_count = len(ds['member'].values)
    cols = 3
    valid_axes = []
    for m in ds['member'].values:
        if CORDEX=="CORDEX-CORE":
            if ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), member=m).mean(dim='time').notnull().any():
                member_str = str(ds["member_id"].sel(member=m).values)  # Convert to string
                if model in member_str:  # Include mean temp min if member_id ends with the model name
                    valid_axes.append(m)
        elif CORDEX=="CORDEX-EUR-11":
            if ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), member=m).mean(dim=['lat','lon']).notnull().any():
                member_str = str(ds["member_id"].sel(member=m).values)  # Convert to string
                if model in member_str:  # Include mean temp min if member_id ends with the model name
                    valid_axes.append(m)

                
    # Calculate the number of rows for subplots
    rows = (len(valid_axes) + cols - 1) // cols + 1
    
    # Create subplots with projection
    fig, axes = plt.subplots(rows + 1, 3, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(20, rows*10))
    fig2, axes2 = plt.subplots(len(valid_axes), 1, figsize=(20, rows*15))
    
    # Define variable names from model dictionary
    urban_variable = model_dict.get(model, {}).get('urban_variable')
    orog_variable = model_dict.get(model, {}).get('orog_variable')
    sea_variable = model_dict.get(model, {}).get('sea_variable')
                 
    #Urban plot
    urban_plot = urbanfraction[urban_variable].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max)).plot(ax=axes[0, 0], cmap=cmap_urban, vmax=1)
    axes[0, 0].set_title('Urban Fraction')
    axes[0, 0].coastlines()
    ucdb_city.plot(ax=axes[0, 0], facecolor="none", edgecolor="red")
    
    #Plot the contour of urban data        
    urban = urbanfraction[urban_variable].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max))
    
    # Urban mask
    urban_mask = urban > urban_min
    if model=="REMO":
        urban_mask=urban_mask.sel(time=urban_mask['time'][0])
    
    lon_values = urban['lon'].values
    lat_values = urban['lat'].values
    dist_lon=(lon_values[0]- lon_values[1])/2
    dist_lat=(lat_values[0]- lat_values[1])/2            

    #Orography plot
    orog=orography[orog_variable].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max))        
    orography_plot = orography[orog_variable].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max)).plot(ax=axes[0, 1], cmap=cmap_orog)
    axes[0, 1].set_title('Orography')
    axes[0, 1].coastlines()
    ucdb_city.plot(ax=axes[0,1], facecolor="none", edgecolor="red")

    #Orography mask
    orog_urban=orog.where(urban_mask).mean().item()
    orog_mask1 = orog < (orog_max + orog_urban)
    orog_mask2 = orog > (orog_urban - orog_max)
    orog_mask = orog_mask1 & orog_mask2
    
    #Sea plot
    sea_mask_plot = sea_mask[sea_variable].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max)).plot(ax=axes[0, 2], cmap=cmap_sea,vmax=100,vmin=0)
    axes[0, 2].set_title('Sea Mask')
    axes[0, 2].coastlines()
    ucdb_city.plot(ax=axes[0,2], facecolor="none", edgecolor="red")

    #Sea mask
    sea=sea_mask[sea_variable].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max))
    sea_masked=(sea>sea_max)
    
    # Plot the border for cells for the differents masks
    for k, m in enumerate(valid_axes):
        col = (k) % 3
        for i in range(len(lon_values)):
            for j in range(len(lat_values)):
                if urban_mask[j, i] == False and sea_masked[j, i].any() and orog_mask[j, i].any():
                    plot_boundary(axes[0, col], lon_values[i], lat_values[j], dist_lon, dist_lat,'blue',zorder=5)
                elif urban_mask[j, i] == True:
                    plot_boundary(axes[0, col], lon_values[i], lat_values[j], dist_lon, dist_lat,'red',zorder=50)

    #Chose the nearest points in rural areas for time series:
    count = 0
    distances = np.zeros((len(lat_values), len(lon_values))) 
    max_temp = ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max)).where(urban_mask).mean(dim=['time', 'member']).max()
    # Find the locations with the maximum temperature
    max_indices = ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max)) \
                    .where(urban_mask).mean(dim=['time', 'member']) == max_temp
    
    # Extract lon values where max_indices is True
    lon_max_urban = ds['lon'].sel(lon=slice(lon_min, lon_max)).where(max_indices)
    lon_max_urban = lon_max_urban.values.flatten()[~np.isnan(lon_max_urban.values.flatten())]
    
    lat_max_urban = ds['lat'].sel(lat=slice(lat_min, lat_max)).where(max_indices)
    lat_max_urban = lat_max_urban.values.flatten()[~np.isnan(lat_max_urban.values.flatten())]
    
    # Calculate distances and count points within selected distances
    for i in range(len(lon_values)):
        for j in range(len(lat_values)):
            if urban_mask[j, i].any():
                count += 1
            elif sea_masked[j, i].any() and orog_mask[j, i].any():

                # Earth's radius (in kilometers)
                earth_radius = 6371  # Average Earth radius in kilometers
                
                # Calculate angular distance using Haversine formula
                delta_lon = lon_max_urban - lon_values[i]
                delta_lat = lat_max_urban - lat_values[j]
                
                a = np.sin(np.radians(delta_lat )/ 2) ** 2 + np.cos(np.radians(lat_max_urban)) * np.sin(np.radians(lat_values[j])) * np.sin(np.radians(delta_lon) / 2) ** 2
                c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
                
                # Convert angular distance to kilometers using Earth's radius
                distance_to_urban = earth_radius * c

                distances[j, i] = distance_to_urban[0]
    
    # Sort distances and select the first 'count' non-zero distances
    sorted_distances = np.sort(distances[distances != 0])
    selected_distances = sorted_distances[:count]            

    # Initialize an empty mask
    selected_distances_mask = np.zeros_like(urban_mask, dtype=bool)
    
    # Iterate over each distance in the dataset
    for i in range(len(lon_values)):
        for j in range(len(lat_values)):
            if distances[j, i] in selected_distances:
                selected_distances_mask[j, i] = True
                
     # Plot points for the choosen not urban cells
    for k, m in enumerate(valid_axes):
        col = (k) % 3
        for i in range(len(lon_values)):
            for j in range(len(lat_values)):
                if urban_mask[j, i] == False and sea_masked[j, i].any() and orog_mask[j, i].any()and distances[j, i] in selected_distances:
                        axes[0, col].plot(lon_values[i], lat_values[j], marker='o', markersize=3, color='black')
    
    # Plot contrast between urban areas 
    for k, m in enumerate(valid_axes):
        row = (k) // 3 + 1
        col = (k) % 3
        tn_mean=ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), member=m).where(urban_mask).mean(dim=['time','lat','lon']).compute()
        change=ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), member=m).mean(dim=['time'])-tn_mean
        change.plot(ax=axes[row, col], cmap='seismic', vmin=vtmin,vmax=vtmax)
        axes[row, col].set_title(f'{ds["member_id"].sel(member=m).values} - ºC')
        axes[row, col].coastlines()
        ucdb_city.plot(ax=axes[row, col], facecolor="none", edgecolor="GreenYellow",zorder=100)
        #for index, station in rval.iterrows():
        #    station_lon = station['lon']
        #    station_lat = station['lat']
        #    
        #    urban_mask_value = urban_mask.sel(lon=station_lon, lat=station_lat, method='nearest').item()
        #    sea_mask_value = sea_masked.sel(lon=station_lon, lat=station_lat, method='nearest').item()
        #    orog_mask_value=orog_mask.sel(lon=station_lon, lat=station_lat, method='nearest').item()
        #    
        #    if urban_mask_value > 0:
        #        station_color = 'grey'  
        #    elif sea_mask_value > 0 and orog_mask_value > 0:
        #        station_color = 'black'
        #    else:
        #        station_color = 'none' 
        #                
        #    axes[row, col].plot(station_lon, station_lat, marker='o', markersize=3, color=station_color)
        for i in range(len(lon_values) ):
                for j in range(len(lat_values)):
                    if urban_mask[j, i] == False and sea_masked[j, i].any() and orog_mask[j, i].any():
                        plot_boundary(axes[row, col], lon_values[i], lat_values[j], dist_lon, dist_lat,'blue',zorder=5)
                        if distances[j, i] in selected_distances:
                            axes[row, col].plot(lon_values[i], lat_values[j], marker='o', markersize=3, color='black')
                    elif urban_mask[j, i] == True:
                        plot_boundary(axes[row, col], lon_values[i], lat_values[j], dist_lon, dist_lat,'red',zorder=50)
    
    # Remove any empty subplots
    for i in range(len(valid_axes), rows * cols):
        row = (i) // 3 + 1
        col = (i) % 3
        fig.delaxes(axes[row, col])
    
    # Time series plot
    for k, m in enumerate(valid_axes):
        # Calculate monthly mean temperature within urban mask for this member
        time_urban = (ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), member=valid_axes[k])
                   .where(urban_mask)
                   .groupby('time.month')
                   .mean(dim=['lat', 'lon', 'time'])
                   .compute())
        # Calculate anomaly (difference from overall mean) for urban time series
        time_plot_urban = time_urban - time_urban 
        
        # Plot anomaly for urban time series with percentile
        time_plot_urban.plot(ax=axes2[k], color='red', linewidth=3, linestyle='--',label='Mean of the urban cells')        
        time_plot_urban_percentil = (ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), member=valid_axes[k]).where(urban_mask).groupby('time.month').mean('time') - time_urban)
        lower_percentile_urban = np.nanpercentile(time_plot_urban_percentil, percentil, axis=[time_plot_urban_percentil.get_axis_num('lat'), time_plot_urban_percentil.get_axis_num('lon')])
        upper_percentile_urban = np.nanpercentile(time_plot_urban_percentil, 100-percentil, axis=[time_plot_urban_percentil.get_axis_num('lat'), time_plot_urban_percentil.get_axis_num('lon')])
        axes2[k].fill_between(time_plot_urban_percentil['month'], lower_percentile_urban, upper_percentile_urban, color='red', alpha=0.1)
        axes2[k].fill_between(time_plot_urban_percentil['month'], time_plot_urban_percentil.min(dim=['lat','lon']), time_plot_urban_percentil.max(dim=['lat','lon']), color='red', alpha=0.1)

        # Calculate monthly mean temperature for non-urban (rural) areas
        time_rural = (ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), member=valid_axes[k])
                  .where(~urban_mask)
                  .where(sea_masked)
                  .where(orog_mask)
                  .where(selected_distances_mask)
                  .groupby('time.month')
                  .mean(dim=['lat', 'lon', 'time'])
                  .compute())
        time_plot_rural = time_rural - time_urban

         # Plot anomaly for rural time series with percentiles
        time_plot_rural.plot(ax=axes2[k], color='blue', linewidth=3, linestyle='--', label='Mean of the not urban cells')        
        time_plot_rural_percentil = (ds['tn'].sel(lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), member=valid_axes[k]).where(~urban_mask).where(sea_masked).where(orog_mask).where(selected_distances_mask).groupby('time.month').mean('time') - time_urban)
        lower_percentile_rural = np.nanpercentile(time_plot_rural_percentil, percentil, axis=[time_plot_rural_percentil.get_axis_num('lat'), time_plot_rural_percentil.get_axis_num('lon')])
        upper_percentile_rural = np.nanpercentile(time_plot_rural_percentil, 100-percentil, axis=[time_plot_rural_percentil.get_axis_num('lat'), time_plot_rural_percentil.get_axis_num('lon')])
        axes2[k].fill_between(time_plot_rural_percentil['month'], lower_percentile_rural, upper_percentile_rural, color='blue', alpha=0.1)
        axes2[k].fill_between(time_plot_rural_percentil['month'], time_plot_rural_percentil.min(dim=['lat','lon']), time_plot_rural_percentil.max(dim=['lat','lon']), color='blue', alpha=0.1)
        
        #Calculate and plot anomaly for individual cells
        for i in range(len(lon_values)):
            for j in range(len(lat_values)):
                if urban_mask[j, i].any():
                    time_plot_urban_loc = (ds['tn'].sel(lon=lon_values[i], lat=lat_values[j], member=valid_axes[k]).groupby('time.month').mean(dim='time') - time_urban)
                    time_plot_urban_loc.plot(ax=axes2[k], color='red', linewidth=0.5)
                elif sea_masked[j, i].any() and orog_mask[j, i].any() and selected_distances_mask[j,i].any():
                    time_plot_rural = (ds['tn'].sel(lon=lon_values[i], lat=lat_values[j], member=valid_axes[k]).groupby('time.month').mean(dim='time') - time_urban)
                    time_plot_rural.plot(ax=axes2[k], color='blue', linewidth=0.5)
        
        # Define month names list
        months_names = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December' ];
        
        # Add empty plot lines with labels for legend
        axes2[k].plot([], color='red', linewidth=0.5, label='Urban Cells')
        axes2[k].plot([], color='blue', linewidth=0.5, label='Not Urban Cells')

        # Set x-axis ticks and labels based on season
        if season == 'all':
            axes2[k].set_xticks(range(1, 13))
            axes2[k].set_xticklabels(months_names)
            
        else:
            start_month, end_month = season_to_month_bounds[season]
            months_names = months_names[start_month - 1:end_month]
            axes2[k].set_xticks(range(start_month, end_month + 1))
            axes2[k].set_xticklabels(months_names)
        
        # Set axis labels, title and legend
        axes2[k].set_xlabel('Months')
        axes2[k].set_ylabel('Temperature Anomaly (°C)')  
        axes2[k].set_title('Monthly Temperature Anomalies (Urban vs. not Urban)')
        axes2[k].legend()
    return fig, fig2