# Using masks and computing weighted average

This example is based from xarray example http://xarray.pydata.org/en/stable/examples/area_weighted_temperature.html

## Import python packages

In [None]:
import xarray as xr

xr.set_options(display_style="html")
import intake
import cftime
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np

%matplotlib inline

In [None]:
cat_url = "https://storage.googleapis.com/cmip6/pangeo-cmip6.json"
col = intake.open_esm_datastore(cat_url)
col

## Search data

In [None]:
cat = col.search(
    source_id=["NorESM2-LM"],
    experiment_id=["historical"],
    table_id=["Amon"],
    variable_id=["tas"],
    member_id=["r1i1p1f1"],
)
cat.df

## Create dictionary from the list of datasets we found
- This step may take several minutes so be patient!

In [None]:
dset_dict = cat.to_dataset_dict(zarr_kwargs={"use_cftime": True})

In [None]:
list(dset_dict.keys())

In [None]:
dset = dset_dict[list(dset_dict.keys())[0]]
dset

Plot the first timestep

In [None]:
projection = ccrs.Mercator(central_longitude=-10)

f, ax = plt.subplots(subplot_kw=dict(projection=projection))

dset["tas"].isel(time=0).plot(
    transform=ccrs.PlateCarree(), cbar_kwargs=dict(shrink=0.7), cmap="coolwarm"
)
ax.coastlines()

## Compute weighted mean

1. Creating weights: for a rectangular grid the cosine of the latitude is proportional to the grid cell area.
2. Compute weighted mean values

In [None]:
def computeWeightedMean(ds):
    # Compute weights based on the xarray you pass
    weights = np.cos(np.deg2rad(ds.lat))
    weights.name = "weights"
    # Compute weighted mean
    air_weighted = ds.weighted(weights)
    weighted_mean = air_weighted.mean(("lon", "lat"))
    return weighted_mean

## Compute weighted average over the entire globe

In [None]:
weighted_mean = computeWeightedMean(dset)

## Comparison with unweighted mean
- We select a time range
- Note how the weighted mean temperature is higher than the unweighted.

In [None]:
weighted_mean["tas"].sel(time=slice("2000-01-01", "2010-01-01")).plot(label="weighted")
dset["tas"].sel(time=slice("2000-01-01", "2010-01-01")).mean(("lon", "lat")).plot(
    label="unweighted"
)

plt.legend()

## Compute Weigted arctic average
Let's try to also take only the data above 60$^\circ$

In [None]:
weighted_mean = computeWeightedMean(dset.where(dset["lat"] > 60.0))

In [None]:
weighted_mean["tas"].sel(time=slice("2000-01-01", "2010-01-01")).plot(label="weighted")
dset["tas"].where(dset["lat"] > 60.0).sel(time=slice("2000-01-01", "2010-01-01")).mean(
    ("lon", "lat")
).plot(label="unweighted")

plt.legend()