In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import pySurrogate
import gwsurrogate
from gwsurrogate.new import spline_coef_evaluation
from gwsurrogate.new import surrogate

import tqdm
import sys, h5py

In [None]:
sur = gwsurrogate.EvaluateSurrogate?

In [None]:
# Load amp/phase surrogate
#sur = gwsurrogate.EvaluateSurrogate('../surrogate_downloads/SpEC_q1_10_NoSpin_nu5thDegPoly.h5')
#sur = gwsurrogate.EvaluateSurrogate('/home/balzani57/Downloads/SpEC_q1_10_NoSpin_nu5thDegPoly_exclude_2_0.h5')
sur = gwsurrogate.EvaluateSurrogate('/home/balzani57/Downloads/SpEC_q1_10_NoSpin_nu5thDegPoly_exclude_2_0.h5',excluded=None)

In [None]:
# Evaluate it, make sure we get the (2, 0) mode as well
lm_modes, t, hreal, himag = sur(1.2, mode_sum=False, fake_neg_modes=False)
print lm_modes

plt.plot(t, hreal[:, 2], 'k')
plt.plot(t, himag[:, 2], 'r--')
plt.plot(t, hreal[:, 0], 'c', lw=2)
plt.show()

In [None]:
# We first need a complex empirical interpolant for each mode.
# Let's first determine how many basis vectors are needed (22 should be sufficient...)

def mode_evaluations(ell, m, nq):
    print 'Evaluating the (%s, %s) mode for %s mass ratios'%(ell, m, nq)
    sys.stdout.flush()
    qvals = np.linspace(1., 10., nq)
    evals = []
    for i in tqdm.trange(nq):
        q = qvals[i]
        lm_modes, t, hreal, himag = sur(q, mode_sum=False, fake_neg_modes=False, ell=[ell], m=[m])
        evals.append(hreal + 1.j*himag)
    return np.array(evals)

def max_proj_err(resids):
    return np.sqrt(np.max([abs(r.dot(r.conjugate())) for r in resids]))

def get_basis_and_errs(evals, tol):
    print 'getting basis...'
    sys.stdout.flush()
    basis = pySurrogate.ei.buildPySurSVD(evals, tol)
    coefs = np.sum(evals[:, :, np.newaxis] * basis.T.conjugate(), 1)
    errs = [max_proj_err(evals)]
    print 'getting projection errors...'
    sys.stdout.flush()
    for n in tqdm.trange(1, len(basis)+1):
        resids = evals - coefs[:, :n].dot(basis[:n])
        errs.append(max_proj_err(resids))
    return basis, errs

In [None]:
# Try out building a basis for the (ell, emm) mode
modes = [(8,8),(2,2),(3,2)]

nTS = 100 # Should be enough

for mode in modes:
    ell = mode[0]
    emm   = mode[1]
    hlm = mode_evaluations(ell, emm, nTS)
    basis, errs = get_basis_and_errs(hlm, 1.e-10)
    plt.semilogy(range(len(errs)), errs, label='(%i, %i) mode'%(ell,emm))

plt.legend(frameon=False)
plt.show()

In [None]:
# We know that a basis size of 22 should be enough.
# We could improve the computational cost by using a smaller basis for some modes.
# We could also build one single (larger) basis for all modes, which can accelerate RapidPE.
nTS = 100
basis_tol = 1.e-10
n_basis = 22

empirical_interpolant_bases = {}
empirical_node_indices = {}

for i, (ell, m) in enumerate(lm_modes):
    mode_evals = mode_evaluations(ell, m, nTS)
    basis, errs = get_basis_and_errs(mode_evals, basis_tol)
    basis = basis[:n_basis]
    errs = errs[:n_basis]
    print 'Largest projection error of final basis is %s'%(errs[-1])
    print 'Getting empirical interpolant...'
    sys.stdout.flush()
    ei_basis, node_indices = pySurrogate.ei.buildPySurEI(basis)
    empirical_interpolant_bases[ell, m] = ei_basis
    empirical_node_indices[ell, m] = node_indices

In [None]:
# Now we need to build a spline for the real/imaginary parts of each empirical node.
# For simplicity we will use the same knots for all modes and nodes.

# Here's how we build a simple spline:
nQ = 10
qvals = np.linspace(1., 10., nQ)
yvals = np.sin(qvals)
spline = spline_coef_evaluation.UniformSpacingCubicSplineND((nQ,), origin=[1.], spacings=[9. / (nQ-1)])
spline.decompose() # Do the LU decomposition of the spline coefficient matrix
spline.solve(yvals) # Use the LU decomposition to obtain spline coefficients
coefs = spline.coefs

# And we can evaluate it:
ts_grid = surrogate.TensorSplineGrid([qvals])

def eval_spline(ts_grid, coefs, x):
    # Nudge to avoid being outside valid range
    if x - 10. > -1.e-12:
        x = 10. - 1.e-12
    if x < 1.e-12:
        x = 1.e-12
        
    imin_vals, eval_prods = ts_grid([x])
    i0 = imin_vals[0]
    return np.sum(coefs[i0: i0+4] * eval_prods)

