In [None]:
from matplotlib import pyplot as plt
from ipywidgets import interact, interactive, fixed, interact_manual, FileUpload, Button
from IPython.display import display
import pandas as pd
import numpy as np
import ipywidgets as widgets
import scipy.integrate as integrate
from scipy import interpolate
from scipy.spatial import distance
import seaborn as sns
import math
from scipy import constants
import os
import io
import sys 

In [None]:
reference = FileUpload(accept='.csv',multiple=False,description="Reference")
to_process = FileUpload(accept='.csv',multiple=False,description="To process")
display(reference)
display(to_process)

In [None]:
def gaussian_op(x, mu, τ):
    return np.exp((-np.subtract(x,mu)**2.0)/(2*(τ/2.35)**2))

def tau_c_inv(x, A = 0,B = 0,C = 0,T = 0):
    return (A + B*T**3*x**2*10**-11 + C*x**4*10**-5) 

def parse_filename(df,basename):
    basename_list=basename.replace(" ","_").split("_")
    if basename_list[0] == "Calc":
        df["Dataset"]=" ".join(basename_list[2:])
        df["Kind"]="Simulation"
        df["Temperature"]=basename_list[1]
    else:
        df["Dataset"]=",".join(["experimental",basename_list[1]])
        df["Kind"]="Experiment"
        df["Temperature"]=basename_list[-1]
    df["Filename"]=basename
    df["GDOS (normalized)"]=(df["GDOS"]/df["GDOS"].max())
    return df

def read_uploaded_file(file_data):
    df=pd.read_csv(io.BytesIO(file_data["content"]),names=["Energy (meV)","GDOS","Error"],skiprows=1)
    basename,extension=os.path.splitext(file_data["name"])
    df=parse_filename(df,basename)
    return df


def apply_RLT_Broadening(df,A = 0,B = 0,C = 0,T = 0):
    dE = df["Energy (meV)"].diff().mean()
    df["RLT_Broadened_Total"]=df["Energy (meV)"].apply(lambda x : A + B*T**3*x**2*10**-11 + C*x**4*10**-5)
    df["RLT_Broadened_Boundary"]=df["Energy (meV)"].apply(lambda x : A)
    df["RLT_Broadened_Umklapp"]=df["Energy (meV)"].apply(lambda x : B*T**3*x**2*10**-11)
    df["RLT_Broadened_Impurity"]=df["Energy (meV)"].apply(lambda x : C*x**4*10**-5)
    df_melted=pd.melt(df,id_vars="Energy (meV)",value_vars=["RLT_Broadened_Total","RLT_Broadened_Boundary","RLT_Broadened_Umklapp","RLT_Broadened_Impurity"],var_name="Components")

    
    fig, ax = plt.subplots(figsize=(8,5))
    sns.lineplot(data=df_melted
        , x='Energy (meV)'
        , y='value'
        , style='Components'
        , hue='Components'            
     )
    ax.set(title=r'RLT Broadening Function', xlabel=r'Energy (meV)', ylabel=r'Scattering Rate')

    
    window_size=np.array(3*df["RLT_Broadened_Total"]/dE).astype(int)
    df.reset_index(drop=True,inplace=True)
    index=df.index
    index_min=df.index.min()
    index_max=df.index.max()
    window_range=np.array([
        np.clip(index-window_size,index_min,index_max),
        np.clip(index+window_size,index_min,index_max)
    ])
    df_size=df.shape[0]
    gdos_window=np.zeros([df_size,df_size])
    energy_window=np.zeros([df_size,df_size])
    for i,(left,right) in enumerate(window_range.T):
        gdos_window[left:right,i]=df["GDOS"].iloc[left:right]
        energy_window[left:right,i]=df["Energy (meV)"].iloc[left:right]
    gdos_blurred=np.sum(gaussian_op(energy_window,df["Energy (meV)"].values,df["RLT_Broadened_Total"].values)*np.divide(gdos_window,df["RLT_Broadened_Total"].values),axis=0)
    df["GDOS (RLT Broadened)"]=gdos_blurred/gdos_blurred.max()
    return df
    
def dict2list(data_dict):
    data_list=list()
    for k,v in data_dict.items():
        data_list.append({"name":k,"content":v["content"]})
    return data_list


def calc_difference(df_ref,df_calc,filename):
    #We use this variable to query the original calculated data.
    var_original="%s (original)"%filename
    df_calc_original=df_calc.query("Filename == '%s'"%var_original)
    interpolation_original=interpolate.interp1d(df_calc_original["Energy (meV)"].values, df_calc_original["GDOS (normalized)"].values)
    
    df_calc_RLT_Broadened=df_calc.query("Filename != '%s'"%var_original)
    interpolation_RLT_Broadened=interpolate.interp1d(df_calc_RLT_Broadened["Energy (meV)"].values, df_calc_RLT_Broadened["GDOS (normalized)"].values)
    
    min_energy=max(df_calc_original["Energy (meV)"].min(),df_calc_RLT_Broadened["Energy (meV)"].min())
    max_energy=min(df_calc_original["Energy (meV)"].max(),df_calc_RLT_Broadened["Energy (meV)"].max())
    df_ref_in_range=df_ref.query("%f < `Energy (meV)` and `Energy (meV)` < %f"%(min_energy,max_energy))
    ref_energy=df_ref_in_range["Energy (meV)"].values
    ref_GDOS=df_ref_in_range["GDOS (normalized)"].values
    
    df_result=pd.DataFrame({"Energy (meV)":ref_energy
                  ,"|ref-original|":np.abs(ref_GDOS-interpolation_original(ref_energy))
                  ,"|ref-RLT_Broadened|":np.abs(ref_GDOS-interpolation_RLT_Broadened(ref_energy))
                 })
    
    df_result_melted=pd.melt(df_result,id_vars="Energy (meV)",value_vars=["|ref-original|","|ref-RLT_Broadened|"],var_name="Comparison",value_name="Difference")
    distance_original=distance.cdist(interpolation_original(ref_energy).reshape(1,-1), ref_GDOS.reshape(1,-1), 'euclidean')
    distance_RLT_Broadened=distance.cdist(interpolation_RLT_Broadened(ref_energy).reshape(1,-1), ref_GDOS.reshape(1,-1), 'euclidean')

    return df_result_melted, distance_original[0],distance_RLT_Broadened[0]

