## Explainability of the DeepESD model

The `deep4downscaling` library includes the module `deep4downscaling.deep.xai`, which implements various functions for applying eXplainable Artificial Intelligence (XAI) techniques [1,2] to statistical downscaling models. It is specifically designed to explain the decisions made by deep learning models developed with `deep4downscaling`. However, with some adjustments, it can also be extended to models outside the scope of `deep4downscaling`. In this example, we explain the DeepESD model trained in the `downscaling_deepsd.ipynb` notebook for precipitation downscaling. 

### Set the data

In [1]:
DATA_PATH = './data/input'
FIGURES_PATH = './figures'
MODELS_PATH = './models'

In [17]:
import xarray as xr
import torch
import captum

import sys; sys.path.append('/home/jovyan/deep4downscaling')
import deep4downscaling.viz
import deep4downscaling.trans
import deep4downscaling.deep.models
import deep4downscaling.deep.xai

First, we reconstruct part of the preprocessing performed during the training of the model. This step is crucial because both the loading of the trained model and the computation of XAI-based diagnostics require the dimensions of the predictor and predictand, as well as the defined mask to work properly.

In [3]:
# Load predictor
predictor_filename = f'{DATA_PATH}/ERA5_NorthAtlanticRegion_1-5dg_full.nc'
predictor = xr.open_dataset(predictor_filename)

# Load predictand
predictand_filename = f'{DATA_PATH}/pr_AEMET.nc'
predictand = xr.open_dataset(predictand_filename)

# Remove days with nans in the predictor
predictor = deep4downscaling.trans.remove_days_with_nans(predictor)

# Align both datasets in time
predictor, predictand = deep4downscaling.trans.align_datasets(predictor, predictand, 'time')

# Split data into training and test sets
years_train = ('1980', '2010')
years_test = ('2011', '2020')

x_train = predictor.sel(time=slice(*years_train))
y_train = predictand.sel(time=slice(*years_train))

x_test = predictor.sel(time=slice(*years_test))
y_test = predictand.sel(time=slice(*years_test))

# Standardize the test predictors w.r.t. to the training ones
x_test_stand = deep4downscaling.trans.standardize(data_ref=x_train, data=x_test)

# Set predictand masking
y_mask = deep4downscaling.trans.compute_valid_mask(y_train) 

y_train_stack = y_train.stack(gridpoint=('lat', 'lon'))
y_mask_stack = y_mask.stack(gridpoint=('lat', 'lon'))

y_mask_stack_filt = y_mask_stack.where(y_mask_stack==1, drop=True)
y_train_stack_filt = y_train_stack.where(y_train_stack['gridpoint'] == y_mask_stack_filt['gridpoint'],
                                             drop=True)
# Convert data from xarray to numpy
x_test_stand_arr = deep4downscaling.trans.xarray_to_numpy(x_test_stand)
y_train_arr = deep4downscaling.trans.xarray_to_numpy(y_train_stack_filt)

There are no observations containing null values


Next, we set the device to be used for computing these metrics. It is important to note that XAI techniques require the gradients of the model to be computed exactly as they were during training. Therefore, we recommend running these computations on a GPU, if available.

In [4]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
# Load the model to explain
model_name = 'deepesd_pr'
model = deep4downscaling.deep.models.DeepESDpr(x_shape=x_test_stand_arr.shape,
                                               y_shape=y_train_arr.shape,
                                               filters_last_conv=1,
                                               stochastic=False)
model.load_state_dict(torch.load(f'{MODELS_PATH}/{model_name}.pt'))

### Explainability techniques

The module `deep4downscaling.deep.xai` enables the computation of both saliency maps and various XAI diagnostics, which combine these saliency maps across time and/or spatial locations. Saliency maps can be generated using a range of techniques, such as standard saliency, attribution, or integrated gradients, among others (refer to [3] for an overview of these techniques in the context of statistical downscaling).

By leveraging `captum`, a PyTorch-integrated XAI library, `deep4downscaling` provides users with access to a wide variety of techniques (see [4] for a comprehensive list). In this notebook, for simplicity, we focus on standard saliency.

