In [1]:
%load_ext autoreload

In [2]:
import numpy as np
import os
import pickle

from enterprise.pulsar import Pulsar
from enterprise.signals import parameter
from enterprise.signals import selections
from enterprise.signals import utils
from enterprise.signals import signal_base
from enterprise.signals import white_signals
from enterprise.signals import gp_signals
from enterprise.signals import deterministic_signals
import enterprise.constants as const

from utils import sample_utils as su
from PTMCMCSampler.PTMCMCSampler import PTSampler as ptmcmc

import matplotlib.pyplot as plt
%autoreload 2

In [3]:
# read in data pickles
with open('/home/pbaker/nanograv/data/nano11_DE436.pkl', "rb") as f:
    psrs = pickle.load(f)

with open('/home/pbaker/nanograv/data/nano11_setpars.pkl', "rb") as f:
    noise_params = pickle.load(f)

In [4]:
tmin = np.min([p.toas.min() for p in psrs])
tmax = np.max([p.toas.max() for p in psrs])
Tspan = tmax - tmin

### white noise

In [6]:
selection = selections.Selection(selections.by_backend)

efac = parameter.Constant()
equad = parameter.Constant()
ecorr = parameter.Constant()

ef = white_signals.MeasurementNoise(efac=efac, selection=selection)
eq = white_signals.EquadNoise(log10_equad=equad, selection=selection)
ec = white_signals.EcorrKernelNoise(log10_ecorr=ecorr, selection=selection)

wn = ef + eq + ec

### red noise

In [7]:
rn_log10_A = parameter.LinearExp(-20, -11)
rn_gamma = parameter.Uniform(0, 7)

rn_pl = utils.powerlaw(log10_A=rn_log10_A, gamma=rn_gamma)
rn = gp_signals.FourierBasisGP(rn_pl, components=30, Tspan=Tspan)

### bwm

In [8]:
amp_name = 'bwm_log10_A'
bwm_log10_A = parameter.LinearExp(-18, -11)(amp_name)

t0_name = 'bwm_t0'
pol_name = 'bwm_pol'
costh_name = 'bwm_costheta'
phi_name = 'bwm_phi'
t0 = parameter.Constant(56064.80)(t0_name)
pol = parameter.Constant(1.25)(pol_name)
phi = parameter.Constant(2.75)(phi_name)
costh = parameter.Constant(0.0)(costh_name)

bwm_wf = utils.bwm_delay(log10_h=bwm_log10_A, t0=t0,
                         cos_gwtheta=costh, gwphi=phi, gwpol=pol)
# BWM signal
bwm = deterministic_signals.Deterministic(bwm_wf, name='bwm')

### timing model

In [9]:
tm = gp_signals.TimingModel(use_svd=True)

### BayesEphem

In [10]:
be = deterministic_signals.PhysicalEphemerisSignal(use_epoch_toas=True)

### construct PTA

In [11]:
# construct PTA
mod = tm + wn + rn + bwm
mod += be

pta = signal_base.PTA([mod(p) for p in psrs])
pta.set_default_params(noise_params)

INFO: enterprise.signals.signal_base: Setting B1855+09_430_ASP_efac to 1.15823
INFO: enterprise.signals.signal_base: Setting B1855+09_430_PUPPI_efac to 1.12256
INFO: enterprise.signals.signal_base: Setting B1855+09_L-wide_ASP_efac to 1.08416
INFO: enterprise.signals.signal_base: Setting B1855+09_L-wide_PUPPI_efac to 1.37942
INFO: enterprise.signals.signal_base: Setting B1855+09_430_ASP_log10_equad to -7.56768
INFO: enterprise.signals.signal_base: Setting B1855+09_430_PUPPI_log10_equad to -6.17493
INFO: enterprise.signals.signal_base: Setting B1855+09_L-wide_ASP_log10_equad to -6.5087
INFO: enterprise.signals.signal_base: Setting B1855+09_L-wide_PUPPI_log10_equad to -6.53584
INFO: enterprise.signals.signal_base: Setting B1855+09_430_ASP_log10_ecorr to -8.15092
INFO: enterprise.signals.signal_base: Setting B1855+09_430_PUPPI_log10_ecorr to -6.29747
INFO: enterprise.signals.signal_base: Setting B1855+09_L-wide_ASP_log10_ecorr to -6.09581
INFO: enterprise.signals.signal_base: Setting B1855

