# 手動最適化

In [11]:
import numpy as np
from convinience import Spectrum, Band, synthesisSpectrum
from convinience import gaussian_profile, compensation, calc_totalPower ,attenuation
from mso58trans import MSO58Wrapper, MSO58transData
from shutter import servo_shutter
from WSMethods import uploadProfile, getWSrange
from sklearn.decomposition import PCA
from csv import DictWriter
import datetime
import copy
from time import sleep
import os
import ipywidgets as widgets
from IPython.display import display
from matplotlib import pyplot as plt
from ipyfilechooser import FileChooser

In [12]:
# 各種設定周り ここにすべてまとめる

class Setting:
    def __init__(self):
        global hd_idx
        self.ws_ip                          = '169.254.6.8'
        self.ws_Freq_resolution             = 0.001 # 0.001THz = 1GHz
        self.ws_Freqrange                   = None
        self.optSpectrum : Spectrum         = None
        self.mso58_address                  = 'USB::0x0699::0x0530::C043144::INSTR'
        self.homodyne_port                  = 1
        self.fastFramePulseNum              = 5
        #self.quadrature_method              = lambda volts: PCA().fit_transform(volts.reshape((-1,volts.shape[-1])))[:,0].reshape(volts.shape[:-1])
        self.quadrature_method              = lambda volts: volts.reshape(-1,volts.shape[-1])[:,hd_idx].reshape(volts.shape[:-1])
        self.targetLOPower                  = 1
        self.shutter_port                   = 'COM3'
        self.pumpShutterPort : int          = 2
        self.param_csv                      = f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}record.csv"


    def condition_approval(self):
        assert self.ws_Freqrange             is not None
        assert self.optSpectrum              is not None
        assert self.optSpectrum.powerdensity is not None
        assert self.optSpectrum.frequency    is not None
        assert self.optSpectrum.wavelength   is not None
        assert self.optSpectrum.phase        is not None
        assert self.fastFramePulseNum        is not None
        assert self.quadrature_method        is not None
        assert self.targetLOPower            is not None
        assert self.shutter_port             is not None
        assert self.pumpShutterPort          is not None
        assert self.param_csv                is not None


setting = Setting()

startFreq, endFreq    = 191.05, 196.475# getWSrange(setting.ws_ip)
setting.ws_Freqrange  = np.arange(startFreq, endFreq, setting.ws_Freq_resolution)

setting.optSpectrum                = Spectrum()
setting.optSpectrum.band           = Band()
setting.optSpectrum.frequency      = np.load('optFreq.npy')
setting.optSpectrum.powerdensity   = np.load('optSpectrum.npy')
setting.optSpectrum.phase          = np.zeros_like(setting.optSpectrum.frequency)

setting.condition_approval()

### 最適化フェーズ

In [13]:
class State:
    def __init__(self):
        self.current_spectrum = None
        self.former_spectrum  = None

state = State()

In [14]:

def quad(w, pulseNum, quadrature_func):
    frameLength = w.shape[1]
    frames      = w.shape[0]
    period = [int(i * frameLength/pulseNum) for i in range(pulseNum)]
    duration = int(frameLength/pulseNum)
    quadratures = np.zeros((frames, pulseNum))

    volts = np.zeros((frames,pulseNum,duration))

    for i in range(pulseNum):
        s = period[i]
        volts[:,i,:] = w[:,s:(s+duration)]

    quadratures = quadrature_func(volts)
    return quadratures


def evaluation(quadratures):
    return quadratures.var()



def param_gen(state_arg: State = None):
    global state
    if state_arg is None: state_arg = state
    for center_wavelength in np.linspace(1543,1547,40):
        for band_width in np.linspace(0.2,1,20):
            for dispersion in [0]:
                band = Band()
                band.frequency          = setting.ws_Freqrange
                filterSpectrum          = gaussian_profile(
                            centerWavelength    =center_wavelength,
                            bandWidth           =band_width, band=band_width,
                            dispersion          =dispersion
                        )
                compensated_Spectrum    = compensation(
                            setting.optSpectrum,
                            targetSpectrum      =filterSpectrum,
                            band                =band
                        )
                state_arg.former_spectrum = state_arg.current_spectrum
                state_arg.current_spectrum = compensated_Spectrum
                yield compensated_Spectrum

