In [1]:
import numpy as np
from tensorflow.keras.models import load_model
import joblib
import pickle 

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, TextBox, RangeSlider
%matplotlib qt
# jupyters notebook Befehl zum direkten Anzeigen von Matplotlib Diagrammen
plt.rcParams['figure.figsize'] = (9, 6)
SMALL_SIZE = 15
MEDIUM_SIZE = 20
BIGGER_SIZE = 25
colormap={0:'red',1:'green'}
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rcParams['lines.linewidth'] = 1.5
plt.rcParams['axes.linewidth'] = 1.2
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True 
plt.rcParams['xtick.labelsize'] = plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['xtick.major.size'] = plt.rcParams['ytick.major.size'] = 7
plt.rcParams['xtick.minor.size'] = plt.rcParams['ytick.minor.size'] = 4
plt.rcParams['xtick.major.width'] = plt.rcParams['ytick.major.width'] = 1.6
plt.rcParams['font.size'] = 12


In [2]:
name='single_23_wo_wrongall3' # what network to use
path_data='./data' #where is the downloaded data

input_file=True

knn_switch=True # using Knn
Knn_factor=2 #parameter to adjust the uncertainties

# loading data

In [3]:
scaler=joblib.load(f'{path_data}/scaler/{name}_para_scaler.save')
y_scaler=joblib.load(f'{path_data}/scaler/{name}_sed_scaler.save')

In [4]:
model_saved=load_model(f'{path_data}/NeuralNets/{name}.h5')

In [5]:
header=np.load(f'{path_data}/header.npy')
header=np.concatenate((header,['incl']),axis=0)
print(header)

['Mstar' 'Lstar' 'Teff' 'fUV' 'pUV' 'amin' 'amax' 'apow' 'a_settle'
 'Mg0.7Fe0.3SiO3[s]' 'amC-Zubko[s]' 'fPAH' 'PAH_charged' 'Mdisk' 'Rin'
 'Rtaper' 'Rout' 'epsilon' 'MCFOST_H0' 'MCFOST_BETA' 'incl']


In [6]:
txt=str()
with open(f'{path_data}/wavelength.out','r') as f:
    lines=f.readlines()
for line in lines[1:]:
    
    txt=txt+line.strip()+' '  
txt=txt[1:-2].split()
wavelength=np.array(txt,'float64')
  

In [7]:

if knn_switch:
    n_neighbors=20
    weights='distance'
    with open(f'{path_data}/Knns/{name}_{n_neighbors}_{weights}', 'br') as file_pi:
        knn=pickle.load(file_pi)
    dat='test'
    mean_sample=np.load(f'{path_data}/Knns/sample_values/{name}_{n_neighbors}_{weights}_{dat}_mean.npy')
    min_sample=np.load(f'{path_data}/Knns/sample_values/{name}_{n_neighbors}_{weights}_{dat}_min.npy')
    max_sample=np.load(f'{path_data}/Knns/sample_values/{name}_{n_neighbors}_{weights}_{dat}_max.npy')

# preparing handling of the data

In [8]:

def transform_parameter(name,val):
    dummy=np.zeros((1,len(header)))
##    if name in name_log:
#        val=np.log10(val)
    pos=np.where(header==name)[0][0]
    dummy[0,pos]=val
    val=scaler.transform(dummy)[0,pos]
    return val,pos

