In [None]:
from tomography import maxlike, Wigner_fock
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
import matplotlib as mpl
from igorwriter import IgorWave
import ipywidgets as widgets
from IPython.display import display
from ipyfilechooser import FileChooser
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default='notebook' # for vscode ,maybe 'colab' on jupyterlab 


In [None]:
def pca(time_array):
    time_array = time_array - np.average(time_array)
    mat = time_array.T.dot(time_array)
    w,v = np.linalg.eigh(mat)
    return w[::-1], v[:,::-1]

def plot_bar3d_rho():
    global rho
    global elev, azim
    fig = plt.figure(figsize=(6, 3))
    ax1 = fig.add_subplot(121, projection='3d')
    X,Y = np.meshgrid(np.arange(rho.shape[0]), np.arange(rho.shape[1]))
    x,y,z = X.ravel(), Y.ravel(), np.zeros(X.size)

    dx,dy,dz_re = np.ones(X.size),np.ones(Y.size),np.real(rho).ravel()
    ax1.bar3d(x,y,z,dx,dy,dz_re)
    ax1.set_zlabel('Density matrix (real)')
    ax1.set_zlim(-1, 1)
    ax1.view_init(elev=elev, azim=-azim)

    ax2 = fig.add_subplot(122, projection='3d')
    dx,dy,dz_im = np.ones(X.size),np.ones(Y.size),np.imag(rho).ravel()
    ax2.bar3d(x,y,z,dx,dy,dz_im)
    ax2.set_zlabel("Density matrix (imag)")
    ax2.set_zlim(-1,1)
    ax2.view_init(elev=elev, azim=-azim)
    plt.show()
    
def calc_wigner_rho():
    global rho
    global wigner
    w_pp = Wigner_fock(rho.shape[0]-1, 5, 0.01)
    x = np.linspace(-5, 5, 101)
    X,Y = np.meshgrid(x, x)
    wigner = w_pp.wigner_rho(rho, X, Y)

def plot_wigner():
    global wigner
    global elev, azim
    x = np.linspace(-5, 5, 101)
    X,Y = np.meshgrid(x, x)
    f = plt.figure(figsize=(5,5))
    ax_3d = f.add_subplot(111, projection='3d')
    surf = ax_3d.plot_surface(X, Y, wigner.real, cmap=cm.bwr, linewidth=0, antialiased=False)
    ax_3d.set_xlabel('X')
    ax_3d.set_ylabel('P')
    ax_3d.view_init(elev=elev, azim=-azim)
    f.colorbar(surf, shrink=0.5, aspect=5)
    plt.show()

def plot_wigner_plotly():
    global wigner
    z = wigner.real
    sh_0, sh_1 = wigner.shape
    x, y = np.linspace(-5, 5, sh_0), np.linspace(-5, 5, sh_1)
    fig = go.Figure(data=[go.Surface(z=z, x=x, y=y, colorscale=[[0, "rgb(255,0,0)"],[0.5,"rgb(255,255,255)"],[1,"rgb(0,0,255)"]])])
    fig.update_traces(contours_z=dict(show=True, usecolormap=True,
                                    highlightcolor="limegreen", project_z=True))
    fig.update_layout(title='wigner', autosize=False,
                    width=500, height=500,
                    margin=dict(l=65, r=50, b=65, t=90))
    fig.show()


def init():
    global q_files
    global l_degs
    global shot_file
    q_files = []
    l_degs = []
    shot_file = ''

