### Combining it all together

This cookbook shows how to visualize the Layerwise Relevance Propogation data that was generated to determine the importance of given features to the AQI classification. 

In [None]:
import glob
import sys
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pickle
import xarray as xr
import pandas as pd
import numpy as np

Query local directories for all of the input RLP data.

In [None]:
from datetime import timedelta
pickles = '/lcrc/group/earthscience/rjackson/opencrums/scripts/relevance_pickles/'
out_plot_path = '/lcrc/group/earthscience/rjackson/merra_relevances/'
hour = int(sys.argv[1])
pickle_list = glob.glob(pickles + 'rel-%dhr*.pickle' % hour)

Load a copy of the MERRA2 data to get the base lat/lon grid.

In [None]:
# Get lats, lons for plotting
ds = xr.open_mfdataset(
            '/lcrc/group/earthscience/rjackson/MERRA2/se_reduced/DUCMASS*.nc')
print(ds.time)
x = ds["DUCMASS"].values
lon = ds["lon"].values
lat = ds["lat"].values
lon_inds = np.argwhere(
            np.logical_and(
                ds.lon.values >= ax_extent[0],
                ds.lon.values <= ax_extent[1])).astype(int)
lat_inds = np.argwhere(
            np.logical_and(
                ds.lat.values >= ax_extent[2],
                ds.lat.values <= ax_extent[3])).astype(int)
lon = lon[lon_inds]
lat = lat[lat_inds]

time = ds["time"].values
ds.close()

Calculate the normalized relevance for each time period. The normalized relevance here simply scales each of the input relevances from 0 to 1, where 0 is the minimum value of relevance for the time period and 1 is the maximum relevance.

In [None]:
classification = ['Good', 'Moderate', 'Unhealthy Sens.', 'Unhealthy', 'Hazardous']
mean_relevances = {}
p = open(pickle_list[0], mode="rb")
relevances = pickle.load(p)
input_keys = []
for key in relevances.keys():
    if "input_" in key:
        input_keys.append(key)
        mean_relevances[key] = np.zeros((5, lon.shape[0], lon.shape[1]))

for picks in pickle_list:
    p = open(picks, mode='rb')
    relevances = pickle.load(p)
    classes = relevances['output'].numpy().argmax(axis=1)
    aqi = relevances['aqi'].argmax(axis=1)
    for k in input_keys:
        relevances[k] = relevances[k].numpy()

    
    for j in range(len(classes)):
        for k in input_keys:
            r_min = relevances[k][j, :, :].min()
            r_max = relevances[k][j, :, :].max()
            relevances[k][j, :, :] = (relevances[k][j, :, :] - r_min) / (r_max - r_min)
    true_times = np.array([x in pre_trough for x in soms])

    soms = np.array([get_som(x) for x in relevances['time']])
            
    for j in range(len(classes)):
        sum_all_r = np.squeeze(np.max(np.concatenate(
            [relevances[k][j, :, :] for k in input_keys])))
        num_points[classes[j]] = num_points[classes[j]] + 1 
        for k in input_keys:
            mean_relevances[k][classes[j], :, :] += np.squeeze(
                relevances[k][j, :, :]) 
            
    p.close()

Generate the dataset-wide averages of normalized relevance. Here, relevances near 0 should be interpreted as not relevant, 0.5 as no preference, and 1 as very relevant.

In [None]:
r_max = -np.inf
r_mean = 0
i = 0
for j in range(5):
    for k in input_keys:
        mean_relevances[k][j,:,:] /= num_points[j]
        r_max = np.max([r_max, np.percentile(mean_relevances[k][j, :, :], 95)])
        r_mean += np.mean(mean_relevances[k][j, :, :])
        i += 1
r_mean = r_mean / i

Use Cartopy to generate the plots of relevance over the Houston domain.

In [None]:
states_provinces = cfeature.NaturalEarthFeature(
        category='cultural',
        name='admin_1_states_provinces_lines',
        scale='50m',
        facecolor='none')

for key in input_keys:
    fig, ax = plt.subplots(5, 1,
            subplot_kw=dict(projection=ccrs.PlateCarree()),
            figsize=(10, 15))
    for l in range(1, 6):     
        r = np.squeeze(mean_relevances[key][l - 1])
        
        print(r.max())
        c = ax[l - 1].contourf(lon, lat, r,
            cmap='seismic', levels=np.arange(0, 1., 0.01))
        bar = plt.colorbar(c, label='Normalized relevance',
                ax=ax[l - 1])
       
        ax[l - 1].coastlines()
        ax[l - 1].add_feature(states_provinces)
        ax[l - 1].add_feature(cfeature.BORDERS)
        ax[l - 1].set_title(classification[l-1])
        ax[l - 1].set_xlabel('Latitude')
        ax[l - 1].set_ylabel('Longitude')
    fig.savefig('output_relevance_/relevance-%s.png' % (key))
    plt.close(fig)