In [9]:
slider_dict={
    'Mstar':{
        'label':r'$log(M_{star}) [M_{sun}]$',
        'lims':[-0.69, 0.39],
        'x0':0.06,
        'priority':1}
        ,
    
    'Teff':{
        'label':r'$log(T_{eff})$',
        'lims':[3.5, 4.0], 
        'x0':3.69,
        'priority':1},
    
    'Lstar':{
        'label':r'$log(L_{star})$',
        'lims':[-1.3, 1.7],
        'x0':0.79,
        'priority':1}, 
    'fUV':{
        'label':r'$log(fUV)$',
        'lims':[-3, -1],
        'x0':-1.57, 
        'priority':1},
    
    'pUV':{
        'label':r'$log(pUV)$',
        'lims':[-0.3, 0.39],
        'x0':-0.02, 
        'priority':1},
    
    'Mdisk':{
        'label':r'$log(Mass_{disk})$',
        'lims':[-5, 0],
        'x0':-1.367, 
        'priority':2},
    
    'incl':{
        'label':r'$incl [Deg]$',
        'lims':[0, 9],
        'x0':2,
        'priority':2},
    
    'Rin':{
        'label':r'$log(R_{in}[AU])$',
        'lims':[-2.00, 2.00], 
        'x0':-1.34,
        'priority':2},
   
     'Rtaper':{
        'label':r'$log(R_{taper}[AU])$',
        'lims':[0.7, 2.5],
         'x0':1.95, 
        'priority':2},
    
    'Rout':{
        'label':r'$log(R_{out}[AU])$',
        'lims':[1.3, 3.14],
        'x0':2.556, 
        'priority':2},
    
    'epsilon':{
        'label':r'$\epsilon$',
        'lims':[0, 2.5],
        'x0':1, 
        'priority':2},
    
    'MCFOST_BETA':{
        'label':r'$\beta$',
        'lims':[0.9, 1.4],
        'x0':1.15, 
        'priority':2},
    
    'MCFOST_H0':{
        'label':'MCFOST_H0[AU]',
        'lims':[3, 35],
        'x0':12, 
        'priority':2},    
    
    'a_settle':{
        'label':r'$log(a_{settle})$',
        'lims':[-5, -1],
        'x0':-3, 
        'priority':3},
    
    'amin':{
        'label':r'$log(a_{min})$',
        'lims':[-3, -1],
        'x0':-1.5, 
        'priority':3},
    
    
    'amax':{
        'label':r'$log(a_{max})$',
        'lims':[2.48, 4],
        'x0':3.6, 
        'priority':3},
    
    'apow':{
        'label':r'$a_{pow}$',
        'lims':[3, 5],
        'x0':3.6, 
        'priority':3},

    
    'amC-Zubko[s]':{
        'label':r'amC-Zubko[s]',
        'lims':[0.05, 0.3],
        'x0':0.18, 
        'priority':3},
    
    'fPAH':{
        'label':r'$log(fPAH)$',
        'lims':[-3.5, 0],
        'x0':-1.5, 
        'priority':3},
    
    'PAH_charged':{
        'label':r'PAH_charged',
        'lims':[0, 1], 
        'priority':3},
}

In [10]:
'''    
    'Mg0.7Fe0.3SiO3[s]':{
        'label':r'Mg0.7Fe0.3SiO3[s]',
        'lims':[0.45, 0.7],
        'x0':0.57, 
        'priority':4},
        
'Mg0.7Fe0.3SiO3[s]': 'linear','

'''

"    \n    'Mg0.7Fe0.3SiO3[s]':{\n        'label':r'Mg0.7Fe0.3SiO3[s]',\n        'lims':[0.45, 0.7],\n        'x0':0.57, \n        'priority':4},\n        \n'Mg0.7Fe0.3SiO3[s]': 'linear','\n\n"

In [11]:
log_dict={'Mstar': 'log', 'Lstar': 'log', 'Teff': 'log', 'fUV': 'log', 'pUV': 'log', 'amin': 'log', 'amax': 'log',
          'apow': 'linear', 'a_settle': 'log',  'amC-Zubko[s]': 'linear', 'fPAH': 'log',
       'PAH_charged': 'linear', 'Mdisk': 'log', 'Rin': 'log', 'Rtaper': 'log', 'Rout': 'log', 'epsilon': 'linear',
       'MCFOST_H0': 'linear', 'MCFOST_BETA': 'linear', 'incl': 'linear'}

In [12]:
for key in log_dict:
    slider_dict[key]['scale']=log_dict[key]

In [13]:
for key in slider_dict:
    if slider_dict[key]['scale']=='log':
        if 'log' in slider_dict[key]['label']:
            print(slider_dict[key]['label']+': fine')
        else:
            slider_dict[key]['label']='$log('+slider_dict[key]['label'][1:-1]+')$'
            low=slider_dict[key]['lims'][0]
            high=slider_dict[key]['lims'][1]
            slider_dict[key]['lims']=[np.log10(low),np.log10(high)]            

$log(M_{star}) [M_{sun}]$: fine
$log(T_{eff})$: fine
$log(L_{star})$: fine
$log(fUV)$: fine
$log(pUV)$: fine
$log(Mass_{disk})$: fine
$log(R_{in}[AU])$: fine
$log(R_{taper}[AU])$: fine
$log(R_{out}[AU])$: fine
$log(a_{settle})$: fine
$log(a_{min})$: fine
$log(a_{max})$: fine
$log(fPAH)$: fine


In [14]:
for key in slider_dict:
    [down,up]=slider_dict[key]['lims']
    try:
        middle=slider_dict[key]['x0']
        #print('Found')
        #print(middle)
    except:
        middle=(up-down)/2+down
    if slider_dict[key]['scale']=='log':
        middle=10**middle
    print(f'%12.6e {key}'%middle)

