In [28]:
from time import time
timing=False
if timing:
    start_tot=time()
import numpy as np
from tensorflow.keras.models import load_model
import joblib
import pickle 
from PIL import Image
import altair as alt
import pandas as pd

import matplotlib.pyplot as plt
# jupyters notebook Befehl zum direkten Anzeigen von Matplotlib Diagrammen
plt.rcParams['figure.figsize'] = (9, 6)
SMALL_SIZE = 15
MEDIUM_SIZE = 20
BIGGER_SIZE = 25
colormap={0:'red',1:'green'}
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rcParams['lines.linewidth'] = 1.5
plt.rcParams['axes.linewidth'] = 1.2
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True 
plt.rcParams['xtick.labelsize'] = plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['xtick.major.size'] = plt.rcParams['ytick.major.size'] = 7
plt.rcParams['xtick.minor.size'] = plt.rcParams['ytick.minor.size'] = 4
plt.rcParams['xtick.major.width'] = plt.rcParams['ytick.major.width'] = 1.6
plt.rcParams['font.size'] = 12


from scipy import interpolate

#loading observations
from PyAstronomy import pyasl
from observation import load_observations
from chi_computations import chi_window,chi_squared,chi_squared_reduced
from para_transform import para_to_parameterin
from rout import adjust_rout


#image icon
im = Image.open("icon.png")


name='single_45_rinlog' # what network to use
star_name='star_m-only_3' #what network for mass prediction of star
path_data='./data' #where is the downloaded data

input_file=False

knn_switch=False # using Knn
Knn_factor=2 #parameter to adjust the uncertainties

observe=False
residual=False
chi_on=False
chi_mode=  'DIANA' #'squared'   #'squared_reduced'#     
loglike=False
dereddeningdata=False
folder_observation='./Example_observation/DNTau'
file_name='SED_to_fit.dat' 
write_parameterin=False
calc_mdisk=False
    
if timing:
    start=time()
#load NN
scaler=joblib.load(f'{path_data}/scaler/{name}_para_scaler.save')
y_scaler=joblib.load(f'{path_data}/scaler/{name}_sed_scaler.save')
model_saved=load_model(f'{path_data}/NeuralNets/{name}.h5')
    
header_start=np.load(f'{path_data}/header.npy')
header_start=np.concatenate((header_start,['incl']),axis=0)

delete_derived_paras=True
if delete_derived_paras:
    list_derived=['Mstar', 'amC-Zubko[s]', 'Rout']
    len_new_header=len(header_start)-len(list_derived)
    new_header_1=[]
    i_list=[]
    for i in range(len(header_start)):
        if header_start[i] not in list_derived:
            new_header_1.append(header_start[i])
            i_list.append(i)
    header=np.asarray(new_header_1)
txt=str()
with open(f'{path_data}/wavelength.out','r') as f:
    lines=f.readlines()
for line in lines[1:]:
    
    txt=txt+line.strip()+' '  
txt=txt[1:-2].split()
wavelength=np.array(txt,'float64')
  
    
model_star=load_model(f'{path_data}/StarNets/{star_name}.h5')
input_scaler=joblib.load(f'{path_data}/scaler/{star_name}_input_scaler.save')
output_scaler=joblib.load(f'{path_data}/scaler/{star_name}_output_scaler.save')

def calculate_mstar(Teff,Lstar):
    trans_star=input_scaler.transform(np.expand_dims([Teff,Lstar],axis=0))
    pred_star=model_star(trans_star)
    log_mass=output_scaler.inverse_transform(pred_star)[0,0]
    
    return log_mass

def angle_to_mcfost_val(angle):
    rad=angle*np.pi/180
    cosx=np.cos(rad)
    mcfost_incl=10-(cosx-0.05)/0.1    
    return mcfost_incl


def transform_parameter(name,val):
    dummy=np.zeros((1,len(header)))
