In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import os.path
import pathlib
import platform 
import json 
import geopandas as gpd
from shapely.geometry import Point
from pykrige.ok import OrdinaryKriging
import matplotlib as mpl
import rasterio
from rasterio.transform import from_origin
from rasterio.enums import Resampling
import imageio.v2 as imageio
import xarray as xr
import sys
from rasterio import features
from rasterio.plot import show

In [4]:
# If errors occur here please refer to the readme file or to the file_imports.py folders. 

cwd = pathlib.Path().resolve()
src = cwd.parent
data = src.parent.parent.parent
root = src.parent
OS_type = platform.system()
sys.path.append(str(src))
sys.path.append(str(root))
from utils.file_imports import *


data_paths = file_paths(root, TAHMO = True)
shape_raw = data_paths[0]
raw_files = data_paths[1]
processed_files = data_paths[2]
animation_path = data_paths[3]

The first entry is pointing to /Users/matskerver/Documents/data_tana/TAHMO/raw_TAHMO, the second one to /Users/matskerver/Documents/data_tana/TAHMO/processed_TAHMO and the third one to /Users/matskerver/Documents/data_tana/TAHMO/interpolated_TAHMO. Animations will be put located in /Users/matskerver/Documents/data_tana/TAHMO/results


In [None]:
netcdf_file = 'NetCDF_TAHMO.nc'
ds = xr.open_dataset(os.path.join(raw_files, netcdf_file))

# Load geographical data
proj = 'EPSG:32737'
counties = gpd.read_file(os.path.join(shape_raw, 'total_tana_catchement_area_clip_projected.gpkg'))

# Variables for plotting
variable = 'te_mean'

geo_dataframes = {}
for station in ds.data_vars:
    # Extract longitude and latitude for each station
    # Assume longitude and latitude are constant over time; thus, we take the mean
    longitude = ds[station].sel(variable='longitude').mean().values.item()
    latitude = ds[station].sel(variable='latitude').mean().values.item()

    # Select the data for the variable of interest and create a geoDataframe with this variable
    df = ds[station].sel(variable=variable).to_dataframe().reset_index()
    point = Point(longitude, latitude)
    rain_gdf = gpd.GeoDataFrame(df, geometry=[point] * len(df), crs=proj)
    geo_dataframes[station] = rain_gdf

# Use standard procedure to plot the Tana basin 
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
plt.style.use('bmh')
counties.plot(ax=ax, color='bisque', edgecolor='dimgray')

for key, gdf in geo_dataframes.items():
    gdf.plot(ax=ax, marker='o', color='red', markersize=15)

ax.set_title('Tana Basin Area Kenya', fontdict={'fontsize': '15', 'fontweight': '3'})


In [None]:
#Warning! creates a file of multiple gb (about 6.5Gb for grid_space = 0.01)
netCDF_file = True 

if (netCDF_file == True):
    
    #Reads in the required datasets and shapefiles to use the Kriging interpolation on.
    ds = xr.open_dataset(os.path.join(raw_files, 'NetCDF_TAHMO.nc'))
    path_shape = os.path.join(shape_raw, 'total_tana_catchement_area_clip.shp')
    shapefile = gpd.read_file(path_shape)
    
    #Variables that determine the grid resolution (in degrees longitude and latitude) and the extend of the earth 
    #to be plotted (same units).
    grid_space = 0.05
    variable = 'evap'
    longitudes_mapped = [36, 42]
    latitudes_mapped = [-4, 2]
    
    
    grid_lons = np.arange(longitudes_mapped[0], longitudes_mapped[1], grid_space)
    grid_lats = np.arange(latitudes_mapped[0], latitudes_mapped[1], grid_space)

    #Extract the station, location and variable data from the dataset ds
    latitudes = ds.sel(variable='latitude').to_array().mean(dim='time').values
    longitudes = ds.sel(variable='longitude').to_array().mean(dim='time').values
    station_ids = ds.sel(variable='latitude').to_array().mean(dim='time').coords['variable'].values
    variable_data = ds.sel(variable=variable)

    # Create empty variable to store the grids in and a variable to keep track of the progress of the interpolation.
    # Variable interval determines the timesteps on which the progess will be printed.
    z_values = []
    progress = 0
    interval = 100
    total_steps = len(ds.time)
    
    
    # ------------- Don't Change anything beneath this line. Variables should be changed above -------------- #

    
    #We loop through each of the timesteps in the provided netcdf file.
    for time_value in ds.time:
        
        #Empty array is created to store the current data variable to be interpolated on the grid
        data = []
        for station in ds.data_vars:
            # Select the data for the current time and variable of interest
            value = ds[station].sel(time=time_value, variable=variable).values
            data.append(value.item() if value.size > 0 else np.nan)
        
        # Convert the data to a numpy array so we can apply a mask. As some stations have gaps in their data we 
        # need to exclude these to prevent errors. This is done with the Mask. We also need to remove the longitudes
        # and latitudes of these points to ensure consistent array dimensions. 
        
        data = np.array(data)
        valid_mask = ~np.isnan(data)
        filtered_data = data[valid_mask]
        filtered_longitudes = longitudes[valid_mask]
        filtered_latitudes = latitudes[valid_mask]

        # Setup the Ordinary Kriging interpolator and subsequently execute it. It is then added to the array.
        OK = OrdinaryKriging(
            filtered_longitudes,
            filtered_latitudes,
            filtered_data,
            variogram_model='gaussian',
            verbose=False,
            enable_plotting=False
        )
        z, ss = OK.execute('grid', grid_lons, grid_lats)
        z_values.append(z)
        
        # Keep the user updated on the progress of the interpolation process as it can take some time.
        if (progress % interval == 0):
            print(f'Update: {progress}/{total_steps} succesfully processed.')
        progress += 1

        
    # Put all the data back into a suitable Xarray to be converted into the final NetCDF file
    z_array = np.stack(z_values)

    time_dim = ds.time 
    lon_dim = np.arange(longitudes_mapped[0], longitudes_mapped[1], grid_space)  # Longitude grid
    lat_dim = np.arange(latitudes_mapped[0], latitudes_mapped[1], grid_space)  # Latitude grid
    kriging_ds = xr.Dataset(
        {
            variable: (['time', 'latitude', 'longitude'], z_array)
        },
        coords={
            'time': time_dim,
            'latitude': lat_dim,
            'longitude': lon_dim
        }
    )

    kriging_ds.to_netcdf(os.path.join(processed_files, 'kriging_results_evap.nc'))
    print(f'file succesfully saved at {processed_files}')


