In [34]:
#! /usr/bin/env python
#################################################################
###  This program is part of PyRite  v1.0                     ### 
###  Copy Right (c): 2020, Yunmeng Cao                        ###  
###  Author: Yunmeng Cao                                      ###                                                          
###  Email : ymcmrs@gmail.com                                 ###
###  Univ. : King Abdullah University of Science & Technology ###   
#################################################################
from numpy import *
import sys
import os
import re
import subprocess
import argparse
import numpy as np
import h5py
import glob

from scipy.interpolate import griddata
import scipy.interpolate as intp
from scipy.optimize import leastsq
from scipy.stats.stats import pearsonr

from pyrite import elevation_models
from pyrite import _utils as ut
from pykrige import OrdinaryKriging
from pykrige import variogram_models
#import matlab.engine # using matlab to estimate the variogram parameters
from mintpy.utils import ptime
###############################################################

model_dict = {'linear': elevation_models.linear_elevation_model,
                      'onn': elevation_models.onn_elevation_model,
                      'onn_linear': elevation_models.onn_linear_elevation_model,
                      'exp': elevation_models.exp_elevation_model,
                      'exp_linear': elevation_models.exp_linear_elevation_model}

residual_dict = {'linear': elevation_models.residuals_linear,
                      'onn': elevation_models.residuals_onn,
                      'onn_linear': elevation_models.residuals_onn_linear,
                      'exp': elevation_models.residuals_exp,
                      'exp_linear': elevation_models.residuals_exp_linear}

initial_dict = {'linear': elevation_models.initial_linear,
                      'onn': elevation_models.initial_onn,
                      'onn_linear': elevation_models.initial_onn_linear,
                      'exp': elevation_models.initial_exp,
                      'exp_linear': elevation_models.initial_exp_linear}

para_numb_dict = {'linear': 2,
                  'onn' : 3,
                  'onn_linear':4,
                  'exp':2,
                  'exp_linear':3}


variogram_dict = {'linear': variogram_models.linear_variogram_model,
                      'power': variogram_models.power_variogram_model,
                      'gaussian': variogram_models.gaussian_variogram_model,
                      'spherical': variogram_models.spherical_variogram_model,
                      'exponential': variogram_models.exponential_variogram_model,
                      'hole-effect': variogram_models.hole_effect_variogram_model}

def get_fname_list(date_list,area,hr):
    
    flist = []
    for k0 in date_list:
        f0 = 'ERA-5{}_{}_{}.grb'.format(area, k0, hr)
        flist.append(f0)
        
    return flist

def get_meta_corner(meta):
    if 'Y_FIRST' in meta.keys():
        length = int(meta['LENGTH'])
        width = int(meta['WIDTH'])
        lat0 = float(meta['Y_FIRST'])
        lon0 = float(meta['X_FIRST'])
        lat_step = float(meta['Y_STEP'])
        lon_step = float(meta['X_STEP'])
        lat1 = lat0 + lat_step * (length - 1)
        lon1 = lon0 + lon_step * (width - 1)
        
        NORTH = lat0
        SOUTH = lat1  
        WEST = lon0
        EAST = lon1    
        
    else:       
        lats = [float(meta['LAT_REF{}'.format(i)]) for i in [1,2,3,4]]
        lons = [float(meta['LON_REF{}'.format(i)]) for i in [1,2,3,4]]
        
        lat0 = np.max(lats[:])
        lat1 = np.min(lats[:])
        lon0 = np.min(lons[:])
        lon1 = np.max(lons[:])
        
        NORTH = lat0 + 0.1
        SOUTH = lat1 + 0.1
        WEST = lon0 + 0.1
        EAST = lon1 + 0.1
    
    return WEST,SOUTH,EAST,NORTH


def era5_time(research_time0):
    
    research_time =  round(float(research_time0) / 3600)
    if len(str(research_time)) == 1:
        time0 = '0' + str(research_time)
    else:
        time0 = str(research_time)
    
    return time0

def ceil_to_5(x):
    """Return the closest number in multiple of 5 in the larger direction"""
    assert isinstance(x, (int, np.int16, np.int32, np.int64)), 'input number is not int: {}'.format(type(x))
    if x % 5 == 0:
        return x
    return x + (5 - x % 5)

