# Run optimization of HXR

In [None]:
# set up env for running on SLAC production servers
import os
os.environ['OMP_NUM_THREADS']=str(6)

## Read pv info from YAML files

In [1]:
import sys
import yaml
sys.path.append("../")

from common import get_pv_objects, save_reference_point, set_magnet_strengths, \
    measure_pvs

In [3]:
pv_bounds = yaml.safe_load(open("../pv_bounds.yml"))
pv_objects = get_pv_objects("../tracked_pvs.yml")

In [4]:
pv_bounds

{'QUAD:IN20:121:BCTRL': [-0.021, 0.021],
 'QUAD:IN20:122:BCTRL': [-0.021, 0.021],
 'QUAD:IN20:361:BCTRL': [-4.32, -1.08],
 'QUAD:IN20:371:BCTRL': [1.09, 4.31],
 'QUAD:IN20:425:BCTRL': [-7.56, -1.08],
 'QUAD:IN20:441:BCTRL': [-1.08, 7.56],
 'QUAD:IN20:511:BCTRL': [-1.08, 7.56],
 'QUAD:IN20:525:BCTRL': [-7.56, -1.08],
 'QUAD:LI26:201:BCTRL': [10.2, 12.3],
 'QUAD:LI26:301:BCTRL': [-12.5, -5.7],
 'QUAD:LI26:401:BCTRL': [8.8, 12.2],
 'QUAD:LI26:501:BCTRL': [-4.5, -3.7],
 'QUAD:LI26:601:BCTRL': [11.2, 12.8],
 'QUAD:LI26:701:BCTRL': [-14.5, -13.0],
 'QUAD:LI26:801:BCTRL': [12.2, 14.1],
 'QUAD:LI26:901:BCTRL': [-10.0, -6.8],
 'SOLN:IN20:121:BCTRL': [0.377, 0.498]}

## load reference point
Also define a function to write the reference values to the pvs

In [None]:
reference = yaml.safe_load(open("reference.yml"))

def reset_pvs():
    set_magnet_strengths(reference, pv_objects, validate=False)

## define measurement function

In [None]:
import numpy as np
import time
def do_measurement(inputs):

    # set pvs and wait for BACT to settle to correct values (validate=True)
    set_magnet_strengths(inputs, validate=True)

    # measure all pvs - except for names in inputs
    results = measure_pvs(
        [name for name in pv_objects.keys() if name not in inputs]
    )

    # do some calculations
    fel_measure_time = inputs["FEL_ACCUMULATION_TIME_SEC"]
    gmd = results["GDET:FEE1:241:ENRCHSTCUHBR"][-120*fel_measure_time:]
    gmd = gmd[~np.isnan(gmd)]

    # get averaged pulse intensity for HXR
    results["hxr_pulse_intensity"] = np.percentile(gmd, 80.0)#-np.nanmean(gmd)
    #data["NEG_HXR_AVG_INTENSITY_var"]=np.nanstd(gmd)
    results["time"] = time.time()

    return results

### test measurement function

In [None]:
do_measurement({})

## Set up optimization

In [None]:
from xopt import Xopt, VOCS
import pandas as pd
import matplotlib.pyplot as plt
from copy import deepcopy
from xopt import Xopt, Evaluator
from xopt.generators.bayesian import BayesianExplorationGenerator, ExpectedImprovementGenerator,UpperConfidenceBoundGenerator
from xopt.utils import get_local_region

### VOCS

In [None]:
### set vocs according to pv_bounds file
vocs = VOCS(
    variables= pv_bounds,
    objectives={'hxr_pulse_intensity': 'MAXIMIZE'},
    constraints={'hxr_pulse_intensity': ['GREATER_THAN', 0.1]}
)

### Generator

In [None]:
NUM_RESTARTS = 20
NUM_MC_SAMPLES = 120
generator = UpperConfidenceBoundGenerator(vocs=vocs,turbo_controller='safety')
generator.numerical_optimizer.n_restarts = NUM_RESTARTS
generator.numerical_optimizer.max_time = 10
generator.n_monte_carlo_samples = NUM_MC_SAMPLES
generator.n_interpolate_points = 5
generator.gp_constructor.use_low_noise_prior = False

### Evaluator

In [None]:
evaluator = Evaluator(function=do_measurement)

### Xopt

In [None]:
X = Xopt(generator=generator, vocs=vocs, evaluator=evaluator)
ts = time.time()
X.dump_file = './'+str(ts)+'_BO_FEL.yml'

## Perform optimization

In [None]:
# reset pvs
reset_pvs()

In [None]:
# set up random sampling in a local region
# local region around reference
reference_local_region = get_local_region(reference, vocs, fraction=0.1)

# local region around current
#current_local_region = get_local_region(
#    measure_pvs(vocs.variable_names), vocs, fraction=0.1
#)

In [5]:
# optionally load data from file
# from common import load_data
# fname =
# X.add_data(load_data(fname))

In [None]:
# optional random evaluate
# X.random_evaluate(5, custom_bounds=reference_local_region)

In [None]:
# step xopt
n_steps = 10
for i in range(n_steps):
    print(i)
    X.step()

In [None]:
# visualize
o = X.data.hist(X.vocs.variable_names,figsize =(20,20))