In [1]:
import xarray as xr
import geopandas as gpd
import numpy as np
import datetime
import pandas as pd
from collections import OrderedDict
#from pathsWY import *
#from SM_tools import *
# from os import listdir
# from os.path import isfile, join
import requests
import glob
import ulmo
from xgrads import open_CtlDataset
from shapely import geometry as sgeom
from scipy.spatial import distance
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

In [2]:
#########################################################################
# User inputs
#########################################################################

domain = 'CA'


#select the water year of interest 
water_year = 2019

#start date
stdt = str(water_year -1) +'-10-01'
#end date
eddt = str(water_year)+'-09-30'

#path to SM output .gdat files
gdatPath = r"/scratch/Nina/CSOdata/"+domain+"/"
# path to output .nc files
assim_file_path = r"/nfs/attic/dfh/Aragon2/CSOassim/"+domain+"/"

In [3]:
#########################################################################
# SNOTEL Functions
#########################################################################
# functions to get snotel data in each domain to compare to SM outputs

# functions to get SNOTEL stations as geodataframe
def sites_asgdf(ulmo_getsites, stn_proj):
    """ Convert ulmo.cuahsi.wof.get_sites response into a point GeoDataframe
    """
    
    # Note: Found one SNOTEL site that was missing the location key
    sites_df = pd.DataFrame.from_records([
        OrderedDict(code=s['code'], 
        longitude=float(s['location']['longitude']), 
        latitude=float(s['location']['latitude']), 
        name=s['name'], 
        elevation_m=s['elevation_m'])
        for _,s in ulmo_getsites.items()
        if 'location' in s
    ])

    sites_gdf = gpd.GeoDataFrame(
        sites_df, 
        geometry=gpd.points_from_xy(sites_df['longitude'], sites_df['latitude']),
        crs=stn_proj
    )
    return sites_gdf

def get_snotel_stns(domain):
    
    #path to CSO domains
    domains_resp = requests.get("https://raw.githubusercontent.com/snowmodel-tools/preprocess_python/master/CSO_domains.json")
    domains = domains_resp.json()

    #Snotel bounding box
    Bbox = domains[domain]['Bbox']

    # Snotel projection
    stn_proj = domains[domain]['stn_proj']
    # model projection
    mod_proj = domains[domain]['mod_proj']

    # Convert the bounding box dictionary to a shapely Polygon geometry using sgeom.box
    box_sgeom = sgeom.box(Bbox['lonmin'], Bbox['latmin'], Bbox['lonmax'], Bbox['latmax'])
    box_gdf = gpd.GeoDataFrame(geometry=[box_sgeom], crs=stn_proj)
    
    # WaterML/WOF WSDL endpoint url 
    if domain == 'NH':
        wsdlurl = "https://hydroportal.cuahsi.org/Scan/cuahsi_1_1.asmx?WSDL"
    else:
        wsdlurl = "https://hydroportal.cuahsi.org/Snotel/cuahsi_1_1.asmx?WSDL"

    # get dictionary of snotel sites 
    sites = ulmo.cuahsi.wof.get_sites(wsdlurl,user_cache=True)

    #turn sites to geodataframe 
    snotel_gdf = sites_asgdf(sites,stn_proj)
    
    #clip snotel sites to domain bounding box
    gdf = gpd.sjoin(snotel_gdf, box_gdf, how="inner")
    gdf.drop(columns='index_right', inplace=True)
    gdf.reset_index(drop=True, inplace=True)

    #add columns with projected coordinates 
    CSO_proj = gdf.to_crs(mod_proj)
    gdf['easting'] = CSO_proj.geometry.x
    gdf['northing'] = CSO_proj.geometry.y
    
    return gdf


def fetch(sitecode, variablecode, domain,start_date, end_date):
    print(sitecode, variablecode, domain,start_date, end_date)
    values_df = None
    # WaterML/WOF WSDL endpoint url 
    if domain == 'NH':
        wsdlurl = "https://hydroportal.cuahsi.org/Scan/cuahsi_1_1.asmx?WSDL"
        network = 'SCAN:'
    else:
        wsdlurl = "https://hydroportal.cuahsi.org/Snotel/cuahsi_1_1.asmx?WSDL"
        network = 'SNOTEL:'

    try:
        #Request data from the server
        site_values = ulmo.cuahsi.wof.get_values(
            wsdlurl, network+sitecode, variablecode, start=start_date, end=end_date
        )
        #Convert to a Pandas DataFrame   
        values_df = pd.DataFrame.from_dict(site_values['values'])
        #Parse the datetime values to Pandas Timestamp objects
        values_df['datetime'] = pd.to_datetime(values_df['datetime'])
        #Set the DataFrame index to the Timestamps
        values_df.set_index('datetime', inplace=True)
        #Convert values to float and replace -9999 nodata values with NaN
        values_df['value'] = pd.to_numeric(values_df['value']).replace(-9999, np.nan)
        #Remove any records flagged with lower quality
        values_df = values_df[values_df['quality_control_level_code'] == '1']
    except:
        print("Unable to fetch %s" % variablecode)
    
    return values_df