def Plot_RLT_Broadened(reference
                      ,to_process
                       ,button_save
                      ,RLT_Broadened_Temperature
                      ,Boundary_scattering
                      ,Umklapp_scattering
                      ,Impurity_scattering
                      ,show_original=True
                      ,show_data_frame=False):
    
    if len(reference.value) == 0 or len(to_process.value) == 0:
        print("Upload data first")
        return 
    if type(reference.value) == tuple:
        data=reference.value
    else:
        data=dict2list(reference.value)
    
    df_experiment=read_uploaded_file(data[0])
    if df_experiment.empty:
        print("Empty set")
        return

    if type(to_process.value) == tuple:
        data=to_process.value
    else:
        data=dict2list(to_process.value)

    df_data=read_uploaded_file(data[0])
    if df_data.empty:
        print("Empty set")
        return    
    df_simulation=apply_RLT_Broadening(df_data.query("`Energy (meV)` < %f"%(df_experiment["Energy (meV)"].max())).copy()
                                 ,A=Boundary_scattering
                                 ,B=Umklapp_scattering
                                 ,C=Impurity_scattering
                                 ,T=RLT_Broadened_Temperature
                                )
    
    filename=df_data["Filename"].unique()[0]
    df_simulation.rename(columns={"GDOS (normalized)":"%s (original)"%filename, "GDOS (RLT Broadened)":"%s (RLT Broadened)"%filename},inplace=True)
    df_simulation_melted=pd.melt(df_simulation,id_vars="Energy (meV)",value_vars=["%s (original)"%filename,"%s (RLT Broadened)"%filename],var_name="Filename",value_name="GDOS (normalized)")

    if show_data_frame:
        print(df_experiment)
        print(df_simulation_melted)
      
        
    df_concatenated=pd.concat([df_experiment,df_simulation_melted],ignore_index=True).reset_index(drop=True)
    if not show_original:
        df_concatenated = df_concatenated.query("Dataset != 'Simulation original'")
    
    fig, ax = plt.subplots(figsize=(8,5))
    sns.lineplot(data=df_concatenated
        , x='Energy (meV)'
        , y='GDOS (normalized)'
        ,hue='Filename'
     )
    ax.spines['bottom'].set_color('0')
    ax.spines['top'].set_color('1')
    ax.spines['right'].set_color('1')
    ax.spines['left'].set_color('0')
    ax.tick_params(direction='out', width=3, bottom=True, left=True)
    ax.grid(False)
    ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.1),ncol=3)
 

    ax.set_title('Reference T=%s, Data T=%s and RLT Broadening T=%d K'%(
                       df_experiment["Temperature"].unique()[0]
                      ,df_simulation["Temperature"].unique()[0]
                      ,RLT_Broadened_Temperature

    )
                )
    
    df_difference,distance_original,distance_RLT_Broadened=calc_difference(df_experiment,df_simulation_melted,filename)
    fig, ax = plt.subplots(figsize=(8,5))
    sns.lineplot(data=df_difference
        , x='Energy (meV)'
        , y='Difference'
        ,hue='Comparison'
     )
    ax.spines['bottom'].set_color('0')
    ax.spines['top'].set_color('1')
    ax.spines['right'].set_color('1')
    ax.spines['left'].set_color('0')
    ax.tick_params(direction='out', width=3, bottom=True, left=True)
    ax.grid(False)
    ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.1),ncol=3)
 
    ax.set_title('Distance(|ref-original|)=%f, Distance(|ref-RLT_Broadened|)=%f'%(distance_original,distance_RLT_Broadened) )

    button_save.data={"df_difference":df_difference,"df_concatenated":df_concatenated,"filename":filename}
    
def save(arg):
    path="../data/"+arg.data["filename"]
    arg.data["df_difference"].to_csv(path+"_difference.csv",index=False)
    arg.data["df_concatenated"].to_csv(path+"_result.csv",index=False)

button=Button(description="Save")
button.on_click(save)
display(button)
        
d=interact(Plot_RLT_Broadened
                  ,reference=fixed(reference)
                  ,to_process=fixed(to_process)
           ,button_save=fixed(button)
        ,RLT_Broadened_Temperature=widgets.IntSlider(min=100,max=1000,value=500,step=10, description="RLT Broadened Temperature", style = {'description_width': 'initial'},layout=widgets.Layout(width='500px', height='40px'),continuous_update=False)
        ,Boundary_scattering=widgets.FloatSlider(min=0.01, max=10, value=0, step=0.01, description="Boundary Scattering", style = {'description_width': 'initial'}, layout=widgets.Layout(width='500px', height='40px'))
        ,Umklapp_scattering=widgets.FloatSlider(min=0.01, max=10, value=0, step=0.01, description="Umklapp Scattering", style = {'description_width': 'initial'}, layout=widgets.Layout(width='500px', height='40px'))
        ,Impurity_scattering=widgets.FloatSlider(min=0.01, max=10, value=0, step=0.01, description="Impurity Scattering", style = {'description_width': 'initial'}, layout=widgets.Layout(width='500px', height='40px'))
        )