def param_unwind(state_arg: State = None):
    global state
    if state_arg is None: state_arg = state
    state_arg.current_spectrum = state_arg.former_spectrum
    state_arg.former_spectrum = None

def set_param(state_arg: State = None):
    global state
    if state_arg is None: state_arg = state
    compensated_Spectrum = state_arg.current_spectrum
    #compensated_Spectrum = attenuation(compensated_Spectrum, total_power - setting.targetLOPower)
    filterPort = np.ones(compensated_Spectrum.powerdensity.shape)
    r = uploadProfile(setting.ws_ip, compensated_Spectrum.frequency, -(compensated_Spectrum.powerdensity), compensated_Spectrum.phase, filterPort)



def loop():
    global stop_flag
    # optimizing loop
    for compensated_Spectrum in param_gen(state):
        if stop_flag: break
        ## new_LO_spectrum
        total_power = calc_totalPower(
                                    setting.optSpectrum,
                                    compensated_Spectrum
                                )
        ## (power adjustment)
        #compensated_Spectrum = attenuation(compensated_Spectrum, total_power - setting.targetLOPower)
        filterPort = np.ones(compensated_Spectrum.powerdensity.shape)
        r = uploadProfile(setting.ws_ip, compensated_Spectrum.frequency, -(compensated_Spectrum.powerdensity), compensated_Spectrum.phase, filterPort)
        sleep(3)
        ## pump off
        servo_shutter(port=setting.shutter_port).close([setting.pumpShutterPort])
        ## osc_com
        osc_handler = MSO58Wrapper(setting.mso58_address)
        osc_handler.push_single()
        t, w = None, None
        while True:
            try:
                t, w = osc_handler.transfer2byte(setting.homodyne_port)
                break
            except:
                sleep(1)
                continue
        ## quad
        quadratures = quad(w, setting.fastFramePulseNum, setting.quadrature_method)
        ## eval_shotnoise
        shot_var = quadratures.var()

        ## pump on
        servo_shutter(port=setting.shutter_port).open([setting.pumpShutterPort])
        ## osc_com
        osc_handler = MSO58Wrapper(setting.mso58_address)
        osc_handler.push_single()
        t, w = None, None
        while True:
            try:
                t, w = osc_handler.transfer2byte(setting.homodyne_port)
                break
            except:
                sleep(1)
                continue
        ## quad
        quadratures = quad(w, setting.fastFramePulseNum, setting.quadrature_method)
        ## eval_squeeze
        squeeze_var = quadratures.var()
        ## recording
        is_exists = os.path.exists(setting.param_csv)
        with open(setting.param_csv, 'a', newline = "") as f:
            headersCSV = ['LOpower', 'shotnoise', 'squeeze', 'diff']
            dict = {
                'LOpower'          : total_power,
                'shotnoise'        : 10*np.log(shot_var),
                'squeeze'          : 10*np.log(squeeze_var),
                'diff'             : 10*np.log(squeeze_var) - 10*np.log(shot_var)
            }
            dictwriter_object = DictWriter(f, fieldnames=headersCSV)
            if not is_exists: dictwriter_object.writeheader()
            dictwriter_object.writerow(dict)
        print(dict)

In [15]:
def read_osc():
    
    global t, w
    global now
    now = datetime.datetime.now()
    try:
        handler = MSO58transData(setting.mso58_address)
        t, w = handler.transfer2byte(1)
        print(f't.shape={t.shape} w.shape={w.shape}')
    except Exception as e:
        print('failed')
        print(e)
    

def save_osc():
    global comment
    filename_w = now.strftime('%Y%m%d_%H%M%S') + '_' + 'wdata_' + comment   + '.npy'
    #filename_w_tes = 'wdata_tes_' + comment + '_' + now.strftime('%Y%m%d_%H%M%S') + '.npy'
    filename_t = now.strftime('%Y%m%d_%H%M%S') + '_' +'tdata_' + comment +  '.npy'
    np.save(filename_w,w)
    #np.save(filename_w_tes,w_tes)
    np.save(filename_t,t)