def get_snotel_data(gdf,sddt, eddt,var,domain,units='metric'):
    '''
    gdf - pandas geodataframe of SNOTEL sites
    st_dt - start date string 'yyyy-mm-dd'
    ed_dt - end date string 'yyyy-mm-dd'
    var - snotel variable of interest 
    units - 'metric' (default) or 'imperial'
    '''
    stn_data = pd.DataFrame(index=pd.date_range(start=stdt, end=eddt))
    if domain == 'NH':
        network = 'SCAN:'
    else:
        network = 'SNOTEL:'    

    for sitecode in gdf.code:
        try:
            data = fetch(sitecode,network+var+'_D', domain, start_date=stdt, end_date=eddt)
            #check for nan values
            if len(data.value[np.isnan(data.value)]) > 0:
                #check if more than 10% of data is missing
                if len(data.value[np.isnan(data.value)])/len(data) > .02:
                    print('More than 2% of days missing')
                    gdf.drop(gdf.loc[gdf['code']==sitecode].index, inplace=True)
                    continue
                if np.mean(data) < 0:
                    print('Average swe is <=0, removing station')
                    gdf.drop(gdf.loc[gdf['code']==sitecode].index, inplace=True)
                    continue                    
            stn_data[sitecode] = data.value
        except:
            gdf.drop(gdf.loc[gdf['code']==sitecode].index, inplace=True)     
    
    gdf.reset_index(drop=True, inplace=True)
    if units == 'metric':
        if (var == 'WTEQ') |(var == 'SNWD') |(var == 'PRCP') |(var == 'PREC'):
            #convert SNOTEL units[in] to [m]
            for sitecode in gdf.code:
                stn_data[sitecode] = 0.0254 * stn_data[sitecode]
        elif (var == 'TAVG') |(var == 'TMIN') |(var == 'TMAX') |(var == 'TOBS'):
            #convert SNOTEL units[F] to [C]
            for sitecode in gdf.code:
                stn_data[sitecode] = (stn_data[sitecode] - 32) * 5/9
    return gdf, stn_data


In [4]:
# function to compute model performance metrics
def calc_metrics(mod_swe,stn_swe):
    swe_stats = []
    
    #remove days with zero SWE at BOTH the station and the SM pixel
    idx = np.where((stn_swe != 0) | (mod_swe != 0))
    mod_swe = mod_swe[idx]
    stn_swe = stn_swe[idx]

    #remove days where station has nan values 
    idx = np.where(~np.isnan(stn_swe))
    mod_swe = mod_swe[idx]
    stn_swe = stn_swe[idx]
    
    if (np.mean(mod_swe) < 0):
        print('undefined point in SnowModel')
        swe_stats = [np.nan,np.nan,np.nan,np.nan,np.nan]
    else:
        #R-squared value - coefficient of determination 
        r = r2_score(stn_swe, mod_swe)
        swe_stats.append(r)

        #mean bias error
        mbe = (sum(mod_swe - stn_swe))/mod_swe.shape[0]
        swe_stats.append(mbe)

        #root mean squared error
        rmse = np.sqrt((sum((mod_swe - stn_swe)**2))/mod_swe.shape[0])
        swe_stats.append(rmse)

        # Nash-Sutcliffe model efficiency coefficient, 1 = perfect, assumes normal data 
        nse_top = sum((mod_swe - stn_swe)**2)
        nse_bot = sum((stn_swe - np.mean(stn_swe))**2)
        nse = 1-(nse_top/nse_bot)
        swe_stats.append(nse)

        # Kling-Gupta Efficiency, 1 = perfect
        kge_std = (np.std(mod_swe)/np.std(stn_swe))
        kge_mean = (np.mean(mod_swe)/np.mean(stn_swe))
        kge_r = np.corrcoef(stn_swe,mod_swe)[1,0]
        kge = 1 - (np.sqrt((kge_r-1)**2)+((kge_std-1)**2)+(kge_mean-1)**2)
        swe_stats.append(kge)   
        
    return swe_stats

