In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import dask
from dask.distributed import Client, LocalCluster

client = Client(n_workers=1, threads_per_worker=4, memory_limit='100GB') # Note that `memory_limit` is the limit **per worker**.

client # If you click the dashboard link in the output, you can monitor real-time progress and get other cool visualizations.

In [15]:
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import scipy.spatial
import cartopy
import cartopy.crs as ccrs

from interpolation_utils import interpolate_nearest_from_grid

In [16]:
# Map projections for plotting
crs_3031 = cartopy.crs.Stereographic(central_latitude=-90, true_scale_latitude=-71) # All Antarctic data will be projected (if needed) to this
crs_3413 = ccrs.Stereographic(central_latitude=90, central_longitude=-45, true_scale_latitude=70) # All Greenland data will be projected (if needed) to this
crs_lonlat = cartopy.crs.PlateCarree()

In [17]:
target_resolution = 10e3 # meters

dataset = 'antarctica'

if dataset == 'antarctica':
    # BedMachine Antarctica grid has 500 m spacing
    # Decimate by 2 to get 1 km spacing
    bm_grid_decimation = 2

    projection = crs_3031

    output_nc_path = 'data_preprocessing/input_data_ais.nc'
elif dataset == 'greenland':
    # BedMachine Greenland grid has 150 m spacing
    # Decimate by 6 to get 900 m spacing
    bm_grid_decimation = 6

    projection = crs_3413

    output_nc_path = 'data_preprocessing/input_data_gis.nc'

In [18]:
if dataset == 'antarctica':
    # Load BedMachine datasets
    ds_bm = xr.open_dataset("external_datasets/BedMachineAntarctica-v3.nc")

    # Load Rignot surface velocity datasets
    ds_vel = xr.open_dataset("external_datasets/antarctic_ice_vel_phase_map_v01.nc")
elif dataset == 'greenland':
    # Load BedMachine datasets
    ds_bm = xr.open_dataset("external_datasets/BedMachineGreenland-v5.nc")

    # Load ITS_LIVE surface velocity
    ds_vel = xr.open_dataset("external_datasets/ITS_LIVE_velocity_120m_RGI05A_0000_v02.nc")

if dataset == 'antarctica':
    # Calculate magnitude (speed) and error per NSIDC-0754 user guide:
    # https://nsidc.org/sites/default/files/nsidc-0754-v001-userguide.pdf
    ds_vel['speed'] = np.sqrt(ds_vel['VX']**2 + ds_vel['VY']**2)
    ds_vel['speed_err'] = np.sqrt(ds_vel['ERRX']**2 + ds_vel['ERRY']**2)
elif dataset == 'greenland':
    # ITS_LIVE velocity data already has magnitude and magnitude error calculated, just need to re-name
    ds_vel['speed'] = ds_vel['v']
    ds_vel['speed_err'] = ds_vel['v_error']

In [19]:
# Load ERA5 t2m data
ds_era5 = xr.open_dataset("external_datasets/era5_t2m_ensemble.nc")

t2m_mean = ds_era5.t2m.mean(dim='valid_time').mean(dim='number')
t2m_std = ds_era5.t2m.mean(dim='valid_time').std(dim='number')

# Combined t2m_mean and t2m_std into a single dataset
ds_t2m = xr.Dataset({'t2m_mean': t2m_mean, 't2m_std': t2m_std})

In [20]:
# Start with a decimated copy of the bedmachine dataset
ds_output = ds_bm.isel(
    x=slice(0, None, bm_grid_decimation),
    y=slice(0, None, bm_grid_decimation)
)[['thickness', 'errbed', 'mask', 'surface']]

In [21]:
# Surface velocity
# TODO: NN interpolation is not really the appropriate choice here. We should probably resample and take the mean.
# But good enough for now.

ds_output['speed'], ds_output['speed_err'] = interpolate_nearest_from_grid(ds_vel, ds_output, ['speed', 'speed_err'], target_gridded=True)

In [22]:
# Surface temperature
# Since this data is very coarse anyway, we'll just use nearest-neighbor interpolation

ds_output['t2m'], ds_output['t2m_err'] = interpolate_nearest_from_grid(ds_t2m, ds_output, ['t2m_mean', 't2m_std'],
                                            source_crs=crs_lonlat, target_crs=projection, y_name='latitude', x_name='longitude',
                                            target_gridded=True)

In [None]:
# Check out work with some plots

decimate_by = 10 # Additional decimation only for plotting

for variable_to_plot in ds_output.data_vars:
    if "Unnamed" in variable_to_plot:
        continue
    if (variable_to_plot == 'x') or (variable_to_plot == 'y'):
        continue
    if "source" in variable_to_plot:
        continue

    try:
        fig, ax = plt.subplots(figsize=(6,4), subplot_kw=dict(projection=projection))

        pcm = ax.pcolormesh(ds_output['x'][::decimate_by], ds_output['y'][::decimate_by], ds_output[variable_to_plot][::decimate_by, ::decimate_by],
                      transform=projection, rasterized=True)

        fig.colorbar(pcm, ax=ax, label=variable_to_plot)

        ax.coastlines(resolution='10m', color='black', linewidth=0.5)
    
    except Exception as e:
        print(f"Count not plot variable {variable_to_plot}")
        continue

In [24]:
ds_output.to_netcdf(output_nc_path)