def var_plot():
    global t, w
    plt.plot(np.var(w[:,:], axis = 0))
    plt.show()


def quadrature_plot():
    global pulseNum
    global hd_idx
    global w, t
    global comment
    global quadrature
    global quadratures
    global now
    global frames

    period = [int(i * w.shape[1]/pulseNum) for i in range(pulseNum)]

    volts = w.T

    varList = []

    frames = w.shape[0]

    fig = plt.figure(figsize = (40,5))
    quadratures = np.zeros((frames, pulseNum))

    ax1 = None
    ax2 = None
    for i in range(pulseNum):
        s = period[i]
        duration = int(w.shape[1]/pulseNum)
        ax1 = fig.add_subplot(2,max(pulseNum,10),i+1, sharey=ax1)
        ax1.plot(volts[s:(s+duration),1])
        ax1.set_title(f'{i}')
        #quadrature = volts[s:(s+duration),:].sum(axis=0)
        quadrature = volts[s+hd_idx,:]
        ax2 = fig.add_subplot(2,max(pulseNum,10),i+1 + max(pulseNum,10), sharey=ax2)
        ax2.scatter(range(w.shape[0]),quadrature,s = 0.01)
        ax2.set_title(f'{i}')
    plt.show()

    for j in range(frames):
        for i in range(pulseNum):
            s = period[i]
            #quadratures[j, i] = volts[s:s+duration, j].sum(axis=0)
            quadratures[j, i] = volts[s+hd_idx, j]

    varList = []

    for i in range(pulseNum):
        s = period[i]
        duration = int(w.shape[1]/pulseNum)
        #quadrature = volts[s:(s+duration),:].sum(axis=0)
        quadrature = volts[s+hd_idx,:]
        varList.append(quadrature.var())

    fig = plt.figure()
    plt.plot(varList)
    plt.gca().get_yaxis().get_major_formatter().set_useOffset(False)
    plt.show()

    plt.plot(10*np.log10(varList))
    plt.title('varList(dB)')
    plt.show

    fig2 = plt.figure()
    n, bins, patches = plt.hist(quadratures[:,1],bins=100)
    plt.show()

def save_quadratures():
    global comment
    global quadrature
    filename_hd = now.strftime('%Y%m%d_%H%M%S') + 'hd_' + comment + '.txt'
    np.savetxt(filename_hd, quadratures)

def save_hist_img():
    global comment
    global quadratures
    fig2 = plt.figure()
    n, bins, patches = plt.hist(quadratures[:,1],bins=100)
    plt.show()
    fig2.savefig(now.strftime('%Y%m%d_%H%M%S') + 'hist_' + comment + '.png')


def quadrature_sweep_plot():
    global w, t
    global sweep_target_pulse
    frameLength = w.shape[1]
    frames      = w.shape[0]
    period = [int(i * frameLength/pulseNum) for i in range(pulseNum)]
    duration = int(frameLength/pulseNum)
    quadratures = np.zeros((frames, pulseNum))

    volts = np.zeros((frames,pulseNum,duration))
    for i in range(pulseNum):
        s = period[i]
        volts[:,i,:] = w[:,s:(s+duration)]


    fig, ax = plt.subplots(2, 10, sharey="row", figsize = (30,5))
    for i, delta_hd_idx in enumerate(range(-5,5)):
        
        quadratures = volts.reshape((-1,volts.shape[-1]))[:,hd_idx + delta_hd_idx].reshape(volts.shape[:-1])
        ax1 = ax[0,i]
        ax1.scatter(range(quadratures.shape[0]),quadratures[:,sweep_target_pulse],s = 0.01)
        ax1.set_title(f'{delta_hd_idx + hd_idx}')
        ax2 = ax[1,i]
        ax2.hist(quadratures[:,sweep_target_pulse], bins=100)
    plt.show()
    pass

In [None]:
loop()