In [6]:
xai_method = captum.attr.Saliency(model)

#### Integrated Saliency Map (ISM)

The `deep4downscaling.deep.xai.compute_ism` function applies the chosen XAI technique to all time steps in the `xr.Dataset` provided as the input to the `data` argument. The computations are performed with respect to a specific grid point in the predictand, which can be specified using the `coord` argument. This argument expects a tuple containing the latitude and longitude coordinates of the grid point to explain. If the specified grid point does not exist, the function applies the XAI technique to the nearest grid point in space.

It is important to note the role of the `postprocess` argument. By setting this to `True`, the saliency maps are postprocessed following the approach described in [5] to reduce artifacts such as noisy patterns and improve consistency.

In [7]:
spatial_coord = (43.125797, -8.087920)
ism = deep4downscaling.deep.xai.compute_ism(data=x_test_stand,
                                            mask=y_mask.copy(deep=True),
                                            model=model, device=device,
                                            xai_method=xai_method,
                                            coord=spatial_coord,
                                            postprocess=True)

Computing ISMs...


In [8]:
time_to_plot = '01-02-2018'
deep4downscaling.viz.multiple_map_plot(data=ism.sel(time=time_to_plot),
                                       colorbar='hot_r',
                                       output_path=f'./{FIGURES_PATH}/ism.pdf')

In addition to applying the XAI technique, it is also possible to compute various XAI-based diagnostics tailored for statistical downscaling. Specifically, `deep4downscaling` implements the Aggregated Saliency Map (ASM) and the Saliency Dispersion Map (SDM). For more information on these diagnostics, we refer the user to [5] and the documentation of the respective functions.

#### Aggregated Saliency Map (ASM)

In [9]:
time_slice = ('01-01-2011', '03-01-2011')
asm = deep4downscaling.deep.xai.compute_asm(data=x_test_stand.sel(time=slice(*time_slice)),
                                            mask=y_mask.copy(deep=True),
                                            model=model, device=device,
                                            xai_method=xai_method,
                                            batch_size=1024,
                                            postprocess=True)

Computing ASMs...


100%|██████████| 60/60 [02:21<00:00,  2.36s/it]


In [10]:
deep4downscaling.viz.multiple_map_plot(data=asm,
                                       colorbar='hot_r',
                                       output_path=f'./{FIGURES_PATH}/asm.pdf')

#### Saliency Dispersion Map (SDM)

In [12]:
time_slice = ('01-01-2011', '03-01-2011')
sdm = deep4downscaling.deep.xai.compute_sdm(data=x_test_stand.sel(time=slice(*time_slice)),
                                            mask=y_mask.copy(deep=True), var_target='pr',
                                            model=model, device=device,
                                            xai_method=xai_method,
                                            batch_size=1024,
                                            postprocess=True)

Precomputing Haversine distances...
Computing SDMs...


100%|██████████| 60/60 [02:24<00:00,  2.40s/it]


In [16]:
deep4downscaling.viz.simple_map_plot(data=sdm,
                                     var_to_plot='pr',
                                     colorbar='Reds',
                                     output_path=f'./{FIGURES_PATH}/sdm.pdf')

### References

[1] Buhrmester, V., Münch, D., & Arens, M. (2021). Analysis of explainers of black box deep neural networks for computer vision: A survey. Machine Learning and Knowledge Extraction, 3(4), 966-989.

[2] Das, A., & Rad, P. (2020). Opportunities and challenges in explainable artificial intelligence (xai): A survey. arXiv preprint arXiv:2006.11371.

[3] González Abad, J. (2024). Towards explainable and physically-based deep learning statistical downscaling methods.

[4] https://captum.ai/docs/attribution_algorithms

[5] González‐Abad, J., Baño‐Medina, J., & Gutiérrez, J. M. (2023). Using explainability to inform statistical downscaling based on deep learning beyond standard validation approaches. Journal of Advances in Modeling Earth Systems, 15(11), e2023MS003641.