In [1]:
import msrvxc
import numpy as np
import pickle

import matplotlib.pyplot as plt
from tqdm import tqdm

from astroquery.sdss import SDSS
from astroquery.gaia import Gaia

KeyboardInterrupt: 

In [None]:
wvl, raw_values, interp_bosz, interp_bosz_norm = msrvxc.build.build_bosz_grid(wl_range = [3600, 9000])

In [None]:
with open('grids/interp_bosz_norm.pkl', 'wb') as f:
    pickle.dump(interp_bosz_norm, f)
    
with open('grids/interp_bosz.pkl', 'wb') as f:
    pickle.dump(interp_bosz, f)
    
np.save('grids/wavl.npy',wvl)

In [None]:
with open('grids/interp_bosz.pkl', 'rb') as f:
    interp_bosz = pickle.load(f)
    
with open('grids/interp_bosz_norm.pkl', 'rb') as f:
    interp_bosz_norm = pickle.load(f)
    
wvl = np.linspace(3600, 9000, 23074)

In [None]:
#params = (4999.848788360524,
#3.812724027560155,
#0.4084144628961131,
#0.5593886093528815)
#
#plt.plot(wvl, interp_bosz_norm(params), alpha = 0.3)
#plt.plot(wvl, interp_bosz((4213, 4.6, -2.3, 0)), alpha = 0.3)

In [None]:
#import msrvxc
#
#
#sampler, nwalkers, ndim = msrvxc.fit.fit_rv(interp_bosz_norm, wvl, interp_bosz_norm((4213, 4.1, -2.3, 100)))

In [None]:
#fig, axes = plt.subplots(4, figsize=(10, 7), sharex=True)
#samples = sampler.get_chain()
#labels = ["Teff", "logg", "Z", "rv"]
#for i in range(ndim):
#    ax = axes[i]
#    ax.plot(samples[:, :, i], "k", alpha=0.3)
#    ax.set_xlim(0, len(samples))
#    ax.set_ylabel(labels[i])
#    ax.yaxis.set_label_coords(-0.1, 0.5)
#
#axes[-1].set_xlabel("step number");

In [None]:
#flat_samples = sampler.get_chain(discard=100, thin=15, flat=True)
#print(flat_samples.shape)

In [None]:
#import corner
#
#fig = corner.corner(
#    flat_samples, labels=labels, truths=[4213, 4.1, -2.3, 0]
#);

In [None]:
#X = np.median(flat_samples,axis=0)
#T, logg, Z, rv = X[:4]
#
#print(T, logg, Z, rv)

In [None]:
GAIA_ADQL = """SELECT gaia.radial_velocity, gaia.radial_velocity_error, gaia.rv_template_teff, gaia.rv_template_logg, sdss.original_ext_source_id as bestobjid
    FROM gaiadr3.gaia_source as gaia
    JOIN gaiaedr3.sdssdr13_best_neighbour as sdss
    ON gaia.source_id = sdss.source_id      
    WHERE gaia.radial_velocity BETWEEN -100 AND 100 AND
    gaia.rv_template_teff BETWEEN 3500 AND 7000 AND
    gaia.rv_template_logg BETWEEN 2.5 AND 5"""

job1 = Gaia.launch_job(GAIA_ADQL,dump_to_file=False)
d1 = job1.get_results()

In [None]:
from astropy.table import Table, vstack, hstack


stardats = []
iters = len(d1) // 100

for i in tqdm(range(iters)):
    SDSS_QUERY = """select bestObjID as bestobjid, plate, mjd, fiberID, subClass
        from dbo.SpecObjAll
        where bestObjID in {}""".format(tuple(d1['bestobjid'][100*i:100*i+100]))
    try:
        f = SDSS.query_sql(SDSS_QUERY)
        if f is not None:
            stardats.append(SDSS.query_sql(SDSS_QUERY))
    except:
        pass
spec = vstack(stardats)

In [None]:
from astropy.table import Table, join

data = join(spec, d1, keys = 'bestobjid')
data

In [None]:
i = 1

spec = SDSS.get_spectra(plate=data['plate'][i], fiberID=data['fiberID'][i], mjd=data['mjd'][i])[0]

In [None]:
wl = (10**spec[1].data['loglam'])
fl = (spec[1].data['flux'])
ivar = (spec[1].data['ivar'])

params = (data['rv_template_teff'][i], data['rv_template_logg'][i], -2.3, data['radial_velocity'][i])

plt.plot(wl, msrvxc.utils.continuum_normalize(wl, fl, avg_size = 1000)[1])
plt.plot(wvl, interp_bosz_norm(params), alpha = 0.3)

In [None]:
for i in range(len(wl)):
    if np.isnan(wl[i]):
        print(i)

In [None]:
sampler, nwalkers, ndim = msrvxc.fit.fit_rv(interp_bosz_norm, wl, msrvxc.utils.continuum_normalize(wl, fl)[1], ivar)

In [None]:
fig, axes = plt.subplots(4, figsize=(10, 7), sharex=True)
samples = sampler.get_chain()
labels = ["Teff", "logg", "Z", "rv"]
for i in range(ndim):
    ax = axes[i]
    ax.plot(samples[:, :, i], "k", alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])
    ax.yaxis.set_label_coords(-0.1, 0.5)

axes[-1].set_xlabel("step number");

In [None]:
flat_samples = sampler.get_chain(discard=100, thin=15, flat=True)
print(flat_samples.shape)

In [None]:
import corner

fig = corner.corner(
    flat_samples, labels=labels, truths=params
);

In [None]:
X = np.median(flat_samples,axis=0)
T, logg, Z, rv = X[:4]

print(len(interp_bosz_norm((T, logg, Z, rv))))

plt.plot(wl, fl*1e17, label='spectrum')
plt.plot(wvl, interp_bosz((T, logg, Z, rv)), alpha = 0.3, label = 'median parameters')
plt.legend()