def tomography():
    global q_files
    global l_degs
    global shot_file
    global pulse_id
    global rho
    
    quadratures = []
    phases = []

    for q_filename, ph in zip(q_files, l_degs):
        quad = np.loadtxt(q_filename)[:,pulse_id]
        quadratures += quad.tolist()
        phases += (ph*np.ones_like(quad)*np.pi/180.0).tolist()

    shot = np.loadtxt(shot_file)[:,pulse_id]

    q_normalize = lambda quadratures, shot, hbar=1: (np.array(quadratures) - np.array(shot).mean())/np.array(shot).std() * np.sqrt(hbar/2)
    quadratures = q_normalize(quadratures, shot).tolist()

    rho = maxlike(np.array(quadratures).ravel(), np.array(phases).ravel(), max_photon=14, conv_th=1e-15)

    plot_bar3d_rho()
    plt.show()

def set_mpl_font_size():
    global mpl_font_size
    plt.rcParams["font.size"] = mpl_font_size

def tomography_m_avg():
    global q_files
    global l_degs
    global shot_file
    global pulse_id
    global rho
    global quad_m_avg
    
    quadratures = []
    phases = []

    for q_filename, ph in zip(q_files, l_degs):
        quad = np.loadtxt(q_filename)[:,pulse_id]
        avg_num = 1000
        quad_m_avg = np.convolve(np.loadtxt(q_filename)[:,pulse_id + 1], np.ones(avg_num)/avg_num, mode='same')
        quad -= quad_m_avg
        quadratures += quad.tolist()
        phases += (ph*np.ones_like(quad)*np.pi/180.0).tolist()

    shot = np.loadtxt(shot_file)[:,pulse_id]

    q_normalize = lambda quadratures, shot, hbar=1: (np.array(quadratures) - np.array(shot).mean())/np.array(shot).std() * np.sqrt(hbar/2)
    quadratures = q_normalize(quadratures, shot).tolist()

    rho = maxlike(np.array(quadratures).ravel(), np.array(phases).ravel(), max_photon=14, conv_th=1e-15)

    plot_bar3d_rho()
    plt.show()

def save_rho():
    global rho
    global comment
    filename_rho = 'rho_' + comment   + '.npy'
    np.save(filename_rho,rho)

def save_wigner():
    global wigner
    global comment
    filename_wigner = 'wigner_' + comment + '.npy'
    np.save(filename_wigner, wigner.real)

def save_wigner_itx():
    global wigner
    global comment
    filename_wigner_itx = 'wigner_' + comment + '.itx'
    wave = IgorWave(wigner.real, name='wigner_' + comment)
    wave.set_dimscale('x', -5,0.1)
    wave.set_dimscale('y', -5,0.1)
    wave.save_itx(filename_wigner_itx, image=True)

def save_fig_pdf():
    global fig
    global comment
    plt.get_current_fig_manager().canvas.figure = fig
    plt.savefig(f'{comment}.pdf')
    plt.close()

def save_fig_png():
    global fig
    global comment
    plt.get_current_fig_manager().canvas.figure = fig
    plt.savefig(f'{comment}.png', dpi=300)
    plt.close()