In [5]:
# function to edit text files
def replace_line(file_name, line_num, text):
    ''' 
    file_name = file to edit
    line_num = line number in file to edit
    text = nex text to put in
    '''
    lines = open(file_name, 'r').readlines()
    lines[line_num] = text
    out = open(file_name, 'w')
    out.writelines(lines)
    out.close()

In [6]:
# function to return the nearest easting and norting to pt in a raster
def nearest_grid(ds, pt):
    """
    Returns the nearest easting and norting to pt in a Dataset (ds).
    
    pt : input point, tuple (easting, northing)
    output:
        easting, northing
    """
    if all(coord in list(ds.coords) for coord in ['lat', 'lon']):
            df_loc = ds[['lon', 'lat']].to_dataframe().reset_index()
    loc_valid = df_loc.dropna()
    pts = loc_valid[['lon', 'lat']].to_numpy()
    idx = distance.cdist([pt], pts).argmin()
    return loc_valid['lon'].iloc[idx], loc_valid['lat'].iloc[idx]

In [7]:
# Function to extract SM swe values at pixels co-located with stations
# and computer performance metrics.
# Data saved out as netcdfs
def SM_eval(domain,stdt,eddt, gdatPath):  
    #get domain info
    domains_resp = requests.get("https://raw.githubusercontent.com/snowmodel-tools/preprocess_python/master/CSO_domains.json")
    domains = domains_resp.json()
    
    # execute functions to get snotel data
    stn_gdf = get_snotel_stns(domain)
    STNgdf, stn_swe_all = get_snotel_data(stn_gdf,stdt,eddt,'WTEQ',domain)

    # list of SM swed files 
    target = gdatPath + r"*.gdat"
    lens = 23+len(domain)
    filenames = sorted([f[lens:-5] for f in glob.glob(target)])


    # create an empty numpy array of dimensions 
    # [#ensemble_members #stations #timesteps]
    data = np.empty([len(filenames), len(STNgdf), len(pd.date_range(stdt,eddt,freq='d'))])

    # create an empty numpy array of dimensions 
    # [#ensemble_members #stations #metrics]
    statsdata = np.empty([len(filenames), len(STNgdf), 5])

    for f in range(len(filenames)):
        print(f+1,' of', len(filenames))
        ctlFile = gdatPath + r"/swed.ctl"
        #baseline SM run 
        text = r"DSET ^"+filenames[f]+".gdat\n"
        replace_line(ctlFile, 0, text)
        modswe = open_CtlDataset(ctlFile)

        for i in range(len(STNgdf)):
            nam = STNgdf.code[i]
            print(nam)

            # define point
            pt = (STNgdf.to_crs(domains[domain]['mod_proj']).geometry[i].x,STNgdf.to_crs(domains[domain]['mod_proj']).geometry[i].y)

            # get nearest easting and northing to SM output
            long, lati = nearest_grid(modswe, pt)

            # select SM grid cell and subset to point
            mod = modswe.swed.sel(lon=long,lat=lati)

            # select station observations
            stn = stn_swe_all[STNgdf.code[i]].values

            # calculate performance statistics
            stats = calc_metrics(mod,stn)

            # add SM data to data array 
            data[f,i,:] = mod.values

            # add stats data to data array 
            statsdata[f,i,:] = stats

    #save SM swe output as netcdf
    date = pd.date_range(stdt,eddt,freq='d')
    station = STNgdf['code'].values

    SMswe = xr.DataArray(
        data,
        dims=('assim_run', 'station', 'date'), 
        coords={'assim_run': filenames, 
                'station': station, 'date': date})

    SMswe.attrs['long_name']= 'Assimilation SWE at stations'
    SMswe.attrs['standard_name']= 'assim_swe'

    d = OrderedDict()
    d['assim_run'] = ('assim_run', filenames)
    d['station'] = ('station', station)
    d['date'] = ('date', date)
    d['swe'] = SMswe

    ds = xr.Dataset(d)
    ds.attrs['description'] = "SnowModel swe at stations"
    ds.attrs['model_output'] = "SWE [m]"

    ds.assim_run.attrs['standard_name'] = "assimilation_run"
    ds.assim_run.attrs['axis'] = "run_id"

    ds.station.attrs['long_name'] = "station_id"
    ds.station.attrs['axis'] = "station"

    ds.date.attrs['long_name'] = "date"
    ds.date.attrs['axis'] = "date"

    #save performance stats as netcdf
    metrics = ['R2','MBE','RMSE','NSE','KGE']

    assimstats = xr.DataArray(
        statsdata,
        dims=('assim_run', 'station', 'metrics'), 
        coords={'assim_run': filenames, 
                'station': station, 'metrics': metrics})

    assimstats.attrs['long_name']= 'Performance metrics at stations'
    assimstats.attrs['standard_name']= 'metrics'

    dd = OrderedDict()
    dd['assim_run'] = ('assim_run', filenames)
    dd['station'] = ('station', station)
    dd['metrics'] = ('metrics', metrics)
    dd['score'] = assimstats

    dss = xr.Dataset(dd)
    dss.attrs['description'] = "Performance metrics at stations"
    dss.attrs['model_output'] = "R^2 MBE RMSE NSE KGE"

    dss.assim_run.attrs['standard_name'] = "assimilation_run"
    dss.assim_run.attrs['axis'] = "run_id"

    dss.station.attrs['long_name'] = "station_id"
    dss.station.attrs['axis'] = "station"

    dss.metrics.attrs['long_name'] = "performance_metrics"
    dss.metrics.attrs['axis'] = "metrics"
    
    return ds, dss