In [13]:
animation = True

if animation:

    # Scale of the plot and a custom colormapping. Change vmin and vmax for other variables
    # to suitable values. 
    
    vmin = 0.3
    vmax = 0.6
    cdict = {
            'red':   [(0.0, 1.0, 1.0), (0.05, 0.59, 0.59), (1.0, 0.0, 0.0)],
            'green': [(0.0, 1.0, 1.0), (0.05, 0.29, 0.29), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)],
            'blue':  [(0.0, 0.88, 0.88), (0.05, 0.1, 0.1), (0.2, 1.0, 1.0), (1.0, 1.0, 1.0)]
        }

    cm = mpl.colors.LinearSegmentedColormap('my_colormap', cdict, 1024)

    # Load the dataset created in the previous cell and determine the required output (variable and timesteps)
    netcdf_file_raster = 'kriging_results_evap.nc'
    ds = xr.open_dataset(os.path.join(processed_files, netcdf_file_raster))
    variable = 'evap'
    time_indices = range(0, ds.dims['time'], 1)  
    file_paths = []

    # Open the shapefile that the kriging interpolation will be clipped on.
    path_shape = os.path.join(shape_raw, 'total_tana_catchement_area_clip.shp')
    shapefile_gdf = gpd.read_file(path_shape)


    for time_index in time_indices:
        selected_data = ds.isel(time=time_index)
        data_array = selected_data[variable]

        # Generate raster to plot upon
        grid_lons, grid_lats = ds['longitude'].values, ds['latitude'].values
        grid_space = grid_lons[1] - grid_lons[0]
        transform = rasterio.transform.from_origin(min(grid_lons), max(grid_lats), grid_space, grid_space)
        temp_tif = 'temp.tif'
        with rasterio.open(
            temp_tif, 'w', driver='GTiff',
            height=data_array.shape[0], width=data_array.shape[1],
            count=1, dtype=str(data_array.dtype),
            crs='+proj=latlong',
            transform=transform
        ) as raster:
            raster.write(data_array.values, 1)

        # Clip raster to the previously opened shapefile so the Tana Basin is correctly represented. 
        with rasterio.open(temp_tif) as src:
            out_image, out_transform = rasterio.mask.mask(src, shapefile_gdf.geometry, crop=True)
            out_meta = src.meta.copy()

        out_meta.update({"driver": "GTiff",
                         "height": out_image.shape[1],
                         "width": out_image.shape[2],
                         "transform": out_transform})

        clipped_tif = 'clipped.tif'
        with rasterio.open(clipped_tif, 'w', **out_meta) as dest:
            dest.write(out_image)

        # Create plots of the result that will be saved as .png
        fig, ax = plt.subplots(figsize=(10, 8))
        with rasterio.open(clipped_tif) as raster_plot:
            img_array = raster_plot.read(1)
            extent = [out_transform[2], out_transform[2] + out_transform[0] * img_array.shape[1],
                      out_transform[5] + out_transform[4] * img_array.shape[0], out_transform[5]]

            img = ax.imshow(img_array, extent=extent, cmap=cm, origin='upper', vmin=vmin, vmax=vmax)  
            ax.set_title(f'{variable} on {selected_data.time.dt.strftime("%Y-%m-%d").values}')
            ax.set_xlabel('Longitude')
            ax.set_ylabel('Latitude')

            cbar = fig.colorbar(img, ax=ax)
            cbar.set_label(variable)

        # Save the plot and save all the file names to an array.
        image_filename = f'image_{time_index}.png'
        plt.savefig(os.path.join(animation_path, image_filename))
        plt.close()
        file_paths.append(image_filename)


    gif_filename = f'time_series_{variable}.gif'  
    with imageio.get_writer(os.path.join(animation_path, gif_filename), mode='I', duration=0.5) as writer:
        for filename in file_paths:
            image = imageio.imread(os.path.join(animation_path, filename))
            writer.append_data(image)  # Pass the image data array directly

    # Delete the individual image files after creating the GIF
    for filename in file_paths:
        os.remove(os.path.join(animation_path, filename))  