def graph_rho():


    def my_max(x, num):
        some = np.ones_like(x)*num
        some[x>num] = x[x>num]
        return some

    def my_min(x,num):
        some = np.ones_like(x)*num
        some[x<=num] = x[x<=num]
        return some
    

    
    global fig
    global rho
    global rho_fock_dim
    rho_local = rho[:rho_fock_dim,:rho_fock_dim]


    fig = plt.figure(figsize=(16,8))
    plt.subplot2grid((1,2),(0,0), label='real', projection='3d', proj_type='ortho')
    plt.subplot2grid((1,2),(0,1), label='imag', projection='3d', proj_type='ortho')
    plt.tight_layout()


    axes_dict = {i.get_label():i for i in plt.gcf().axes }

    plt.sca(axes_dict['real'])
    plt.gca().view_init(elev=elev, azim=-azim)

    bar_width = 0.7

    X,Y = np.meshgrid(np.arange(rho_local.shape[0]), np.arange(rho_local.shape[1]), indexing='ij')
    x,y,z = X.ravel(), Y.ravel(), np.zeros(X.size)
    x = x - bar_width/2.0
    y = y - bar_width/2.0

    dx,dy,dz = np.ones(X.size)*bar_width,np.ones(Y.size)*bar_width,np.real(rho_local).ravel()
    z_value = np.real(rho_local).ravel()

    #影の出来方がくそなので
    z = my_min(dz, 0)
    dz = np.abs(dz)

    cmap = plt.get_cmap('seismic')
    cmap_float_gen1 = lambda x: (x-(-0.5))/((0.5)-(-0.5))
    clist = [cmap(cmap_float_gen1(i)) for i in z_value]

    plt.gca().bar3d(x,y,z,dx,dy,dz, color = clist, zsort='max')
    plt.gca().set_zlabel('Density matrix (real)')
    plt.gca().set_xlabel('n')
    plt.gca().set_ylabel('m')
    plt.gca().set_zlim(0, 0.5)
    plt.xticks(np.arange(0,rho_fock_dim,2))
    plt.yticks(np.arange(0,rho_fock_dim,2))
    plt.xlim([-1,rho_fock_dim])
    plt.ylim([-1,rho_fock_dim])
    norm = mpl.colors.Normalize(vmin=-0.5, vmax = 0.5)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm = norm)
    plt.colorbar(sm, ax=plt.gca(), shrink=0.3, pad=0.15)
    
    plt.sca(axes_dict['imag'])
    plt.gca().view_init(elev=elev, azim=-azim)

    bar_width = 0.7

    X,Y = np.meshgrid(np.arange(rho_local.shape[0]), np.arange(rho_local.shape[1]), indexing='ij')
    x,y,z = X.ravel(), Y.ravel(), np.zeros(X.size)
    x = x - bar_width/2.0
    y = y - bar_width/2.0

    dx,dy,dz = np.ones(X.size)*bar_width,np.ones(Y.size)*bar_width,np.imag(rho_local).ravel()
    z_value = np.imag(rho_local).ravel()

    #影の出来方がくそなので
    z = my_min(dz, 0)
    dz = np.abs(dz)

    cmap = plt.get_cmap('seismic')
    cmap_float_gen1 = lambda x: (x-(-0.5))/((0.5)-(-0.5))
    clist = [cmap(cmap_float_gen1(i)) for i in z_value]

    plt.gca().bar3d(x,y,z,dx,dy,dz, color = clist, zsort='max')
    plt.gca().set_zlabel('Density matrix (image)')
    plt.gca().set_xlabel('n')
    plt.gca().set_ylabel('m')
    plt.gca().set_zlim(0, 0.5)
    plt.xticks(np.arange(0,rho_fock_dim,2))
    plt.yticks(np.arange(0,rho_fock_dim,2))
    plt.xlim([-1,rho_fock_dim])
    plt.ylim([-1,rho_fock_dim])
    norm = mpl.colors.Normalize(vmin=-0.5, vmax = 0.5)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm = norm)
    plt.colorbar(sm, ax=plt.gca(), shrink=0.3, pad=0.15)

    plt.tight_layout()
    plt.show()

def graph_wigner():
    global wigner
    global elev, azim
    global fig

    hbar = 1

    fig = plt.figure(figsize=(10,5))
    plt.subplot2grid((1,2),(0,0), label='top')
    plt.subplot2grid((1,2),(0,1), label='3d', projection='3d', proj_type='ortho')
    plt.tight_layout()

    axes_dict = {i.get_label():i for i in plt.gcf().axes }
    plt.sca(axes_dict['top'])
    plt.gca().set_aspect('equal', adjustable='box')
    x = np.linspace(-5, 5, 101)
    X,Y = np.meshgrid(x, x)
    surf = plt.gca().contourf(X, Y, wigner.real, cmap='seismic', norm=Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)),levels=100) 
    sm = plt.cm.ScalarMappable(cmap='seismic', norm = Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)))
    plt.colorbar(sm).set_label('wigner function')
    plt.xlabel('x')
    plt.ylabel('p')



    plt.sca(axes_dict['3d'])


    x = np.linspace(-5, 5, 101)
    X,Y = np.meshgrid(x, x)
    surf = plt.gca().plot_surface(X, Y, wigner.real,  cmap='seismic', norm=Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)), linewidth=0,antialiased=True, rcount=100, ccount=100,shade=True)
    plt.gca().axes.set_zlim3d(bottom=-1/(np.pi*hbar), top=1/(np.pi*hbar))
    plt.gca().set_xlabel('x')
    plt.gca().set_ylabel('p')
    plt.gca().view_init(elev=elev, azim=-azim)
    #plt.colorbar(surf, ax=plt.gca() ,shrink=0.5, aspect=5)

    plt.tight_layout()
    plt.show()