qtest = np.linspace(1., 10., nQ * 10)
ytest = np.array([eval_spline(ts_grid, coefs, q) for q in qtest])
plt.plot(qvals, yvals, 'o', label='knots')
plt.plot(qtest, ytest, '--', label='spline interpolation')
plt.legend(frameon=False)
plt.show()


In [None]:
# Let's show that a simple spline converges well:

def simple_spline_errors(nQ, nQTest):
    qvals = np.linspace(1., 10., nQ)
    yvals = np.sin(qvals)
    spline = spline_coef_evaluation.UniformSpacingCubicSplineND((nQ,), origin=[1.], spacings=[9. / (nQ-1)])
    spline.decompose()
    spline.solve(yvals)
    coefs = spline.coefs
    ts_grid = surrogate.TensorSplineGrid([qvals])
    qtest = np.linspace(1., 10., nQTest)
    ytest = np.array([eval_spline(ts_grid, coefs, q) for q in qtest])
    return abs(ytest - np.sin(qtest))

nqvals = [5]
for i in range(10):
    nqvals.append(2 * nqvals[-1])
nQTest = nqvals[-1] * 2
qtest = np.linspace(1., 10., nQTest)
errs = []
for i in tqdm.trange(len(nqvals)):
    nq = nqvals[i]
    errs.append(simple_spline_errors(nq, nQTest))

for i, nq in enumerate(nqvals):
    plt.semilogy(qtest, errs[i], label='nQ = %s'%(nq))
plt.legend()
plt.show()

plt.loglog(nqvals, [np.max(err) for err in errs], 'o')
plt.show()

In [None]:
# Now let's build a whole spline surrogate

def get_mode_spline_coefs(ell, m, qvals, spline):
    node_indices = empirical_node_indices[ell, m]
    real_nodes = []
    imag_nodes = []
    for q in qvals:
        lm_modes, t, hreal, himag = sur(q, mode_sum=False, fake_neg_modes=False, ell=[ell], m=[m])
        real_nodes.append(hreal[node_indices])
        imag_nodes.append(himag[node_indices])
    
    real_nodes = np.array(real_nodes)
    imag_nodes = np.array(imag_nodes)
    real_coefs = []
    imag_coefs = []
    for i in range(len(node_indices)):
        spline.solve(real_nodes[:, i])
        real_coefs.append(np.copy(spline.coefs))
        spline.solve(imag_nodes[:, i])
        imag_coefs.append(np.copy(spline.coefs))
    return np.array(real_coefs), np.array(imag_coefs)

def build_spline_surrogate(nQ):
    qvals = np.linspace(1., 10., nQ)
    spline = spline_coef_evaluation.UniformSpacingCubicSplineND((nQ,), origin=[1.], spacings=[9. / (nQ-1)])
    spline.decompose()
    
    mode_data = {}
    for i in tqdm.trange(len(lm_modes)):
        ell, m = lm_modes[i]
        real_coefs, imag_coefs = get_mode_spline_coefs(ell, m, qvals, spline)
        mode_data[ell, m] = (empirical_interpolant_bases[ell, m], real_coefs, imag_coefs)
        
    spline_surrogate = surrogate.FastTensorSplineSurrogate(
            name = 'SpEC_1d_nonspinning_%s_spline_knots',
            domain = t,
            param_space = surrogate.ParamSpace('Nonspinning_q10', [surrogate.ParamDim('q', 1, 10)]),
            knot_vecs = [qvals],
            mode_data = mode_data,
            modes = lm_modes,
    )
    
    return spline_surrogate

def waveform_norm(h):
    return np.sqrt(np.sum(abs(h**2)))

def waveform_error(h1, h2):
    return waveform_norm(h1 - h2) / waveform_norm(h1)

def many_h_evals(qvals):
    h_evals = []
    for i in tqdm.trange(len(qvals)):
        q = qvals[i]
        _, t, hreal, himag = sur(q, mode_sum=False, fake_neg_modes=False)
        h_evals.append((hreal + 1.j*himag).T)
    return h_evals

def test_spline_surrogate(spline_surrogate, nqtest, h_evals):
    qtest = np.linspace(1., 10., nqtest)
    errs = []
    for i in tqdm.trange(len(qtest)):
        q = qtest[i]
        h = h_evals[i]
        spline_modes = spline_surrogate([q])
        h_spline = np.array([spline_modes[k] for k in lm_modes])
        errs.append(waveform_error(h, h_spline))
    return np.array(errs)



In [None]:
errs = []
nqs = [5, 10, 20, 40, 80]
nqtest = 150
qtest = np.linspace(1., 10., nqtest)
h_evals = many_h_evals(qtest)
for nq in nqs:
    spline_sur = build_spline_surrogate(nq)
    errs.append(test_spline_surrogate(spline_sur, nqtest, h_evals))
    print nq, np.max(errs[-1])
    sys.stdout.flush()

In [None]:
for nq, err in zip(nqs, errs):
    plt.semilogy(qtest, err, label='n knots = %s'%(nq))
plt.legend()
plt.show()