1.148154e+00 Mstar
4.897788e+03 Teff
6.165950e+00 Lstar
2.691535e-02 fUV
9.549926e-01 pUV
4.295364e-02 Mdisk
2.000000e+00 incl
4.570882e-02 Rin
8.912509e+01 Rtaper
3.597493e+02 Rout
1.000000e+00 epsilon
1.150000e+00 MCFOST_BETA
1.200000e+01 MCFOST_H0
1.000000e-03 a_settle
3.162278e-02 amin
3.981072e+03 amax
3.600000e+00 apow
1.800000e-01 amC-Zubko[s]
3.162278e-02 fPAH
5.000000e-01 PAH_charged


In [21]:
if input_file:
    with open('Para.in','r') as f:
        lines=f.readlines()
        for line in lines:
            split_line=line.split()
            value=float(split_line[0])
            parameter=split_line[1]
            if parameter in slider_dict.keys():
                
                if slider_dict[parameter]['scale']=='log':
                    value=np.log10(value)
                slider_dict[parameter]['x0']=value
            else:
                print('Unknown parameter')
                print(parameter)

# plotting

In [23]:
def change_range(ar_val):
    ax.set_xlim(10**(ar_val))

def change_flux(ar_val):
    ax.set_ylim(10**(ar_val))

def update(val):
    #features=np.zeros((1,len(header)))
    for key in slider_dict:
        value=slider_dict[key]['slider'].val
        val_trans, pos=transform_parameter(key,value)
        #print(key,val_transans)
        features[0,pos]=val_trans
        if key=='amC-Zubko[s]':
            value_sio=0.75-value
            #print('Adjust SIO to: %5.3f' %value_sio)
            val_trans_sio, pos_sio=transform_parameter('Mg0.7Fe0.3SiO3[s]',value_sio)
            features[0,pos_sio]=val_trans_sio
        #print(key)
    #print(features)
    data=10**(y_scaler.inverse_transform(model_saved.predict(features)))[0]
    if knn_switch:  
        dist,neighbor_ar=knn.kneighbors(features)
        min_dist,mean_dist,max_dist=np.min(dist)/min_sample,np.mean(dist)/mean_sample,np.max(dist)/max_sample
        txt=''
        if mean_dist<=1.0:
            colortitle='tab:green'
        elif 1.0<=mean_dist<=2.0:
            colortitle='tab:orange'
        elif 2.0<=mean_dist:
            colortitle='tab:red'
            txt='Warning!! Few models! '
        ax.set_title(txt+'Distance to neighbors (average=1): Minimum %4.2f, Mean %4.2f, Maximum %4.2f' %(min_dist,mean_dist,max_dist),color=colortitle,fontsize=12)
        
        lower.set_ydata(10**(np.log10(data)-Knn_factor*error_knn[0]))
        higher.set_ydata(10**(np.log10(data)+Knn_factor*error_knn[0]))
    l.set_ydata(data)    
    fig.canvas.draw_idle()

def importing(event):
    with open('Para.in','r') as f:
        lines=f.readlines()
        for line in lines:
            split_line=line.split()
            value=float(split_line[0])
            parameter=split_line[1]
            if parameter in slider_dict.keys():
                
                if slider_dict[parameter]['scale']=='log':
                    value=np.log10(value)

                slider_dict[parameter]['slider'].set_val(value)

def exporting(event):
    with open('Para.out','w') as f:
        for key in slider_dict:
            value=slider_dict[key]['slider'].val
            if slider_dict[key]['scale']=='log':
                value=10**value
            f.write(f'%12.6e {key}\n'%value)
            
def reset(event):
    for key in slider_dict:
        slider_dict[key]['slider'].reset()
    #flux_slider.reset()
    #lam_slider.reset()

In [30]:
color_list=['bisque','lightsteelblue', 'lightgreen','lightgoldenrodyellow','skyblue','lightgrey']

fig, ax = plt.subplots()
plt.subplots_adjust(left=0.2, bottom=0.41, top=0.95)

features=np.zeros((1,len(header)))
for key in slider_dict:
    #print(key)
    [down,up]=slider_dict[key]['lims']
    try:
        middle=slider_dict[key]['x0']
        #print('Found')
        #print(middle)
    except:
        middle=(up-down)/2+down
    
    val_trans, pos=transform_parameter(key,middle)
    features[0,pos]=val_trans
    if key=='amC-Zubko[s]':
        middle_sio=0.75-middle
        val_trans_sio, pos_sio=transform_parameter('Mg0.7Fe0.3SiO3[s]',middle_sio)
        features[0,pos_sio]=val_trans_sio
    #print(val_trans)  