def graph_wigner2():
    global wigner
    global elev, azim
    global fig

    hbar = 1

    fig = plt.figure(figsize=(10,5))
    plt.subplot2grid((1,2),(0,0), label='top')
    plt.subplot2grid((1,2),(0,1), label='3d', projection='3d', proj_type='ortho')
    plt.tight_layout()

    axes_dict = {i.get_label():i for i in plt.gcf().axes }
    plt.sca(axes_dict['top'])
    plt.gca().set_aspect('equal', adjustable='box')
    x = np.linspace(-5, 5, 101)
    X,Y = np.meshgrid(x, x)
    surf = plt.gca().contourf(X, Y, wigner.real, cmap='seismic', norm=Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)),levels=100) 
    sm = plt.cm.ScalarMappable(cmap='seismic', norm = Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)))
    plt.colorbar(sm).set_label('wigner function')
    plt.xlabel('x')
    plt.ylabel('p')



    plt.sca(axes_dict['3d'])


    x = np.linspace(-5, 5, 101)
    X,Y = np.meshgrid(x, x)
    surf = plt.gca().plot_surface(X, Y, wigner.real,  cmap='seismic', norm=Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)), linewidth=0.3,alpha=0.5 ,edgecolor='black',antialiased=True, rcount=100, ccount=100,shade=True)
    #plt.gca().axes.set_zlim3d(bottom=-1/(np.pi*hbar), top=1/(np.pi*hbar))
    plt.gca().axes.set_zlim3d(bottom=np.min(wigner.real), top=np.max(wigner.real))
    plt.gca().set_zticks([np.min(wigner.real),0,np.max(wigner.real)])
    plt.gca().set_xlabel('x')
    plt.gca().set_ylabel('p')
    plt.gca().view_init(elev=elev, azim=-azim)
    #plt.colorbar(surf, ax=plt.gca() ,shrink=0.5, aspect=5)

    plt.tight_layout()
    plt.show()

def graph_wigner3():
    global wigner
    global elev, azim
    global fig

    hbar = 1

    fig = plt.figure(figsize=(10,5))
    plt.subplot2grid((1,2),(0,0), label='top')
    plt.subplot2grid((1,2),(0,1), label='3d', projection='3d', proj_type='ortho')
    plt.tight_layout()

    axes_dict = {i.get_label():i for i in plt.gcf().axes }
    plt.sca(axes_dict['top'])
    plt.gca().set_aspect('equal', adjustable='box')
    x = np.linspace(-5, 5, 101)
    X,Y = np.meshgrid(x, x)
    surf = plt.gca().contourf(X, Y, wigner.real, cmap='seismic', norm=Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)),levels=100) 
    sm = plt.cm.ScalarMappable(cmap='seismic', norm = Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)))
    plt.colorbar(sm).set_label('wigner function')
    plt.xlabel('x')
    plt.ylabel('p')



    plt.sca(axes_dict['3d'])


    x = np.linspace(-5, 5, 101)
    X,Y = np.meshgrid(x, x)
    surf = plt.gca().plot_surface(X, Y, wigner.real,  cmap='seismic', norm=Normalize(vmin=-1/(np.pi*hbar), vmax=1/(np.pi*hbar)), linewidth=0.3,alpha=0.5 ,edgecolor='black',antialiased=True, rcount=100, ccount=100,shade=True)
    #plt.gca().axes.set_zlim3d(bottom=-1/(np.pi*hbar), top=1/(np.pi*hbar))
    #plt.gca().axes.set_zlim3d(bottom=np.min(wigner.real), top=np.max(wigner.real))
    plt.gca().axes.set_zlim3d(bottom=-0.5/(np.pi*hbar), top=0.5/(np.pi*hbar))
    #plt.gca().set_zticks([np.min(wigner.real),0,np.max(wigner.real)])
    plt.gca().set_xlabel('x')
    plt.gca().set_ylabel('p')
    plt.gca().view_init(elev=elev, azim=-azim)
    #plt.colorbar(surf, ax=plt.gca() ,shrink=0.5, aspect=5)

    plt.tight_layout()
    plt.show()