def floor_to_5(x):
    """Return the closest number in multiple of 5 in the lesser direction"""
    assert isinstance(x, (int, np.int16, np.int32, np.int64)), 'input number is not int: {}'.format(type(x))
    return x - x % 5

def floor_to_1(x):
    """Return the closest number in multiple of 5 in the lesser direction"""
    assert isinstance(x, (int, np.int16, np.int32, np.int64)), 'input number is not int: {}'.format(type(x))
    return x - x % 1

def ceil_to_1(x):
    """Return the closest number in multiple of 5 in the larger direction"""
    assert isinstance(x, (int, np.int16, np.int32, np.int64)), 'input number is not int: {}'.format(type(x))
    if x % 1 == 0:
        return x
    return x + (1 - x % 1)

def floor_to_2(x):
    """Return the closest number in multiple of 5 in the lesser direction"""
    assert isinstance(x, (int, np.int16, np.int32, np.int64)), 'input number is not int: {}'.format(type(x))
    return x - x % 2

def ceil_to_2(x):
    """Return the closest number in multiple of 5 in the larger direction"""
    assert isinstance(x, (int, np.int16, np.int32, np.int64)), 'input number is not int: {}'.format(type(x))
    if x % 2 == 0:
        return x
    return x + (2 - x % 2)

def get_snwe(wsen, min_buffer=0.5, multi_1=True):
    # get bounding box
    lon0, lat0, lon1, lat1 = wsen
    # lat/lon0/1 --> SNWE
    S = np.floor(min(lat0, lat1) - min_buffer).astype(int)
    N = np.ceil( max(lat0, lat1) + min_buffer).astype(int)
    W = np.floor(min(lon0, lon1) - min_buffer).astype(int)
    E = np.ceil( max(lon0, lon1) + min_buffer).astype(int)

    # SNWE in multiple of 5
    if multi_1:
        S = floor_to_2(S)
        W = floor_to_2(W)
        N = ceil_to_2(N)
        E = ceil_to_2(E)
    return (S, N, W, E)


def get_fname_list(date_list,area,hr):
    
    flist = []
    for k0 in date_list:
        f0 = 'ERA-5{}_{}_{}.grb'.format(area, k0, hr)
        flist.append(f0)
        
    return flist
    
def snwe2str(snwe):
    """Get area extent in string"""
    if not snwe:
        return None

    area = ''
    s, n, w, e = snwe

    if s < 0:
        area += '_S{}'.format(abs(s))
    else:
        area += '_N{}'.format(abs(s))

    if n < 0:
        area += '_S{}'.format(abs(n))
    else:
        area += '_N{}'.format(abs(n))

    if w < 0:
        area += '_W{}'.format(abs(w))
    else:
        area += '_E{}'.format(abs(w))

    if e < 0:
        area += '_W{}'.format(abs(e))
    else:
        area += '_E{}'.format(abs(e))
    return area

def read_par_orb(slc_par):
    
    date0 = ut.read_gamma_par(slc_par,'read', 'date')
    date = date0.split(' ')[0] + date0.split(' ')[1] + date0.split(' ')[2]
    
    first_sar = ut.read_gamma_par(slc_par,'read', 'start_time')
    first_sar = str(float(first_sar.split('s')[0]))

    first_state = ut.read_gamma_par(slc_par,'read', 'time_of_first_state_vector')
    first_state = str(float(first_state.split('s')[0]))
    
    intv_state = ut.read_gamma_par(slc_par,'read', 'state_vector_interval')
    intv_state = str(float(intv_state.split('s')[0]))
    
    out0 = date + '_orbit0'
    out = date + '_orbit'
    
    if os.path.isfile(out0): 
        os.remove(out0)
    if os.path.isfile(out): 
        os.remove(out)
    
    with open(slc_par) as f:
        Lines = f.readlines()
    
    with open(out0, 'a') as fo:
        k0 = 0
        for Line in Lines:
            if 'state_vector_position' in Line:
                kk = float(first_state) + float(intv_state)*k0
                fo.write(str(kk - float(first_sar)) +' ' + Line)
                k0 = k0+1
                
    call_str = "awk '{print $1,$3,$4,$5}' " + out0 + " >" + out
    os.system(call_str)
    
    orb_data = ut.read_txt2array(out)
    #t_orb = Orb[:,0]
    #X_Orb = Orb[:,1]
    #Y_Orb = Orb[:,2]
    #Z_Orb = Orb[:,3]
    return orb_data