INFO: enterprise.signals.signal_base: Setting J0645+5158_Rcvr_800_GUPPI_log10_equad to -7.0759
INFO: enterprise.signals.signal_base: Setting J0645+5158_Rcvr1_2_GUPPI_log10_ecorr to -7.83345
INFO: enterprise.signals.signal_base: Setting J0645+5158_Rcvr_800_GUPPI_log10_ecorr to -7.15591
INFO: enterprise.signals.signal_base: Setting J1012+5307_Rcvr1_2_GASP_efac to 1.05466
INFO: enterprise.signals.signal_base: Setting J1012+5307_Rcvr1_2_GUPPI_efac to 1.0801
INFO: enterprise.signals.signal_base: Setting J1012+5307_Rcvr_800_GASP_efac to 1.14312
INFO: enterprise.signals.signal_base: Setting J1012+5307_Rcvr_800_GUPPI_efac to 1.1831
INFO: enterprise.signals.signal_base: Setting J1012+5307_Rcvr1_2_GASP_log10_equad to -6.54824
INFO: enterprise.signals.signal_base: Setting J1012+5307_Rcvr1_2_GUPPI_log10_equad to -6.49869
INFO: enterprise.signals.signal_base: Setting J1012+5307_Rcvr_800_GASP_log10_equad to -6.67014
INFO: enterprise.signals.signal_base: Setting J1012+5307_Rcvr_800_GUPPI_log10_equad 

INFO: enterprise.signals.signal_base: Setting J1713+0747_Rcvr1_2_GUPPI_efac to 1.03958
INFO: enterprise.signals.signal_base: Setting J1713+0747_Rcvr_800_GASP_efac to 1.13148
INFO: enterprise.signals.signal_base: Setting J1713+0747_Rcvr_800_GUPPI_efac to 1.06597
INFO: enterprise.signals.signal_base: Setting J1713+0747_S-wide_ASP_efac to 1.12037
INFO: enterprise.signals.signal_base: Setting J1713+0747_S-wide_PUPPI_efac to 1.1154
INFO: enterprise.signals.signal_base: Setting J1713+0747_L-wide_ASP_log10_equad to -7.55444
INFO: enterprise.signals.signal_base: Setting J1713+0747_L-wide_PUPPI_log10_equad to -7.90314
INFO: enterprise.signals.signal_base: Setting J1713+0747_Rcvr1_2_GASP_log10_equad to -7.31355
INFO: enterprise.signals.signal_base: Setting J1713+0747_Rcvr1_2_GUPPI_log10_equad to -8.03232
INFO: enterprise.signals.signal_base: Setting J1713+0747_Rcvr_800_GASP_log10_equad to -6.94987
INFO: enterprise.signals.signal_base: Setting J1713+0747_Rcvr_800_GUPPI_log10_equad to -7.13031
INF

INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr1_2_GUPPI_efac to 1.051
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr_800_GASP_efac to 0.989049
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr_800_GUPPI_efac to 1.04172
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr1_2_GASP_log10_equad to -7.44335
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr1_2_GUPPI_log10_equad to -8.07509
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr_800_GASP_log10_equad to -6.62237
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr_800_GUPPI_log10_equad to -7.33716
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr1_2_GASP_log10_ecorr to -8.47751
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr1_2_GUPPI_log10_ecorr to -7.12103
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr_800_GASP_log10_ecorr to -7.92153
INFO: enterprise.signals.signal_base: Setting J1909-3744_Rcvr_800_GUPPI_log

INFO: enterprise.signals.signal_base: Setting J2043+1711_L-wide_PUPPI_log10_ecorr to -8.48264
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr1_2_GASP_efac to 0.981934
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr1_2_GUPPI_efac to 1.1003
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr_800_GASP_efac to 1.40428
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr_800_GUPPI_efac to 1.13029
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr1_2_GASP_log10_equad to -5.69371
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr1_2_GUPPI_log10_equad to -6.3098
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr_800_GASP_log10_equad to -6.51163
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr_800_GUPPI_log10_equad to -6.28187
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr1_2_GASP_log10_ecorr to -6.0959
INFO: enterprise.signals.signal_base: Setting J2145-0750_Rcvr1_2_GUPPI_log10_ecorr to 

In [12]:
print(pta.summary())

Signal Name                              Signal Class                   no. Parameters      
B1855+09_linear_timing_model_svd         TimingModel                    0                   

params:
__________________________________________________________________________________________
B1855+09_efac                            MeasurementNoise               0                   

