In [1]:
import planetary_computer
import xarray as xr
import fsspec
import pystac_client

from kerchunk.combine import MultiZarrToZarr

import pandas as pd
import dask
import dask_gateway

import dask.bag as db 


from dask.distributed import wait


import numpy as np

In [2]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer-test.microsoft.com/stac",
    modifier=planetary_computer.sign_inplace,
)

In [3]:
search = catalog.search(
    collections=["nasa-nex-gddp-cmip6"],
    datetime="2020/2020",
    query={"cmip6:model": {"eq": "MIROC6"}, "cmip6:scenario": {"eq": "ssp585"}},
)
items = search.item_collection()
len(items)

1

In [5]:
items[0].assets

{'pr': <Asset href=https://nasagddp.blob.core.windows.net/nex-gddp-cmip6/NEX/GDDP-CMIP6/MIROC6/ssp585/r1i1p1f1/pr/pr_day_MIROC6_ssp585_r1i1p1f1_gn_2020.nc?st=2023-12-26T13%3A11%3A31Z&se=2024-01-03T13%3A11%3A31Z&sp=rl&sv=2021-06-08&sr=c&skoid=c85c15d6-d1ae-42d4-af60-e2ca0f81359b&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2023-12-27T13%3A11%3A30Z&ske=2024-01-03T13%3A11%3A30Z&sks=b&skv=2021-06-08&sig=2jsntmFQ5PEXK%2BFm7gLHskGtb%2BJKSSaxUwAB7iS%2BnUY%3D>,
 'tas': <Asset href=https://nasagddp.blob.core.windows.net/nex-gddp-cmip6/NEX/GDDP-CMIP6/MIROC6/ssp585/r1i1p1f1/tas/tas_day_MIROC6_ssp585_r1i1p1f1_gn_2020.nc?st=2023-12-26T13%3A11%3A31Z&se=2024-01-03T13%3A11%3A31Z&sp=rl&sv=2021-06-08&sr=c&skoid=c85c15d6-d1ae-42d4-af60-e2ca0f81359b&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2023-12-27T13%3A11%3A30Z&ske=2024-01-03T13%3A11%3A30Z&sks=b&skv=2021-06-08&sig=2jsntmFQ5PEXK%2BFm7gLHskGtb%2BJKSSaxUwAB7iS%2BnUY%3D>,
 'hurs': <Asset href=https://nasagddp.blob.core.windows.net/nex-gddp-cmip6/NE

In [3]:
gateway = dask_gateway.Gateway()
cluster_options = gateway.cluster_options()

In [4]:
cluster_options

VBox(children=(HTML(value='<h2>Cluster Options</h2>'), GridBox(children=(HTML(value="<p style='font-weight: bo…

Options<worker_cores=1.0,
        worker_memory=8.0,
        image='pcccr.azurecr.io/public/planetary-computer/python:2023.6.22.0',
        gpu=False,
        environment={'GDAL_DISABLE_READDIR_ON_OPEN': 'EMPTY_DIR',
         'GDAL_HTTP_MERGE_CONSECUTIVE_RANGES': 'YES',
         'GDAL_HTTP_MAX_RETRY': '5',
         'GDAL_HTTP_RETRY_DELAY': '3',
         'USE_PYGEOS': '0'}>


In [15]:
cluster = gateway.new_cluster(cluster_options)
# cluster = dask_gateway.GatewayCluster(public_address="https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.76e890ca286b43558f8cece0d48e0ff6/individual-scheduler-system")
client = cluster.get_client()

cluster.adapt(minimum=100)

In [16]:
client

0,1
Connection method: Cluster object,Cluster type: dask_gateway.GatewayCluster
Dashboard: https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.6e16c08c660349aba0282ed2b181f3bd/status,


In [None]:
%%time

single_ref_sets = []
sas_token = items[0].assets["pr"].href.split("?")[1]
for d in [item.properties["kerchunk:indices"] for item in items]:
    for key in d["templates"]:
        d["templates"][key] = d["templates"][key] + "?" + sas_token
    single_ref_sets.append(d)
mzz = MultiZarrToZarr(
    single_ref_sets, concat_dims=["time"], identical_dims=["lat", "lon"]
)
d = mzz.translate()
m = fsspec.get_mapper("reference://", fo=d)

m.fs.clear_instance_cache()
ds = xr.open_dataset(
    m, engine="zarr", consolidated=False, decode_times=True, chunks="auto"
)
ds = ds[['pr', 'hurs', 'tas', 'tasmin', 'tasmax']]
ds = ds.convert_calendar(calendar="gregorian", align_on="date", missing=-99)

# ds = ds.chunk({'time': -1, 'lat': 75, 'lon': 90}).persist()

# wait(ds)


In [None]:
items = sorted(items, key=lambda x: x.id)

In [None]:
ds = xr.open_mfdataset(
    [
        fsspec.open(asset.href).open() for item in items for asset in item.assets.values()
    ],
    chunks="auto"
)

ds = ds[['pr', 'hurs', 'tas', 'tasmin', 'tasmax']]
ds = ds.convert_calendar(calendar="gregorian", align_on="date", missing=-99)

ds = ds.chunk({'time': -1, 'lat': 75, 'lon': 90}).persist()

wait(ds)


In [None]:
%%time

campinas_pr = ds.sel(lat=-22.907104, lon=(-47.063240 + 360) % 360, method='nearest').compute()

In [None]:
campinas_pr.to_pandas().head()

In [11]:
cities = pd.read_csv('/home/jovyan/ghslcities_popgte50k.csv', sep=',', encoding = "ISO-8859-1")
cities.head()


Unnamed: 0,latitude,longitude,countrycode,cityname,pop2015
0,21.340678,-157.893497,USA,Honolulu,512853.6667
1,-17.534103,-149.568053,PYF,Papeete,91521.1246
2,34.923123,-120.434372,USA,Santa Maria,123181.2848
3,36.60772,-121.882378,USA,Monterey,67772.28886
4,34.427664,-119.743693,USA,Santa Barbara,114753.1502


In [10]:
def get_zarr_point_data(point, variable):
    data = ds[variable].sel(
        lat=point[1],
        lon=point[0], method='nearest'
    ).compute()

    return pd.Series(data.data, index=[pd.Timestamp(t) for t in data.time.data])

In [None]:
cities_data.shape

In [None]:
%%time

zarr_cities_bag = db.from_sequence(
    zip(cities.longitude.values, cities.latitude.values),
    npartitions=50  # Number of partitions should match the number of workers
)

df = zarr_cities_bag.map(get_zarr_point_data).compute()

In [9]:
def get_model_data(items):
#     single_ref_sets = []
#     sas_token = items[0].assets[variable].href.split("?")[1]
#     for d in [item.properties["kerchunk:indices"] for item in items]:
#         for key in d["templates"]:
#             d["templates"][key] = d["templates"][key] + "?" + sas_token
#         single_ref_sets.append(d)
#     mzz = MultiZarrToZarr(
#         single_ref_sets, concat_dims=["time"], identical_dims=["lat", "lon"]
#     )
#     d = mzz.translate()
#     m = fsspec.get_mapper("reference://", fo=d)

#     m.fs.clear_instance_cache()
#     ds = xr.open_dataset(
#         m, engine="zarr", consolidated=False, decode_times=True, chunks="auto"
#     )
#     ds = ds.convert_calendar(calendar="gregorian", align_on="date", missing=-99)
#     ds = ds.chunk({'time': -1, 'lat': 75, 'lon': 90}).persist()
#     wait(ds)
    
    # return ds
    
#     zarr_cities_bag = db.from_sequence(
#         zip(cities.iloc[:200].longitude.values, cities.iloc[:200].latitude.values),
#         npartitions=100  # Number of partitions should match the number of workers
#     )

#     df = pd.DataFrame(zarr_cities_bag.map(get_zarr_point_data).compute()).T
    
#     cities_data.to_csv("./can5_historical_pr.csv")
    ds = xr.open_mfdataset(
        [
            fsspec.open(asset.href).open() for item in items for asset in item.assets.values()
        ],
        chunks="auto"
    )

    ds = ds[['pr', 'hurs', 'tas', 'tasmin', 'tasmax']]
    ds = ds.convert_calendar(calendar="gregorian", align_on="date", missing=-99)

    ds = ds.chunk({'time': -1, 'lat': 75, 'lon': 90}).persist()

    wait(ds)
    
    return ds
    


In [10]:
def get_model_assets(catalog, model, scenario, period):
    search = catalog.search(
    collections=["nasa-nex-gddp-cmip6"],
    datetime=period,
    query={"cmip6:model": {"eq": model}, "cmip6:scenario": {"eq": scenario}},
)
    items = search.item_collection()
    
    return sorted(items, key=lambda x: x.id)


In [None]:
def run_historical_pipeline(model):
    scenario = "historical"
    period = "1980/2014"
    
    items = get_model_assets(catalog, model, scenario, period)
    
    variables = ['pr']
    variables_data = {}
    # for variable in variables:
    # variables_data[variable] = get_variable_data(items, variable)
        
    # dask.compute(*[dask.delayed(save_cities_data)(variables_data[variable], model, scenario, variable) for variable in variables])
    
    # dask.compute(*[dask.delayed(get_variable_data)(variables_data[variable], model, scenario, variable) for variable in variables])

    model_variable_collection = db.from_sequence(
    ['pr'],
    # npartitions=100 # Number of partitions should match the number of workers
    )
    model_variable_collection.map(get_variable_data, items).compute()
    

**IMPORTANT**  Having the get_zarr_point_data function access the global xr dataset as opposed to passing it as argument dramatically speeds
up the computation (30 sec vs 6 min)

Rechunking to time:-1, lat:75, lon:90 had speed of 3min 30sec for 2000 cities but had intermittent issue of timing out connecting to scheduler.
This happens used delayed method but not bag method.
More importantly, persisting this chunk size scales so good taking 1min for 10k cities

Waiting for chunk-persist to complete and then fetching cities data is the most efficient way...5min30sec (historical range). 
Not waiting is >10mins

Smaller chunk size results in warnings with task size being too large

Let's loop through a model's variables and save

In [8]:
models = sorted([
    'UKESM1-0-LL',
     'NorESM2-MM',
     'NorESM2-LM',
     'MRI-ESM2-0',
     'MPI-ESM1-2-LR',
     'MPI-ESM1-2-HR',
     'MIROC6',
     'MIROC-ES2L',
     'KIOST-ESM',
     'KACE-1-0-G',
     'IPSL-CM6A-LR',
     'INM-CM5-0',
     'INM-CM4-8',
     'HadGEM3-GC31-MM',
     'HadGEM3-GC31-LL',
     'GFDL-ESM4',
     'GFDL-CM4',
     'FGOALS-g3',
     'EC-Earth3-Veg-LR',
     'EC-Earth3',
     'CanESM5',
     'CNRM-ESM2-1',
     'CNRM-CM6-1',
     'CMCC-ESM2',
     'CMCC-CM2-SR5',
     'ACCESS-ESM1-5',
     'ACCESS-CM2',
     'TaiESM1'
])
models


['ACCESS-CM2',
 'ACCESS-ESM1-5',
 'CMCC-CM2-SR5',
 'CMCC-ESM2',
 'CNRM-CM6-1',
 'CNRM-ESM2-1',
 'CanESM5',
 'EC-Earth3',
 'EC-Earth3-Veg-LR',
 'FGOALS-g3',
 'GFDL-CM4',
 'GFDL-ESM4',
 'HadGEM3-GC31-LL',
 'HadGEM3-GC31-MM',
 'INM-CM4-8',
 'INM-CM5-0',
 'IPSL-CM6A-LR',
 'KACE-1-0-G',
 'KIOST-ESM',
 'MIROC-ES2L',
 'MIROC6',
 'MPI-ESM1-2-HR',
 'MPI-ESM1-2-LR',
 'MRI-ESM2-0',
 'NorESM2-LM',
 'NorESM2-MM',
 'TaiESM1',
 'UKESM1-0-LL']

In [13]:
def save_variable_data(variable):
    def get_zarr_point_data(point, variable):
        data = ds[variable].sel(
            lat=point[1],
            lon=point[0], method='nearest'
        ).compute()

        return pd.Series(data.data, index=[pd.Timestamp(t) for t in data.time.data])
    

    data = zarr_cities_bag.map(get_zarr_point_data, variable).compute()
    df = pd.DataFrame(data).T
    df.columns = (cities.countrycode + cities.cityname.apply(lambda x: f"_{x}")).values


    df.to_csv(
        f"s3://cities-climate-hazard/{model}_{variable}_{period.replace('/', '-')}.csv",
        index=False,
        storage_options={
            "key": "AKIAUAAZPB7LT747PAX7",
            "secret": "LTM3UJ7iMogIAVPYxfcstGqtpEiwUl0qOLlr+vSC",
        },
    )
    

In [17]:
# variables = ['hurs', 'pr', 'tas', 'tasmax', 'tasmin']
variables = ['tas', 'tasmax', 'tasmin', 'pr', 'hurs']
# model = 'CMCC-CM2-SR5'
period = '1980/2014'
scenario = 'historical'

for model in models[25:26]:
    search = catalog.search(
        collections=["nasa-nex-gddp-cmip6"],
        datetime=period,
        query={"cmip6:model": {"eq": model}, "cmip6:scenario": {"eq": "historical"}},
    )
    items = search.item_collection()
    items = sorted(items, key=lambda x: x.id)
    ds = get_model_data(items)
    
    for variable in variables:
        data = zarr_cities_bag.map(get_zarr_point_data, variable).compute()
        df = pd.DataFrame(data).T
        df.columns = (cities.countrycode + cities.cityname.apply(lambda x: f"_{x}")).values


        df.to_csv(
            f"s3://cities-climate-hazard/{model}_{variable}_{period.replace('/', '-')}.csv",
            index=False,
            storage_options={
                "key": "AKIAUAAZPB7LT747PAX7",
                "secret": "LTM3UJ7iMogIAVPYxfcstGqtpEiwUl0qOLlr+vSC",
            },
        )

In [13]:
zarr_cities_bag = db.from_sequence(
    zip(cities.longitude.values, cities.latitude.values),
    npartitions=100  # Number of partitions should match the number of workers
)

search = catalog.search(
    collections=["nasa-nex-gddp-cmip6"],
    datetime=period,
    query={"cmip6:model": {"eq": model}, "cmip6:scenario": {"eq": "historical"}},
)
items = search.item_collection()
items = sorted(items, key=lambda x: x.id)

In [None]:
%%time


# items = get_model_assets(catalog, model, scenario, period)
ds = get_model_data(items)

In [None]:
def extract_models_data(model):
    def get_zarr_point_data(point, variable):
        data = ds[variable].sel(
            lat=point[1],
            lon=point[0], method='nearest'
        ).compute()

        return pd.Series(data.data, index=[pd.Timestamp(t) for t in data.time.data])


    variables = ['tas', 'tasmax', 'tasmin', 'pr', 'hurs']
    # model = 'CMCC-CM2-SR5'
    period = '1980/2014'
    scenario = 'historical'
    
    search = catalog.search(
    collections=["nasa-nex-gddp-cmip6"],
    datetime=period,
    query={"cmip6:model": {"eq": model}, "cmip6:scenario": {"eq": "historical"}},
)
    items = search.item_collection()
    items = sorted(items, key=lambda x: x.id)
    
    ds = get_model_data(items)
    
    
    for variable in variables:
        data = zarr_cities_bag.map(get_zarr_point_data, variable).compute()
        df = pd.DataFrame(data).T
        df.columns = (cities.countrycode + cities.cityname.apply(lambda x: f"_{x}")).values
        
        
        df.to_csv(
            f"s3://cities-climate-hazard/{model}_{variable}_{period.replace('/', '-')}.csv",
            index=False,
            storage_options={
                "key": "AKIAUAAZPB7LT747PAX7",
                "secret": "LTM3UJ7iMogIAVPYxfcstGqtpEiwUl0qOLlr+vSC",
            },
        )