##    if name in name_log:
#        val=np.log10(val)
    if name=='incl':
        val=float(val)
        val=angle_to_mcfost_val(val)
    if name in header:
        pos=np.where(header==name)[0][0]
        dummy[0,pos]=val
        val=scaler.transform(dummy)[0,pos]
        return val,pos
slider_dict={
    'Mstar':{
        'label':r'$log_{10}(M_{star}) [M_{sun}]$',
        'lims':[-0.69, 0.39],
        'x0':0.06,
        'priority':1}
        ,
    
    'Teff':{
        'label':r'$log_{10}(T_{eff})$',
        'lims':[3.5, 4.0], 
        'x0':3.69,
        'priority':1},
    
    'Lstar':{
        'label':r'$log_{10}(L_{star})$',
        'lims':[-1.3, 1.7],
        'x0':0.79,
        'priority':1}, 
    'fUV':{
        'label':r'$log_{10}(fUV)$',
        'lims':[-3, -1],
        'x0':-1.57, 
        'priority':1},
    
    'pUV':{
        'label':r'$log_{10}(pUV)$',
        'lims':[-0.3, 0.39],
        'x0':-0.02, 
        'priority':1},
    
    'Mdisk':{
        'label':r'$log_{10}(M_{disk})$',
        'lims':[-5, 0],
        'x0':-1.367, 
        'priority':2},
    
    'incl':{
        'label':r'$incl [Deg]$',
        'lims':[0.0, 90.0],
        'x0':20.0,
        'priority':2},
    
    'Rin':{
        'label':r'$log_{10}(R_{in}[AU])$',
        'lims':[-2.00, 2.00], 
        'x0':-1.34,
        'priority':2},
   
     'Rtaper':{
        'label':r'$log_{10}(R_{taper}[AU])$',
        'lims':[0.7, 2.5],
         'x0':1.95, 
        'priority':2},
    
    'Rout':{
        'label':r'$log_{10}(R_{out}[AU])$',
        'lims':[1.3, 3.14],
        'x0':2.556, 
        'priority':2},
    
    'epsilon':{
        'label':r'$\epsilon$',
        'lims':[0, 2.5],
        'x0':1, 
        'priority':2},
    
    'MCFOST_BETA':{
        'label':r'$\beta$',
        'lims':[0.9, 1.4],
        'x0':1.15, 
        'priority':2},
    
    'MCFOST_H0':{
        'label':'H_0[AU]',
        'lims':[3, 35],
        'x0':12, 
        'priority':2},    
    
    'a_settle':{
        'label':r'$log_{10}(a_{settle})$',
        'lims':[-5, -1],
        'x0':-3, 
        'priority':3},
    
    'amin':{
        'label':r'$log_{10}(a_{min})$',
        'lims':[-3, -1],
        'x0':-1.5, 
        'priority':3},
    
    
    'amax':{
        'label':r'$log_{10}(a_{max})$',
        'lims':[2.48, 4],
        'x0':3.6, 
        'priority':3},
    
    'apow':{
        'label':r'$a_{pow}$',
        'lims':[3, 5],
        'x0':3.6, 
        'priority':3},
    
    'Mg0.7Fe0.3SiO3[s]':{
        'label':r'Mg0.7Fe0.3SiO3[s]',
        'lims':[0.45, 0.7],
        'x0':0.57, 
        'priority':3},
    
    'amC-Zubko[s]':{
        'label':r'amC-Zubko[s]',
        'lims':[0.05, 0.3],
        'x0':0.18, 
        'priority':3},
    
    'fPAH':{
        'label':r'$log_{10}(fPAH)$',
        'lims':[-3.5, 0],
        'x0':-1.5, 
        'priority':3},
    
    'PAH_charged':{
        'label':r'PAH_charged',
        'lims':[0, 1], 
        'priority':3},
}
    