def remove_ramp(lat,lon,data):
    # mod = a*x + b*y + c*x*y
    lat = lat/180*np.pi
    lon = lon/180*np.pi  
    lon0 = lon*np.cos(lat) # to get isometrics coordinates
    
    p0 = [0.0001,0.0001,0.0001,0.0000001]
    plsq = leastsq(residual_trend,p0,args = (lat,lon0,data))
    para = plsq[0]
    data_trend = data - func_trend(lat,lon0,para)
    corr, _ = pearsonr(data, func_trend(lat,lon0,para))
    return data_trend, para, corr

def func_trend(lat,lon,p):
    a0,b0,c0,d0 = p
    
    return a0 + b0*lat + c0*lon +d0*lat*lon

def residual_trend(p,lat,lon,y0):
    a0,b0,c0,d0 = p 
    return y0 - func_trend(lat,lon,p)

def func_trend_model(lat,lon,p):
    lat = lat/180*np.pi
    lon = lon/180*np.pi  
    lon0 = lon*np.cos(lat) # to get isometrics coordinates
    a0,b0,c0,d0 = p
    
    return a0 + b0*lat + c0*lon0 +d0*lat*lon0
    
def kriging_levels(lonlos,latlos,hgtlvs,dwetlos,attr, maxdem, mindem, Rescale,kriging_points_numb):
    R = 6371 
    # Get interp grids    
    lonStep = lonlos[0,1,0] - lonlos[0,0,0]
    latStep = latlos[1,0,0] - latlos[0,0,0]
    
    minlon,maxlon,minlat,maxlat = ut.get_sar_area(attr)
    
    minlon = minlon - 0.5 #extend 0.5 degree
    maxlon = maxlon + 0.5
    minlat = minlat - 0.5
    maxlat = maxlat + 0.5
    
    lonStep = lonStep/Rescale
    latStep = latStep/Rescale
    
    lonv = np.arange(minlon,maxlon,lonStep)
    latv = np.arange(maxlat,minlat,latStep)
    
    #print(lonlos[:,:,5])
    #print(latlos[:,:,5])
    #print(minlon)
    #print(minlat)
    
    lonvv,latvv = np.meshgrid(lonv,latv)
    
    lonvv_all = lonvv.flatten()
    latvv_all = latvv.flatten()
    
    # Get index of the useful hgtlvs
    mindex,maxdex = ut.get_hgt_index(hgtlvs,mindem,maxdem)
    kl = np.arange(mindex,(maxdex+1))
    hgtuseful = hgtlvs[kl]
    row,col = lonvv.shape
    nh = len(kl)
    
    dwet_intp = np.zeros((row,col,nh))

    def resi_func(m,d,y):
        variogram_function =variogram_dict['spherical'] 
        return  y - variogram_function(m,d)
    
    for i in range(len(kl)):
        #print_progress(i+1, nh, prefix='Kriging wet-delay levels:', suffix='complete', decimals=1, barLength=50, elapsed_time=None)
        lat = latlos[:,:,kl[i]]
        lon = lonlos[:,:,kl[i]]
        lat = lat.flatten()
        lon = lon.flatten()
        
        dwet0 = dwetlos[:,:,kl[i]]
        dwet0 = dwet0.flatten()
        
        print(dwet0[0:20])
        print(np.mean(dwet0))
        dwet0_cor, para, corr= remove_ramp(lat,lon,dwet0)
        print(dwet0_cor[0:20])
        print(para)
        print(corr)
        trend = func_trend_model(latvv_all,lonvv_all,para)
        
        uk = OrdinaryKriging(lon, lat, dwet0_cor, coordinates_type = 'geographic', nlags=50)
        Semivariance_trend = 2*(uk.semivariance)    
        x0 = (uk.lags)/180*np.pi*R
        y0 = Semivariance_trend
        max_length = 2/3*max(x0)
        range0 = max_length/2
        LL0 = x0[x0< max_length]
        SS0 = y0[x0< max_length]
        sill0 = max(SS0)
        sill0 = sill0.tolist()

        p0 = [sill0, range0, 0.0001]    
        vari_func = variogram_dict['spherical']        
        tt, _ = leastsq(resi_func,p0,args = (LL0,SS0))   
        corr, _ = pearsonr(SS0, vari_func(tt,LL0))
        #print(tt)
        #print(corr)
        if tt[2] < 0:
            tt[2] =0
        para = tt
        para[1] = para[1]/R/np.pi*180

        uk.variogram_model_parameters = para
        z0,s0 = uk.execute('grid', lonv, latv, n_closest_points = kriging_points_numb, backend='loop')
        z0 = z0.flatten() + trend
        dwet_intp[:,:,i] = z0.reshape((row,col))

    return dwet_intp, lonvv, latvv, hgtuseful