params:
B1855+09_430_ASP_efac:Constant=1.15823                                                    
B1855+09_430_PUPPI_efac:Constant=1.12256                                                  
B1855+09_L-wide_ASP_efac:Constant=1.08416                                                 
B1855+09_L-wide_PUPPI_efac:Constant=1.37942                                               
__________________________________________________________________________________________
B1855+09_equad                           EquadNoise                     0                   

params:
B1855+09_430_ASP_log10_equad:Constant=-7.56768         

### sampler

In [13]:
x0 = np.hstack([noise_params[p.name] if p.name in noise_params.keys()
                else p.sample() for p in pta.params])  # initial point
ndim = len(x0)

# initial jump covariance matrix
# set initial cov stdev to (starting order of magnitude)/10
stdev = np.array([10**np.floor(np.log10(abs(x)))/10 for x in x0])
cov = np.diag(stdev**2)

In [37]:
for p in pta.param_names:
    if "jup_" in p:
        x0[pta.param_names.index(p)] =0

In [39]:
# generate custom sampling groups
groups = [list(range(ndim))]  # all params

# pulsar noise groups (RN)
for psr in psrs:
    this_group = [pta.param_names.index(par)
                  for par in pta.param_names if psr.name in par]
    groups.append(this_group)

# bwm params
this_group = [pta.param_names.index(par)
              for par in pta.param_names if 'bwm_' in par]
# duplicate BWM group for more proposals
for ii in range(5):
    groups.append(this_group)

# all BE params
BE_group = [pta.param_names.index(par)
              for par in pta.param_names
              if 'jup_orb' in par or 'mass' in par or 'frame_drift' in par]
groups.append(BE_group)

# jup_orb elements + GWs
this_group = [pta.param_names.index(par)
              for par in pta.param_names if 'jup_orb' in par]
this_group += [pta.param_names.index(par)
               for par in pta.param_names if 'bwm_' in par]
groups.append(this_group)

In [40]:
outdir = '/home/pbaker/nanograv/bwm/ULvT/py3_test'

sampler = ptmcmc(ndim, pta.get_lnlikelihood, pta.get_lnprior, cov, groups=groups,
                 outDir=outdir, resume=False)

In [41]:
sumfile = os.path.join(outdir, 'summary.txt')
with open(sumfile, 'w') as f:
    f.write(pta.summary())

outfile = os.path.join(outdir, 'params.txt')
with open(outfile, 'w') as f:
    for pname in pta.param_names:
        f.write(pname+'\n')

In [42]:
# additional proposals
full_prior = su.build_prior_draw(pta, pta.param_names, name='full_prior')
sampler.addProposalToCycle(full_prior, 1)

GWA_loguni = su.build_loguni_draw(pta, 'bwm_log10_A', (-18,-11), name='GWA_loguni')
sampler.addProposalToCycle(GWA_loguni, 5)

# RN empirical
from utils.sample_utils import EmpiricalDistribution2D
with open("/home/pbaker/nanograv/data/nano11_RNdistr.pkl", "rb") as f:
    distr = pickle.load(f)
Non4 = len(distr) // 4
RN_emp = su.EmpDistrDraw(distr, pta.param_names, Nmax=Non4, name='RN_empirical')
sampler.addProposalToCycle(RN_emp, 10)

# BE w/ jup KDE
BE_params = [pta.param_names[ii] for ii in BE_group]
BE_prior = su.build_prior_draw(pta, BE_params, name='BE_prior')
sampler.addProposalToCycle(BE_prior, 5)

from scipy.stats import gaussian_kde
with open("/home/pbaker/nanograv/data/nano11_jup_kde.pkl", "rb") as f:
    jup_kde = pickle.load(f)
BE_kde = su.JupOrb_KDE_Draw(jup_kde, pta.param_names, 'jup_kde')
sampler.addProposalToCycle(BE_kde, 5)

### test fancy proposals

In [49]:
y, lqxy = BE_kde(x0, 1, 1)
print(x0-y, lqxy)

[ 0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.01039054 -0.00902458  0.00271835  0.00152535
 -0.00506367 -0.00974308] -1.6310071184107606


### Sample!

In [50]:
thin = 50
Nsamp = 5000 * 50
sampler.sample(x0, Nsamp,
               SCAMweight=30, AMweight=20, DEweight=50,
               burn=int(5e4), thin=thin)

  logpdf = np.log(self.prior(value, **kwargs))


Finished 7.20 percent in 8653.250625 s Acceptance rate = 0.299882

KeyboardInterrupt: 