log_dict={'Mstar': 'log', 'Lstar': 'log', 'Teff': 'log', 'fUV': 'log', 'pUV': 'log', 'amin': 'log', 'amax': 'log',
      'apow': 'linear', 'a_settle': 'log', 'Mg0.7Fe0.3SiO3[s]': 'linear', 'amC-Zubko[s]': 'linear', 'fPAH': 'log',
   'PAH_charged': 'linear', 'Mdisk': 'log', 'Rin': 'log', 'Rtaper': 'log', 'Rout': 'log', 'epsilon': 'linear',
   'MCFOST_H0': 'linear', 'MCFOST_BETA': 'linear', 'incl': 'linear'}#,'Dist[pc]':'linear'}
for key in log_dict:
    slider_dict[key]['scale']=log_dict[key]
    
for key in slider_dict:
    if slider_dict[key]['scale']=='log':
        if 'log' in slider_dict[key]['label']:
            print(slider_dict[key]['label']+': fine')
        else:
            slider_dict[key]['label']='$log('+slider_dict[key]['label'][1:-1]+')$'
            low=slider_dict[key]['lims'][0]
            high=slider_dict[key]['lims'][1]
            slider_dict[key]['lims']=[np.log10(low),np.log10(high)]            
dist_start=100
if input_file:
    with open('Para.in','r') as f:
        lines=f.readlines()
        for line in lines:
            split_line=line.split()
            value=float(split_line[0])
            parameter=split_line[1]
            if parameter in slider_dict.keys():
                if slider_dict[parameter]['scale']=='log':
                    value=np.log10(value)
                slider_dict[parameter]['x0']=value
                
            else:
                if parameter=='Dist[pc]':
                    dist_start=value
                if parameter=='E(B-V)':
                    e_bvstart=value
                    
                if parameter=='R(V)':
                    R_Vstart=value    
def change_range(ar_val):
    ax.set_xlim(10**(ar_val))
    if residual:
        ax_res.set_xlim(10**ar_val)
def change_flux(ar_val):
    ax.set_ylim(10**(ar_val))
        

def change_dist(dist,data):
    new_data=data*(100/dist)**2
    return new_data
    
def reddening( lam,flux, e_bv, R_V):
        # lam in mu m 
        fluxRed = pyasl.unred(lam*10**4, flux, ebv=-e_bv, R_V=R_V)
        return fluxRed

def spline(lam,nuflux,new_lam):

    #interpolation on a double logarithmic scale
    s=interpolate.InterpolatedUnivariateSpline(np.log10(lam),np.log10(nuflux))
    interp=10**(s(np.log10(new_lam)))
#    return interp #returning the nex SED
    return interp #for visualisation 
if timing:
    end=time()
    loading_time=end-start
    