###############################################################
    
geo_file = '/Users/caoy0a/Documents/SCRATCH/NewZealand/T175D/geometryRadar.h5'
meta = ut.read_attr(geo_file) 
lats = ut.read_hdf5(geo_file,datasetName='latitude')[0]
lons = ut.read_hdf5(geo_file,datasetName='longitude')[0]
heis = ut.read_hdf5(geo_file,datasetName='height')[0]    

slc_par = '/Users/caoy0a/Documents/SCRATCH/NewZealand/T175D/20190325.rslc.par'
ERA5_file = '/Users/caoy0a/Documents/SCRATCH/NewZealand/T175D/ERA-5_S43_S36_E174_E180_20190406_17.grb'

attr = ut.read_attr(geo_file)
orb_data = read_par_orb(slc_par)
date0 = ut.read_gamma_par(slc_par,'read', 'date')

start_sar = ut.read_gamma_par(slc_par,'read', 'start_time')
earth_R = ut.read_gamma_par(slc_par,'read', 'earth_radius_below_sensor')
end_sar = ut.read_gamma_par(slc_par,'read', 'end_time')
cent_time = ut.read_gamma_par(slc_par,'read', 'center_time')

attr['EARTH_RADIUS'] = str(float(earth_R.split('m')[0]))
attr['END_TIME'] = str(float(end_sar.split('s')[0]))
attr['START_TIME'] = str(float(start_sar.split('s')[0]))
attr['DATE'] = '20190325'
attr['CENTER_TIME'] = str(float(cent_time.split('s')[0]))
root_path = os.getcwd()

cdic = ut.initconst()
fname_list = glob.glob(era5_raw_dir + '/ERA*' + date0 + '*')    
w,s,e,n = get_meta_corner(attr)
wsen = (w,s,e,n)    
snwe = get_snwe(wsen, min_buffer=0.5, multi_1=True)
area = snwe2str(snwe)
hour = era5_time(attr['CENTER_TIME'])

lvls,latlist,lonlist,gph,tmp,vpr = ut.get_ecmwf('ERA5',ERA5_file,cdic, humidity='Q')
mean_lon = np.mean(lonlist.flatten())
#print(latlist)

if mean_lon > 180:    
    lonlist = lonlist - 360.0 # change to insar format lon [-180, 180]
lonlist[lonlist<0] = lonlist[lonlist<0] + 360
    
    
#print(lonlist)
    
lats0 = lats.flatten()
lons0 = lons.flatten()
mean_geoid = ut.get_geoid_point(np.nanmean(lats0),np.nanmean(lons0))
print('Average geoid height: ' + str(mean_geoid))
heis = heis + mean_geoid # geoid correct
heis0 = heis.flatten()

maxdem = max(heis0)
mindem = min(heis0)
if mindem < -200:
    mindem = -100.0

hh = heis0[~np.isnan(lats0)]
lala = lats0[~np.isnan(lats0)]
lolo = lons0[~np.isnan(lats0)]


sar_wet = np.zeros((heis.shape),dtype = np.float32)   
sar_wet0 = sar_wet.flatten()

sar_dry = np.zeros((heis.shape),dtype = np.float32)   
sar_dry0 = sar_dry.flatten()