In [None]:
# Add a couple more (better) spline surrogates...
nqs = nqs + [160, 320]
for nq in [160, 320]:
    spline_sur = build_spline_surrogate(nq)
    errs.append(test_spline_surrogate(spline_sur, nqtest, h_evals))
    print nq, np.max(errs[-1])
    sys.stdout.flush()

In [None]:
for nq, err in zip(nqs, errs):
    plt.semilogy(qtest, err, label='n knots = %s'%(nq))
plt.legend()
plt.show()

In [None]:
# It looks like we reach the empirical interpolation error cutoff around 100 knots.
spline_sur = build_spline_surrogate(100)

In [None]:
%%timeit
spline_sur([1.2])

In [None]:
%%timeit
sur(1.2, mode_sum=False, fake_neg_modes=False)

In [None]:
# Save the spline surrogate:
spline_sur.save('SpEC_q10_nonspinning_spline_surrogate.h5')

In [None]:
# Load the spline surrogate:
loaded_surrogate = surrogate.FastTensorSplineSurrogate()
loaded_surrogate.load('SpEC_q10_nonspinning_spline_surrogate.h5')
h_modes = loaded_surrogate([np.pi])
_, t, hreal, himag = sur(q=np.pi, mode_sum=False, fake_neg_modes=False)

plt.plot(t, hreal[:, 2], 'k', label='Original amp/phase surrogate')
plt.plot(t, np.real(h_modes[2, 2]), 'r--', label='spline surrogate-of-a-surrogate')
plt.plot(t, abs(hreal[:, 2] - np.real(h_modes[2, 2])), 'c', label='error')
plt.legend(frameon=False, loc='upper left')
plt.show()

# save in old gwsurrogate format

In [None]:
spline_sur.ts_grid.breakpoint_vecs[0]

In [None]:
# Q: Where are EIM indicies? (how does mode get evaluated with ei and spline data?)
# other data needed for gwsurrogate?
# Q: data for parameter domain? 
# Q: data for affine map?
print spline_sur.param_space
print spline_sur.name
print spline_sur.mode_list
print spline_sur.mode_indices
print spline_sur.domain
print len(spline_sur.cre)
print len(spline_sur.cim)
print len(spline_sur.ei)
print spline_sur.ei[2].shape
print spline_sur.cre[2].shape
print spline_sur.ts_grid.breakpoint_vecs[0].shape
print spline_sur._h5_data_keys # what is this used for?
print spline_sur._h5_subordinate_keys # what is this used for?

In [None]:
spline_sur.__dict__.keys()

In [None]:
# Q: not exactly 100? OK?
# Q: not exactly dt=.1? OK?
print spline_sur.domain[0]
print spline_sur.domain[-1]
print spline_sur.domain[1] - spline_sur.domain[0]

In [None]:
mode = '(2, 2)'
mode_index = spline_sur.mode_indices[mode]


# Q: is this how to access (2,2) mode data ?
surrogate_mode_data = {}
surrogate_mode_data["surrogate_mode_type"] = 'waveform_basis'
surrogate_mode_data['parameterization']  = 'q_to_q'
surrogate_mode_data['affine_map'] = 'none'
surrogate_mode_data['t_units'] = 'TOverMtot'
surrogate_mode_data['B'] = spline_sur.ei[mode_index]
surrogate_mode_data['fit_min'] = 1.0
surrogate_mode_data['fit_max'] = 1.0


# Q: fit amp/phase instead of h? how was (2,0) mode fitted?
# Q: how does tmin/tmax/dt and times fit together? Always include just times at which basis is sampled? a bit more
# space but less confusing
#surrogate_mode_data['tmin'] = spline_sur.domain[-1]
#surrogate_mode_data['tmax'] = spline_sur.domain[0]
surrogate_mode_data['times'] = spline_sur.domain
#surrogate_mode_data['dt'] = spline_sur.domain[1] - spline_sur.domain[0]
#surrogate_mode_data['quadrature_weights'] =

# Q: data set has no eim_indices. 
#_eim_indices_grp        = 'eim_indices'       # .txt... text: eim_indices contains 2 vectors if amp/phase
#_eim_indices_phase_grp  = 'eim_indices_phase' # rolled into eim_indices

# _eim_amp_grp    = 'eim_amp' # text analog? only used to plot 
#  _eim_phase_grp  = 'eim_phase' # text analog? only used to plot

#  _fitparams_phase_grp      = 'fitparams_phase' #  .txt
#  _fit_type_phase_grp       = 'fit_type_phase'   # .txt
#  _fitparams_amp_grp        = 'fitparams_amp'  # .txt
#  _fit_type_amp_grp         = 'fit_type_amp' # .txt
#  _fit_type_norm_grp        = 'fit_type_norm' # .txt
#  _fitparams_norm_grp       = 'fitparams_norm' # .txt




In [None]:
from gwsurrogate import surrogateIO

In [None]:
writeh5 = surrogateIO.H5Surrogate('testsur.h5','w')

In [None]:
print writeh5.required
print writeh5.

In [None]:
writeh5.prepare_mode_data()