def main():
    should_tell_me_more = False
    
    if should_tell_me_more:
        tell_me_more()
        
    else:
        #st.write(chi_on)
        
        folder_observation='./Example_observation/DNTau'
        file_name='SED_to_fit.dat' 
        if observe:
            lam_obs,flux_obs,sig_obs,filer_names,e_bvstart,R_Vstart=load_observations(folder_observation,file_name,dereddening_data=False)
        else:
            e_bvstart=0.1
            R_Vstart=3.1

        if timing:
            start=time()

        questions = {
        'Complexity': ['Single zone', 'Two-zone'], 
        'Two-zone flavor': ['discontinues', 'continues','smooth'],
        'Input version':['Slider','Text only']}
        color_list=['bisque','lightsteelblue', 'lightgreen','lightgoldenrodyellow','skyblue','lightgrey']

        #fig,ax = plt.figure(figsize=(9,9))
       


        #distance
        dist_start=100
       
        features=np.zeros((1,len(header)))
        if timing:
            start2=time()
        for key in header:
            #print(key)name=slider_dict[key]['label']
            mini,maxi=slider_dict[key]['lims']
            name=slider_dict[key]['label']
            try:
                value=float(slider_dict[key]['x0'])
            except:
                value=float((maxi+mini)/2)
            if timing:
                end=time()
                sidebar_time=end-start

            if calc_mdisk:
                if key=='Mdisk':
                    init_mass=10**middle
            val_trans, pos=transform_parameter(key,value)
            features[0,pos]=val_trans
            if key=='amC-Zubko[s]':
                middle_sio=0.75-middle
                val_trans_sio, pos_sio=transform_parameter('Mg0.7Fe0.3SiO3[s]',middle_sio)
                features[0,pos_sio]=val_trans_sio
             
            #print(val_trans)  
        #print(features)
        if timing:
            end2=time()
            para_time=end2-start2
            start=time()
        data=10**(y_scaler.inverse_transform(model_saved(features)))[0]
        data=change_dist(dist_start,data)
        if not dereddeningdata:
            #reddening
            data=reddening(wavelength,data,e_bvstart,R_Vstart)
        if timing:
            end=time()
            pred_time=end-start
        str_title=''
        if knn_switch:
            error_knn=knn.predict(features)
            higher, =ax.plot(wavelength,10**(np.log10(data)+Knn_factor*error_knn[0]),color='grey',alpha=1)
            lower, =ax.plot(wavelength,10**(np.log10(data)-Knn_factor*error_knn[0]),color='grey',alpha=1)
            dist,neighbor_ar=knn.kneighbors(features)
            min_dist,mean_dist,max_dist=np.min(dist)/min_sample,np.mean(dist)/mean_sample,np.max(dist)/max_sample
            txt=''
            if mean_dist<=1.0:
                colortitle='tab:green'
            elif 1.0<=mean_dist<=2.0:
                colortitle='tab:orange'
            elif 1.5<=mean_dist:
                colortitle='tab:red'
                txt='Warning!! Few models! '
            str_title=txt+'Distance to neighbors (average=1): Minimum %4.2f, Mean %4.2f, Maximum %4.2f' %(min_dist,mean_dist,max_dist)
            ax.set_title(str_title,color=colortitle,fontsize=12)
        if calc_mdisk:
            m_disk=calc_mass(data)
            rat_mass=m_disk/init_mass
            str_title_new=str_title+'\n '+r'$M_{disk,calc}: %8.2e ,  M_{calc}/M_{model}: %8.2e $' %(m_disk,rat_mass)
            ax.set_title(str_title_new,fontsize=12)
        if timing:
            start=time()
        t=wavelength
        s = data
        
        df_array=np.concatenate((np.expand_dims(wavelength,axis=0),np.expand_dims(data,axis=0)),axis=0).T
        
        df=pd.DataFrame(df_array,columns=['lambda','SED'])
        print(df.shape)
        print(df)
        alt.Chart(df).mark_point().encode(
            x=alt.X('lambda',scale=alt.Scale(type="log")),y=alt.Y('SED',scale=alt.Scale(type="log")),
                tooltip=[alt.Tooltip('lam', title=r'$ \nu F_\nu [erg/cm^2/s]$'),
                alt.Tooltip("SED", title="Price (USD)"),]
        )
        
        
        

$log_{10}(M_{star}) [M_{sun}]$: fine
$log_{10}(T_{eff})$: fine
$log_{10}(L_{star})$: fine
$log_{10}(fUV)$: fine
$log_{10}(pUV)$: fine
$log_{10}(M_{disk})$: fine
$log_{10}(R_{in}[AU])$: fine
$log_{10}(R_{taper}[AU])$: fine
$log_{10}(R_{out}[AU])$: fine
$log_{10}(a_{settle})$: fine
$log_{10}(a_{min})$: fine
$log_{10}(a_{max})$: fine
$log_{10}(fPAH)$: fine


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


In [29]:
if __name__ == '__main__':


    main()

(140, 2)
           lambda           SED
0        0.092705  3.197110e-11
1        0.095792  3.905116e-11
2        0.098980  4.726113e-11
3        0.102275  5.610754e-11
4        0.105679  6.568645e-11
..            ...           ...
135   6320.582160  5.192181e-15
136   7419.147930  2.696589e-15
137   8708.652860  1.370704e-15
138  10222.283700  6.910858e-16
139  12000.000000  3.402165e-16

[140 rows x 2 columns]