#print(features)
data=10**(y_scaler.inverse_transform(model_saved.predict(features)))[0]
if knn_switch:
    error_knn=knn.predict(features)
    higher, =plt.plot(wavelength,10**(np.log10(data)+Knn_factor*error_knn[0]),color='grey',alpha=1)
    lower, =plt.plot(wavelength,10**(np.log10(data)-Knn_factor*error_knn[0]),color='grey',alpha=1)
    dist,neighbor_ar=knn.kneighbors(features)
    min_dist,mean_dist,max_dist=np.min(dist)/min_sample,np.mean(dist)/mean_sample,np.max(dist)/max_sample
    txt=''
    if mean_dist<=1.0:
        colortitle='tab:green'
    elif 1.0<=mean_dist<=2.0:
        colortitle='tab:orange'
    elif 2.0<=mean_dist:
        colortitle='tab:red'
        txt='Warning!! Few models! '
    ax.set_title(txt+'Distance to neighbors (average=1): Minimum %4.2f, Mean %4.2f, Maximum %4.2f' %(min_dist,mean_dist,max_dist),color=colortitle,fontsize=12)
        
t=wavelength
s = data
l, = ax.plot(t, s,marker='+',linestyle='none')

ax.axis([np.min(wavelength), 1.E+3, 10**(-12), 10**(-7)])
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel(r'$ \lambda \, [\mu m]$',fontsize=13)
ax.set_ylabel(r'$ \nu F_\nu [erg/cm^2/s]$',fontsize=13)

placed=[]
l_head=len(slider_dict)
i=0
l_1=0
r=0
for key in slider_dict:
    if i<l_head/2:
        frame=[0.17,0.06+l_1*0.026,0.25,0.02]
        l_1+=1
    else:
        frame=[0.67,0.06+r*0.026,0.25,0.02]
        r+=1

   
    plot=plt.axes(frame, facecolor=color_list[slider_dict[key]['priority']-1])
    label=slider_dict[key]['label']
    down=slider_dict[key]['lims'][0]
    up=slider_dict[key]['lims'][1]
    try:
        middle=slider_dict[key]['x0']
    except:
        middle=(up-down)/2+down
    slider=Slider(plot, label, down, up, valinit=middle)
    slider_dict[key]['slider']=slider
    i+=1
#print(placed)


#Wavlengthrange
lam = plt.axes([0.2, 0.01, 0.25, 0.015], facecolor=color_list[-1])
lam_label=r'$log(\lambda_{min}) / log(\lambda_{max})$'
lam_slider=RangeSlider(lam, lam_label, -1.2, 4.2, valinit=[-1,3],valstep=0.2)

lam_slider.on_changed(change_range)                     

#fluxrange

flux = plt.axes([0.07, 0.4, 0.01, 0.5], facecolor=color_list[-1])
flux_label=r'$log(\nu F)$ range'
flux_slider=RangeSlider(flux, flux_label, -18, -5, valinit=[-12,-7],valstep=0.2,orientation='vertical')



flux_slider.on_changed(change_flux)                     


features=np.zeros((1,len(header)))

    
for key in slider_dict:
    slider_dict[key]['slider'].on_changed(update)

'''

This part gives the user the option to set new limits but it is super slow!
lam_max = plt.axes([0.5, 0.01, 0.1, 0.04])
lam_max_txt =TextBox(lam_max, r'$\lambda_{max}$')

lam_min = plt.axes([0.3, 0.01, 0.1, 0.04])
lam_min_txt =TextBox(lam_min, r'$\lambda_{min}$')



def submit_max(value):
    print(value)
    if value!='':
        ax.set_xlim(right=float(value))

def submit_min(value):
    print(value)
    if value!='':
        ax.set_xlim(left=float(value))

    
lam_min_txt.on_submit(submit_min)
lam_max_txt.on_submit(submit_max)
'''

impor = plt.axes([0.6, 0.01, 0.1, 0.04])
import_button = Button(impor, 'Import', color=color_list[-2], hovercolor='paleturquoise')

import_button.on_clicked(importing)


export = plt.axes([0.70, 0.01, 0.1, 0.04])
export_button = Button(export, 'Export', color=color_list[-2], hovercolor='paleturquoise')


export_button.on_clicked(exporting)



resetax = plt.axes([0.85, 0.01, 0.1, 0.04])
button = Button(resetax, 'Reset', color='lightgoldenrodyellow', hovercolor='ivory')

button.on_clicked(reset)


plt.show()