In [None]:


def simple_ui():
    button_clear_output = widgets.Button(description='表示クリア')
    button_init = widgets.Button(description='初期化')
    button_tomography = widgets.Button(description='トモグラフィー')
    button_tomography_m_avg = widgets.Button(description='トモグラフィー移動平均')
    button_rho_plot = widgets.Button(description='rhoプロット')
    button_wigner_rho = widgets.Button(description='計算rho→wigner')
    button_wigner_plot = widgets.Button(description='wignerプロット')
    button_save_rho = widgets.Button(description='rho保存')
    button_save_wigner = widgets.Button(description='wigner保存')
    button_save_wigner_itx = widgets.Button(description='wigner→itx保存')
    button_wigner_plotly = widgets.Button(description='wignerプロット3D')
    button_save_fig_pdf = widgets.Button(description='save_fig_pdf')
    button_save_fig_png = widgets.Button(description='save_fig_png')
    button_set_mpl_font_size = widgets.Button(description='set_mpl_font_size')
    filechooser = FileChooser('./')
    float_quad_deg = widgets.FloatText(value=0,description='角度deg')
    button_add_to_q_files = widgets.Button(description='q_files追加')
    button_add_to_shot_file = widgets.Button(description='shot_file追加')
    button_load_as_rho_npy = widgets.Button(description='rho.npy読込')
    button_load_as_wigner_npy = widgets.Button(description='wigner.npy読込')
    button_input_field = widgets.Button(description='変数反映')

    text_comment = widgets.Text(value='',placeholder='文字を入力',description='comment',disabled=False)
    int_pulse_id = widgets.IntText(value=0,description='pulse_id')
    int_elev = widgets.IntText(value=10,description='elev')
    int_azim = widgets.IntText(value=60,description='azim')
    int_rho_fock_dim = widgets.IntText(value=15,description='rho_fock_dim')
    float_mpl_font_size = widgets.IntText(value=10,description='mpl_font_size')


    button_graph_rho = widgets.Button(description='graph_rho', layout={'width':'auto'})
    button_graph_wigner = widgets.Button(description='graph_wigner', layout={'width':'auto'})
    button_graph_wigner2 = widgets.Button(description='graph_wigner2', layout={'width':'auto'})
    button_graph_wigner3 = widgets.Button(description='graph_wigner3', layout={'width':'auto'})





    output = widgets.Output(layour={'border': '1px solid black'})
    def wrapped_func_factory(func):
        def new_func(ui_element):
            with output:
                print(f"exec func {func.__name__}")
                func()
                print(f"complete {func.__name__}")
        return new_func
    button_clear_output.on_click(lambda button: output.clear_output(wait=False))
    button_init.on_click(wrapped_func_factory(init))
    button_tomography.on_click(wrapped_func_factory(tomography))
    button_tomography_m_avg.on_click(wrapped_func_factory(tomography_m_avg))
    button_rho_plot.on_click(wrapped_func_factory(plot_bar3d_rho))
    button_wigner_rho.on_click(wrapped_func_factory(calc_wigner_rho))
    button_wigner_plot.on_click(wrapped_func_factory(plot_wigner))
    button_save_rho.on_click(wrapped_func_factory(save_rho))
    button_save_wigner.on_click(wrapped_func_factory(save_wigner))
    button_save_wigner_itx.on_click(wrapped_func_factory(save_wigner_itx))
    button_wigner_plotly.on_click(wrapped_func_factory(plot_wigner_plotly))
    button_save_fig_pdf.on_click(wrapped_func_factory(save_fig_pdf))
    button_save_fig_png.on_click(wrapped_func_factory(save_fig_png))
    button_graph_rho.on_click(wrapped_func_factory(graph_rho))
    button_graph_wigner.on_click(wrapped_func_factory(graph_wigner))
    button_graph_wigner2.on_click(wrapped_func_factory(graph_wigner2))
    button_graph_wigner3.on_click(wrapped_func_factory(graph_wigner3))
    button_set_mpl_font_size.on_click(wrapped_func_factory(set_mpl_font_size))


    def load_npy_factory(variable_name):
        def load_npy():
            choosed_file_path = filechooser.selected
            global rho, wigner
            if variable_name == 'rho':
                rho = np.load(choosed_file_path)
                print(f'rho.shape={rho.shape}')
            elif variable_name == 'wigner':
                wigner = np.load(choosed_file_path)
                print(f'wigner.shape={wigner.shape}')
        return load_npy
    def edit_path_factory(variable_name):
        def edit_path():
            choosed_file_path = filechooser.selected
            global q_files,shot_file, l_degs
            if variable_name == 'q_files':
                q_files.append(choosed_file_path)
                l_degs.append(float_quad_deg.value)
            elif variable_name == 'shot_file':
                shot_file = choosed_file_path
            print(f"shot_file {shot_file}")
            print('q_files')
            for i, j in zip(q_files,l_degs):
                print(i,j)
        return edit_path
            
    button_add_to_q_files.on_click(wrapped_func_factory(edit_path_factory('q_files')))
    button_add_to_shot_file.on_click(wrapped_func_factory(edit_path_factory('shot_file')))
    button_load_as_rho_npy.on_click(wrapped_func_factory(load_npy_factory('rho')))
    button_load_as_wigner_npy.on_click(wrapped_func_factory(load_npy_factory('wigner')))
        
    
    def load_input_field():
        global comment
        global pulse_id
        global elev, azim
        global rho_fock_dim
        global mpl_font_size
        comment = text_comment.value
        pulse_id = int_pulse_id.value
        elev = int_elev.value
        azim = int_azim.value
        rho_fock_dim = int_rho_fock_dim.value
        mpl_font_size = float_mpl_font_size.value
    button_input_field.on_click(lambda button: load_input_field())
    

    display(
        widgets.VBox([
            widgets.HBox([button_clear_output,button_init,button_tomography,button_rho_plot,button_wigner_rho, button_wigner_plot, button_save_rho, button_save_wigner]),
            widgets.HBox([button_save_wigner_itx,button_tomography_m_avg, button_wigner_plotly]),
            widgets.HBox([button_save_fig_pdf, button_save_fig_png, button_set_mpl_font_size]),
            widgets.HBox([filechooser,float_quad_deg , button_add_to_q_files, button_add_to_shot_file,button_load_as_rho_npy,button_load_as_wigner_npy]),
            widgets.HBox([text_comment,int_pulse_id,int_elev,int_azim]),
            widgets.HBox([int_rho_fock_dim, float_mpl_font_size, button_input_field]),
            widgets.HBox([button_graph_rho, button_graph_wigner, button_graph_wigner2, button_graph_wigner3]),
            output,
        ])
)
    

In [None]:
simple_ui()