In [23]:
def simple_ui():
    button_clear_output = widgets.Button(description='表示クリア')
    button_read_osc = widgets.Button(description='オシロ読込')
    button_save_osc = widgets.Button(description='オシロ保存')
    button_var_plot = widgets.Button(description='varプロット')
    button_quadrature_plot = widgets.Button(description='quadratureプロット')
    button_quadrature_sweep_plot = widgets.Button(description='hd_idxスイーププロット')
    button_save_quadratures = widgets.Button(description='quadrature保存')
    button_save_hist_img = widgets.Button(description='hist保存')

    button_param_gen = widgets.Button(description='param_gen')
    button_param_unwind = widgets.Button(description='param_unwind')
    button_set_param = widgets.Button(description='set_param')
    button_measure_diff_1push = widgets.Button(description='measure_diff_1push')

    filechooser = FileChooser('./')
    button_load_as_wdata_npy = widgets.Button(description='wdata.npyとして読込')
    button_load_as_tdata_npy = widgets.Button(description='tdata.npyとして読込')
    button_input_field = widgets.Button(description='変数反映')

    text_comment = widgets.Text(value='',placeholder='文字を入力',description='comment',disabled=False)
    int_pulseNum = widgets.IntText(value=5,description='pulseNum')
    int_hd_idx = widgets.IntText(value=5,description='hd_idx')
    int_sweep_target_pulse = widgets.IntText(value=0, description='sweep_target_pulse')

    output = widgets.Output(layour={'border': '1px solid black'})
    def wrapped_func_factory(func):
        def new_func(ui_element):
            with output:
                print(f"exec func {func.__name__}")
                func()
                print(f"complete {func.__name__}")
        return new_func
    button_clear_output.on_click(lambda button: output.clear_output(wait=False))
    button_read_osc.on_click(wrapped_func_factory(read_osc))
    button_save_osc.on_click(wrapped_func_factory(save_osc))
    button_var_plot.on_click(wrapped_func_factory(var_plot))
    button_quadrature_plot.on_click(wrapped_func_factory(quadrature_plot))
    button_quadrature_sweep_plot.on_click(wrapped_func_factory(quadrature_sweep_plot))
    button_save_quadratures.on_click(wrapped_func_factory(save_quadratures))
    button_save_hist_img.on_click(wrapped_func_factory(save_hist_img))
    button_param_gen.on_click(wrapped_func_factory(param_gen))

    def load_npy_factory(variable_name):
        def load_npy():
            choosed_file_path = filechooser.selected
            global w, t
            if variable_name == 'w':
                w = np.load(choosed_file_path)
                print(f'w.shape={w.shape}')
            elif variable_name == 't':
                t = np.load(choosed_file_path)
                print(f't.shape={t.shape}')
        return load_npy
        
    button_load_as_wdata_npy.on_click(wrapped_func_factory(load_npy_factory('w')))
    button_load_as_tdata_npy.on_click(wrapped_func_factory(load_npy_factory('t')))
    
    def load_input_field():
        global comment
        global pulseNum
        global hd_idx
        global sweep_target_pulse
        comment = text_comment.value
        pulseNum = int_pulseNum.value
        setting.fastFramePulseNum = pulseNum
        hd_idx = int_hd_idx.value
        sweep_target_pulse = int_sweep_target_pulse.value
    button_input_field.on_click(lambda button: load_input_field())
    
    display(
        widgets.HBox([button_clear_output,button_read_osc,button_save_osc,button_var_plot,button_quadrature_plot,button_quadrature_sweep_plot,button_save_quadratures,button_save_hist_img]),
        widgets.HBox([button_param_gen, button_param_unwind, button_set_param, button_measure_diff_1push]),
        widgets.HBox([filechooser, button_load_as_wdata_npy, button_load_as_tdata_npy]),
        widgets.HBox([text_comment,int_pulseNum, int_hd_idx,int_sweep_target_pulse ,button_input_field]),
        output)

In [24]:
simple_ui()

HBox(children=(Button(description='表示クリア', style=ButtonStyle()), Button(description='オシロ読込', style=ButtonStyle…

HBox(children=(Button(description='param_gen', style=ButtonStyle()), Button(description='param_unwind', style=…

HBox(children=(FileChooser(path='C:\Users\ruofan\storage\waveshaper\optimize', filename='', title='', show_hid…

HBox(children=(Text(value='', description='comment', placeholder='文字を入力'), IntText(value=5, description='pulse…

Output()