Skip to content

Commit

Permalink
Changed samples_box to box
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasstolker committed Apr 15, 2021
1 parent 15d1091 commit 53ab9f4
Showing 1 changed file with 78 additions and 81 deletions.
159 changes: 78 additions & 81 deletions species/plot/plot_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,25 +188,25 @@ def plot_posterior(tag: str,

species_db = database.Database()

samples_box = species_db.get_samples(tag, burnin=burnin)
samples = samples_box.samples
box = species_db.get_samples(tag, burnin=burnin)
samples = box.samples

# index_sel = [0, 1, 8, 9, 14]
# samples = samples[:, index_sel]
#
# for i in range(13, 9, -1):
# del samples_box.parameters[i]
# del box.parameters[i]
#
# del samples_box.parameters[2]
# del samples_box.parameters[2]
# del samples_box.parameters[2]
# del samples_box.parameters[2]
# del samples_box.parameters[2]
# del samples_box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]

ndim = len(samples_box.parameters)
ndim = len(box.parameters)

if not inc_pt_param and samples_box.spectrum == 'petitradtrans':
if not inc_pt_param and box.spectrum == 'petitradtrans':
pt_param = ['tint', 't1', 't2', 't3', 'alpha', 'log_delta']

index_del = []
Expand All @@ -215,35 +215,35 @@ def plot_posterior(tag: str,
for i in range(100):
pt_item = f't{i}'

if pt_item in samples_box.parameters:
param_index = np.argwhere(np.array(samples_box.parameters) == pt_item)[0]
if pt_item in box.parameters:
param_index = np.argwhere(np.array(box.parameters) == pt_item)[0]
index_del.append(param_index)
item_del.append(pt_item)

else:
break

for item in pt_param:
if item in samples_box.parameters and item not in item_del:
param_index = np.argwhere(np.array(samples_box.parameters) == item)[0]
if item in box.parameters and item not in item_del:
param_index = np.argwhere(np.array(box.parameters) == item)[0]
index_del.append(param_index)
item_del.append(item)

samples = np.delete(samples, index_del, axis=1)
ndim -= len(index_del)

for item in item_del:
samples_box.parameters.remove(item)
box.parameters.remove(item)

if samples_box.attributes['chemistry'] == 'free':
samples_box.parameters.append('c_h_ratio')
samples_box.parameters.append('o_h_ratio')
samples_box.parameters.append('c_o_ratio')
if box.attributes['chemistry'] == 'free':
box.parameters.append('c_h_ratio')
box.parameters.append('o_h_ratio')
box.parameters.append('c_o_ratio')

ndim += 3

abund_index = {}
for i, item in enumerate(samples_box.parameters):
for i, item in enumerate(box.parameters):
if item == 'CH4':
abund_index['CH4'] = i

Expand Down Expand Up @@ -290,52 +290,52 @@ def plot_posterior(tag: str,
for i, item in enumerate(samples):
abund = {}

if 'CH4' in samples_box.parameters:
if 'CH4' in box.parameters:
abund['CH4'] = item[abund_index['CH4']]

if 'CO' in samples_box.parameters:
if 'CO' in box.parameters:
abund['CO'] = item[abund_index['CO']]

if 'CO_all_iso' in samples_box.parameters:
if 'CO_all_iso' in box.parameters:
abund['CO_all_iso'] = item[abund_index['CO_all_iso']]

if 'CO2' in samples_box.parameters:
if 'CO2' in box.parameters:
abund['CO2'] = item[abund_index['CO2']]

if 'FeH' in samples_box.parameters:
if 'FeH' in box.parameters:
abund['FeH'] = item[abund_index['FeH']]

if 'H2O' in samples_box.parameters:
if 'H2O' in box.parameters:
abund['H2O'] = item[abund_index['H2O']]

if 'H2S' in samples_box.parameters:
if 'H2S' in box.parameters:
abund['H2S'] = item[abund_index['H2S']]

if 'Na' in samples_box.parameters:
if 'Na' in box.parameters:
abund['Na'] = item[abund_index['Na']]

if 'NH3' in samples_box.parameters:
if 'NH3' in box.parameters:
abund['NH3'] = item[abund_index['NH3']]

if 'K' in samples_box.parameters:
if 'K' in box.parameters:
abund['K'] = item[abund_index['K']]

if 'PH3' in samples_box.parameters:
if 'PH3' in box.parameters:
abund['PH3'] = item[abund_index['PH3']]

if 'TiO' in samples_box.parameters:
if 'TiO' in box.parameters:
abund['TiO'] = item[abund_index['TiO']]

if 'VO' in samples_box.parameters:
if 'VO' in box.parameters:
abund['VO'] = item[abund_index['VO']]

if 'VO' in samples_box.parameters:
if 'VO' in box.parameters:
abund['VO'] = item[abund_index['VO']]

c_h_ratio[i], o_h_ratio[i], c_o_ratio[i] = retrieval_util.calc_metal_ratio(abund)

if vmr and samples_box.spectrum == 'petitradtrans' and \
samples_box.attributes['chemistry'] == 'free':
if vmr and box.spectrum == 'petitradtrans' and \
box.attributes['chemistry'] == 'free':
print('Changing mass fractions to number fractions...', end='', flush=True)

# Get all available line species
Expand All @@ -351,10 +351,10 @@ def plot_posterior(tag: str,
# Initiate a dictionary for the log10 mass fraction of the metals
log_x_abund = {}

for param_item in samples_box.parameters:
for param_item in box.parameters:
if param_item in line_species:
# Get the index of the parameter
param_index = samples_box.parameters.index(param_item)
param_index = box.parameters.index(param_item)

# Store log10 mass fraction in the dictionary
log_x_abund[param_item] = samples_item[param_index]
Expand All @@ -365,10 +365,10 @@ def plot_posterior(tag: str,
# Calculate the mean molecular weight from the input mass fractions
mmw = retrieval_util.mean_molecular_weight(x_abund)

for param_item in samples_box.parameters:
for param_item in box.parameters:
if param_item in line_species:
# Get the index of the parameter
param_index = samples_box.parameters.index(param_item)
param_index = box.parameters.index(param_item)

# Overwrite the sample with the log10 number fraction
samples_item[param_index] = np.log10(10.**samples_item[param_index] *
Expand All @@ -378,17 +378,14 @@ def plot_posterior(tag: str,
updated_samples[i, ] = samples_item

# Overwrite the samples in the SamplesBox
samples_box.samples = updated_samples
box.samples = updated_samples

print(' [DONE]')

print('Median sample:')
for key, value in box.median_sample.items():
print(f' - {key} = {value:.2e}')

samples = box.samples
ndim = samples.shape[-1]

if 'gauss_mean' in box.parameters:
param_index = np.argwhere(np.array(box.parameters) == 'gauss_mean')[0]
samples[:, param_index] *= 1e3 # (um) -> (nm)
Expand All @@ -404,52 +401,52 @@ def plot_posterior(tag: str,
for key, value in box.prob_sample.items():
print(f' - {key} = {value:.2e}')

for item in samples_box.parameters:
for item in box.parameters:
if item[0:11] == 'wavelength_':
param_index = samples_box.parameters.index(item)
param_index = box.parameters.index(item)

# (um) -> (nm)
samples_box.samples[:, param_index] *= 1e3
box.samples[:, param_index] *= 1e3

print(f'Plotting the posterior: {output}...', end='', flush=True)

if 'H2O' in samples_box.parameters:
if 'H2O' in box.parameters:
samples = np.column_stack((samples, c_h_ratio, o_h_ratio, c_o_ratio))

if inc_luminosity:
if 'teff' in samples_box.parameters and 'radius' in samples_box.parameters:
teff_index = np.argwhere(np.array(samples_box.parameters) == 'teff')[0]
radius_index = np.argwhere(np.array(samples_box.parameters) == 'radius')[0]
if 'teff' in box.parameters and 'radius' in box.parameters:
teff_index = np.argwhere(np.array(box.parameters) == 'teff')[0]
radius_index = np.argwhere(np.array(box.parameters) == 'radius')[0]

lum_planet = 4. * np.pi * (samples[..., radius_index]*constants.R_JUP)**2 * \
constants.SIGMA_SB * samples[..., teff_index]**4. / constants.L_SUN

if 'disk_teff' in samples_box.parameters and 'disk_radius' in samples_box.parameters:
teff_index = np.argwhere(np.array(samples_box.parameters) == 'disk_teff')[0]
radius_index = np.argwhere(np.array(samples_box.parameters) == 'disk_radius')[0]
if 'disk_teff' in box.parameters and 'disk_radius' in box.parameters:
teff_index = np.argwhere(np.array(box.parameters) == 'disk_teff')[0]
radius_index = np.argwhere(np.array(box.parameters) == 'disk_radius')[0]

lum_disk = 4. * np.pi * (samples[..., radius_index]*constants.R_JUP)**2 * \
constants.SIGMA_SB * samples[..., teff_index]**4. / constants.L_SUN

samples = np.append(samples, np.log10(lum_planet+lum_disk), axis=-1)
samples_box.parameters.append('luminosity')
box.parameters.append('luminosity')
ndim += 1

samples = np.append(samples, lum_disk/lum_planet, axis=-1)
samples_box.parameters.append('luminosity_disk_planet')
box.parameters.append('luminosity_disk_planet')
ndim += 1

else:
samples = np.append(samples, np.log10(lum_planet), axis=-1)
samples_box.parameters.append('luminosity')
box.parameters.append('luminosity')
ndim += 1

elif 'teff_0' in samples_box.parameters and 'radius_0' in samples_box.parameters:
elif 'teff_0' in box.parameters and 'radius_0' in box.parameters:
luminosity = 0.

for i in range(100):
teff_index = np.argwhere(np.array(samples_box.parameters) == f'teff_{i}')
radius_index = np.argwhere(np.array(samples_box.parameters) == f'radius_{i}')
teff_index = np.argwhere(np.array(box.parameters) == f'teff_{i}')
radius_index = np.argwhere(np.array(box.parameters) == f'radius_{i}')

if len(teff_index) > 0 and len(radius_index) > 0:
luminosity += 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
Expand All @@ -459,34 +456,34 @@ def plot_posterior(tag: str,
break

samples = np.append(samples, np.log10(luminosity), axis=-1)
samples_box.parameters.append('luminosity')
box.parameters.append('luminosity')
ndim += 1

# teff_index = np.argwhere(np.array(samples_box.parameters) == 'teff_0')
# radius_index = np.argwhere(np.array(samples_box.parameters) == 'radius_0')
# teff_index = np.argwhere(np.array(box.parameters) == 'teff_0')
# radius_index = np.argwhere(np.array(box.parameters) == 'radius_0')
#
# luminosity_0 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
# * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN
#
# samples = np.append(samples, np.log10(luminosity_0), axis=-1)
# samples_box.parameters.append('luminosity_0')
# box.parameters.append('luminosity_0')
# ndim += 1
#
# teff_index = np.argwhere(np.array(samples_box.parameters) == 'teff_1')
# radius_index = np.argwhere(np.array(samples_box.parameters) == 'radius_1')
# teff_index = np.argwhere(np.array(box.parameters) == 'teff_1')
# radius_index = np.argwhere(np.array(box.parameters) == 'radius_1')
#
# luminosity_1 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
# * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN
#
# samples = np.append(samples, np.log10(luminosity_1), axis=-1)
# samples_box.parameters.append('luminosity_1')
# box.parameters.append('luminosity_1')
# ndim += 1
#
# teff_index_0 = np.argwhere(np.array(samples_box.parameters) == 'teff_0')
# radius_index_0 = np.argwhere(np.array(samples_box.parameters) == 'radius_0')
# teff_index_0 = np.argwhere(np.array(box.parameters) == 'teff_0')
# radius_index_0 = np.argwhere(np.array(box.parameters) == 'radius_0')
#
# teff_index_1 = np.argwhere(np.array(samples_box.parameters) == 'teff_1')
# radius_index_1 = np.argwhere(np.array(samples_box.parameters) == 'radius_1')
# teff_index_1 = np.argwhere(np.array(box.parameters) == 'teff_1')
# radius_index_1 = np.argwhere(np.array(box.parameters) == 'radius_1')
#
# luminosity_0 = 4. * np.pi * (samples[..., radius_index_0[0]]*constants.R_JUP)**2 \
# * constants.SIGMA_SB * samples[..., teff_index_0[0]]**4. / constants.L_SUN
Expand All @@ -495,7 +492,7 @@ def plot_posterior(tag: str,
# * constants.SIGMA_SB * samples[..., teff_index_1[0]]**4. / constants.L_SUN
#
# samples = np.append(samples, np.log10(luminosity_0/luminosity_1), axis=-1)
# samples_box.parameters.append('luminosity_ratio')
# box.parameters.append('luminosity_ratio')
# ndim += 1

# r_tmp = samples[..., radius_index_0[0]]*constants.R_JUP
Expand All @@ -504,26 +501,26 @@ def plot_posterior(tag: str,
# m_mdot = (3600.*24.*365.25)*lum_diff*r_tmp/constants.GRAVITY/constants.M_JUP**2
#
# samples = np.append(samples, m_mdot, axis=-1)
# samples_box.parameters.append('m_mdot')
# box.parameters.append('m_mdot')
# ndim += 1

if inc_mass:
if 'logg' in samples_box.parameters and 'radius' in samples_box.parameters:
logg_index = np.argwhere(np.array(samples_box.parameters) == 'logg')[0]
radius_index = np.argwhere(np.array(samples_box.parameters) == 'radius')[0]
if 'logg' in box.parameters and 'radius' in box.parameters:
logg_index = np.argwhere(np.array(box.parameters) == 'logg')[0]
radius_index = np.argwhere(np.array(box.parameters) == 'radius')[0]

mass_samples = read_util.get_mass(samples[..., logg_index], samples[..., radius_index])

samples = np.append(samples, mass_samples, axis=-1)
samples_box.parameters.append('mass')
box.parameters.append('mass')
ndim += 1

else:
warnings.warn('Samples with the log(g) and radius are required for \'inc_mass=True\'.')

if inc_loglike:
# Get ln(L) of the samples
ln_prob = samples_box.ln_prob[..., np.newaxis]
ln_prob = box.ln_prob[..., np.newaxis]

# Normalized by the maximum ln(L)
ln_prob -= np.amax(ln_prob)
Expand All @@ -538,14 +535,14 @@ def plot_posterior(tag: str,
prob /= np.sum(prob)

samples = np.append(samples, np.log10(prob), axis=-1)
samples_box.parameters.append('log_prob')
box.parameters.append('log_prob')
ndim += 1

if isinstance(title_fmt, list) and len(title_fmt) != ndim:
raise ValueError(f'The number of items in the list of \'title_fmt\' ({len(title_fmt)}) is '
f'not equal to the number of dimensions of the samples ({ndim}).')

labels = plot_util.update_labels(samples_box.parameters)
labels = plot_util.update_labels(box.parameters)

# Check if parameter values were fixed

Expand Down

0 comments on commit 53ab9f4

Please sign in to comment.