hgtlvs = [ -200, 0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400
          ,2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000
          ,5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, 11000, 12000, 13000
          ,14000, 15000, 16000, 17000, 18000, 19000, 20000, 25000, 30000, 35000, 40000]

hgtlvs = np.asarray(hgtlvs)
Presi,Tempi,Vpri = ut.intP2H(lvls,hgtlvs,gph,tmp,vpr,cdic,verbose=False)

print(np.mean(lon_intp))

lon_intp[lon_intp<0] = lon_intp[lon_intp<0] + 360
print(lon_intp[:,:,5])
a=1
if a==1:
    # calc los coords
    #print('Start to calc LOS locations ...')
    #lat_intp, lon_intp, los_intp = ut.get_LOS3D_coords(latlist,lonlist,hgtlvs, orb_data, attr)
    LosP,LosT,LosV = ut.get_LOS_parameters(latlist,lonlist,Presi,Tempi,Vpri,lat_intp, lon_intp,'kriging',20)
    ddrylos,dwetlos = ut.losPTV2del(LosP,LosT,LosV,los_intp ,cdic,verbose=False)
    #print(dwetlos[:,:,5])
    # interp grid delays
    print('Start to interpolate delays ...')
    #dwet_intp,ddry_intp, lonvv, latvv, hgtuse = ut.pyrite_griddata_los(lon_intp,lat_intp,hgtlvs,ddrylos,dwetlos,attr, maxdem, mindem, 10,'kriging',20)
    
    dwet_intp, lonvv, latvv, hgtuse = kriging_levels(lon_intp,lat_intp,hgtlvs,dwetlos,attr,maxdem, mindem, 10,20)




Average geoid height: 18.303332084186863
177.0695003632868
[[174.01258213 174.26222268 174.51186797 174.76151724 175.01116979
  175.26082655 175.51048681 175.76014989 176.00981671 176.25948607
  176.50915886 176.75883443 177.00851216 177.2581929  177.50787555
  177.75756096 178.0072485  178.25693762 178.5066291  178.75632239
  179.00601691 179.25571345 179.50541102 179.75511039 180.00481101]
 [174.01246094 174.26210432 174.51175176 174.76140253 175.01105759
  175.26071568 175.51037772 175.76004303 176.00971094 176.25938235
  176.5090561  176.75873307 177.00841261 177.25809411 177.50777843
  177.75746496 178.00715309 178.25684367 178.50653564 178.75622982
  179.00592562 179.2556225  179.50532125 179.7550213  180.00472212]
 [174.0123407  174.26198572 174.51163528 174.76128863 175.01094507
  175.26060554 175.5102688  175.75993578 176.0096058  176.25927821
  176.5089539  176.75863223 177.00831258 177.2579958  177.5076808
  177.75736843 178.00705808 178.25674917 178.50644253 178.75613758
  

Calculating LOS parameters: [##################################################] 100.0%    complete

Start to interpolate delays ...
[0.21094543 0.20832632 0.20781608 0.20663493 0.20592126 0.20875947
 0.21083806 0.21339021 0.21708065 0.22238359 0.22672238 0.22895269
 0.22883842 0.22852062 0.22678698 0.22497802 0.22401275 0.2211308
 0.21679052 0.2117799 ]
0.16405986712535595
[ 0.02110848  0.01685412  0.01470863  0.01189224  0.00954332  0.01074629
  0.01118964  0.01210654  0.01416174  0.01782944  0.02053298  0.02112805
  0.01937853  0.0174255   0.01405661  0.01061241  0.00801189  0.0034947
 -0.00248082 -0.00912669]
[ -8.38086749 -11.82931022   3.60162559   4.9948934 ]
0.7456492154460718
[0.1876342  0.18601749 0.18597782 0.18497421 0.18420666 0.18783478
 0.19024055 0.19263548 0.19656385 0.20264132 0.20737971 0.20989612
 0.21008327 0.21001211 0.2085641  0.20690447 0.2060198  0.20305802
 0.1989289  0.1940293 ]
0.14929217935140557
[ 0.01640602  0.01316377  0.01149857  0.00886942  0.00647633 