In [3]:
###### upscale model bootstrap prediction
###### upscale model bootstrap variance

import os
import sys
import urllib.request
import re
import argparse
import numpy as np
import pandas as pd
import netCDF4 as nc
import xarray as xr
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV
import pickle
import datetime
from scipy import stats
from osgeo import gdal

# load model
site_df = pd.read_csv('site_list.csv')
filter_df = site_df.loc[(site_df['WETLAND_CL']=='Fen') | (site_df['WETLAND_CL']=='Bog') | (site_df['WETLAND_CL']=='Wet tundra')]

models = []
for row in filter_df.itertuples(index=True, name='Pandas'):
    id = row.NewID
    models.append(pickle.load(open('model_%s.sav'%(id),'rb')))

nbar_path = "Data/DailyNBAR"
smap_path = "Data/DailySMAP"
merra_path = "Data/merra2"
topo_path = "Data/TOPO"
out_path = "Upscale/bootstrap_mean"
out_vpath = "Upscale/bootstrap_std"


merra_vars = ['pa','tas', 'spfh','ts1','ts2','ts3','rsds','rsdl','le','h']
topo_vars = ['dem','slope','cti','spi']
lat_range = slice(90.0,45.0)
lon_range = slice(-180.0, 180.0)

# loop per day
for filename in os.listdir(nbar_path):
    filedate = os.path.splitext(filename)[0]
    filext = os.path.splitext(filename)[1]
    features = np.empty(shape=3)
    if filext == '.tif':
        print(filedate)
        out_file = os.path.join(out_path, filename)
        out_vfile = os.path.join(out_vpath, filename)
        if os.path.isfile(out_file):
            continue
        else:
            # read daily predictor data
            nbar_file = os.path.join(nbar_path, filename)
            smap_file = os.path.join(smap_path, filename)
            ds_gdal = gdal.Open(nbar_file)
            nbar_data = ds_gdal.ReadAsArray()
            smap_data = gdal.Open(smap_file).ReadAsArray()
            features = np.concatenate((nbar_data, smap_data), axis=0)
            print(nbar_data.shape)
            print(smap_data.shape)
            for var in merra_vars:
                ds = xr.open_dataset('Data/merra2/merra2.%s.daily.nc'%(var))
                data = ds[var].sel(time=filedate, lat=lat_range, lon=lon_range)
                features = np.concatenate((features, data), axis=0)
                print(data.shape)
                ds.close()
            for tvar in topo_vars:
                tdata = gdal.Open('Data/TOPO/%s.tif'%(tvar)).ReadAsArray()
                tdata = np.array([tdata])
                features = np.concatenate((features, tdata), axis=0)
                print(tdata.shape)
                
                
            #predict a day
            mm = features
            img_to_arr = mm.reshape(mm.shape[0],-1).T
            print(img_to_arr.shape)
            img_to_arr = np.nan_to_num(img_to_arr)
            rf_preds = np.stack([model.predict(img_to_arr) for model in models], axis=1)
            pred_mean = np.mean(rf_preds, axis=1)
            pred_std = np.std(rf_preds, axis=1)
            mean_img_out = pred_mean.reshape(features.shape[1],features.shape[2])
            std_img_out = pred_std.reshape(features.shape[1],features.shape[2])


            ##### write out prediction image
            driver = gdal.GetDriverByName("GTiff")        
            rows = features.shape[1]
            cols = features.shape[2]
            outdata = driver.Create(out_file, cols, rows, 1, gdal.GDT_Float32)
            outdata.SetGeoTransform(ds_gdal.GetGeoTransform())##sets same geotransform as input
            outdata.SetProjection(ds_gdal.GetProjection())##sets same projection as input
            outdata.GetRasterBand(1).WriteArray(mean_img_out)
            outdata.FlushCache() ##saves to disk!!
            outdata = None
            
            outdata = driver.Create(out_vfile, cols, rows, 1, gdal.GDT_Float32)
            outdata.SetGeoTransform(ds_gdal.GetGeoTransform())##sets same geotransform as input
            outdata.SetProjection(ds_gdal.GetProjection())##sets same projection as input
            outdata.GetRasterBand(1).WriteArray(std_img_out)
            outdata.FlushCache() ##saves to disk!!
            outdata = None         
            ds_gdal = None
            

2022-07-26
(7, 456, 3644)
(3, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)


  decode_timedelta=decode_timedelta,
  condition |= data == fv
  decode_timedelta=decode_timedelta,
  decode_timedelta=decode_timedelta,


(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1661664, 24)
2022-07-27
(7, 456, 3644)
(3, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)


  decode_timedelta=decode_timedelta,
  condition |= data == fv
  decode_timedelta=decode_timedelta,
  decode_timedelta=decode_timedelta,


(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1661664, 24)
2022-07-31
(7, 456, 3644)
(3, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)


  decode_timedelta=decode_timedelta,
  condition |= data == fv
  decode_timedelta=decode_timedelta,
  decode_timedelta=decode_timedelta,


(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1661664, 24)
2022-07-30
(7, 456, 3644)
(3, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)


  decode_timedelta=decode_timedelta,
  condition |= data == fv
  decode_timedelta=decode_timedelta,
  decode_timedelta=decode_timedelta,


(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1661664, 24)
2022-07-29
(7, 456, 3644)
(3, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)


  decode_timedelta=decode_timedelta,
  condition |= data == fv
  decode_timedelta=decode_timedelta,
  decode_timedelta=decode_timedelta,


(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1661664, 24)
2022-07-28
(7, 456, 3644)
(3, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)


  decode_timedelta=decode_timedelta,
  condition |= data == fv
  decode_timedelta=decode_timedelta,
  decode_timedelta=decode_timedelta,


(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1, 456, 3644)
(1661664, 24)