In [8]:
#domain_list = ['CA','CO_N','CO_S','OR','UT','WA','WA_SQ','WY']
domain_list = ['OR']


for domain in domain_list: 
    print(domain)
    #path to SM output .gdat files
    gdatPath = r"/scratch/Nina/CSOdata/"+domain+"/"
    # path to output .nc files
    assim_file_path = r"/nfs/attic/dfh/Aragon2/CSOassim/"+domain+"/"

    # eval SM performance
    ds, dss = SM_eval(domain,stdt,eddt, gdatPath)
    
    #output .nc file name/path
    outfilepath = assim_file_path + 'assim_swe_'+str(water_year)+'.nc'
    ds.to_netcdf(outfilepath, format='NETCDF4', engine='netcdf4')
    #output .nc file name/path
    outfilepath = assim_file_path + 'assim_stats_'+str(water_year)+'.nc'
    dss.to_netcdf(outfilepath, format='NETCDF4', engine='netcdf4')

OR
1166_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
434_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
1025_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30


<suds.sax.document.Document object at 0x7faa9e358b70>


Unable to fetch SNOTEL:WTEQ_D
1024_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30


<suds.sax.document.Document object at 0x7faa9e2abfd0>


Unable to fetch SNOTEL:WTEQ_D
976_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30


<suds.sax.document.Document object at 0x7faa9e50ec18>


Unable to fetch SNOTEL:WTEQ_D
526_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
545_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
614_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
619_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
719_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
733_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
1167_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30
815_OR_SNTL SNOTEL:WTEQ_D OR 2018-10-01 2019-09-30


In [None]:
#get domain info
domains_resp = requests.get("https://raw.githubusercontent.com/snowmodel-tools/preprocess_python/master/CSO_domains.json")
domains = domains_resp.json()

# execute functions to get snotel data
stn_gdf = get_snotel_stns(domain)
STNgdf, stn_swe_all = get_snotel_data(stn_gdf,stdt,eddt,'WTEQ',domain)

# list of SM swed files 
target = gdatPath + r"*.gdat"
lens = 23+len(domain)
filenames = sorted([f[lens:-5] for f in glob.glob(target)])

f=0
ctlFile = gdatPath + r"/swed.ctl"
#baseline SM run 
text = r"DSET ^"+filenames[f]+".gdat\n"
replace_line(ctlFile, 0, text)
modswe = open_CtlDataset(ctlFile)

# make a figure
fig, axs = plt.subplots(nrows=len(STNgdf), figsize=(14, 2*len(STNgdf)), ncols=1, facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace = .5, wspace=.5)

for i in range(len(STNgdf)):
    nam = STNgdf.code[i]
    
    # define point
    pt = (STNgdf.to_crs(domains[domain]['mod_proj']).geometry[i].x,STNgdf.to_crs(domains[domain]['mod_proj']).geometry[i].y)

    # get nearest easting and northing to SM output
    long, lati = nearest_grid(modswe, pt)

    # select SM grid cell and subset to point
    mod = modswe.swed.sel(lon=long,lat=lati)

    # select station observations
    stn = stn_swe_all[STNgdf.code[i]].values
    
    # calculate performance statistics
    stats = calc_metrics(mod,stn)
    
    axs[i].plot(mod,'r',label = 'SM swe [m]',alpha = .5)
    axs[i].plot(stn,'b',label = 'station swe [m]',alpha = .5)
    axs[i].set_title(nam+' '+str(stats[0]))
    #axs[i].set_ylim([0, 2])
    if i == 0:
        axs[i].legend()