Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature wildcard serialization #279

Merged
merged 3 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion pygsti/objectivefns/wildcardbudget.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
import numpy as _np

from pygsti import tools as _tools
from pygsti.baseobjs import NicelySerializable as _NicelySerializable

#pos = lambda x: x**2
pos = abs


class WildcardBudget(object):
class WildcardBudget(_NicelySerializable):
"""
A fixed wildcard budget.

Expand Down Expand Up @@ -339,6 +340,12 @@ def update_probs(self, probs_in, probs_out, freqs, layout, precomp=None, probs_f

qvec, W = _adjust_qvec_to_be_nonnegative_and_unit_sum(qvec, W, min_qvec, circ, tol)

initialTVD = 0.5 * sum(_np.abs(qvec - fvec)) # update current TVD
if initialTVD <= W + tol: # TVD is "in-budget" for this circuit due to adjustment; leave as is
probs_out[elInds] = qvec
if return_deriv: p_deriv[elInds] = 0.0
continue

#recompute A-D b/c we've updated qvec
A = _np.logical_and(qvec > fvec, fvec > 0); sum_fA = sum(fvec[A]); sum_qA = sum(qvec[A])
B = _np.logical_and(qvec < fvec, fvec > 0); sum_fB = sum(fvec[B]); sum_qB = sum(qvec[B])
Expand Down Expand Up @@ -461,6 +468,15 @@ def update_probs(self, probs_in, probs_out, freqs, layout, precomp=None, probs_f

return p_deriv if return_deriv else None

def _to_nice_serialization(self):
state = super()._to_nice_serialization()
state.update({'wildcard_vector': list(self.wildcard_vector)})
return state

@classmethod
def _from_nice_serialization(cls, state): # memo holds already de-serialized objects
return cls(_np.array(state['wildcard_vector'], 'd'))


class PrimitiveOpsWildcardBudgetBase(WildcardBudget):
"""
Expand Down Expand Up @@ -540,6 +556,24 @@ def num_primitive_ops(self):
def wildcard_error_per_op(self):
return {lbl: val for lbl, val in zip(self.primitive_op_labels, self.per_op_wildcard_vector)}

def _to_nice_serialization(self):
state = super()._to_nice_serialization()
state.update({'primitive_op_labels': [str(lbl) for lbl in self.primitive_op_labels],
'idle_name': self._idlename,
})
return state

@classmethod
def _from_nice_serialization(cls, state): # memo holds already de-serialized objects
primitive_op_labels = cls._parse_primitive_op_labels(state['primitive_op_labels'])
return cls(primitive_op_labels, _np.array(state['wildcard_vector'], 'd'), state['idle_name'])

@classmethod
def _parse_primitive_op_labels(cls, label_strs):
from pygsti.circuits.circuitparser import parse_label as _parse_label
return [(_parse_label(lbl_str) if (lbl_str != 'SPAM') else 'SPAM')
for lbl_str in label_strs]

def circuit_budget(self, circuit):
"""
Get the amount of wildcard budget, or "outcome-probability-slack" for `circuit`.
Expand Down Expand Up @@ -812,6 +846,10 @@ def update_circuit_probs(probs, freqs, circuit_budget):

qvec, W = _adjust_qvec_to_be_nonnegative_and_unit_sum(qvec, W, min(qvec), base_tol)

initialTVD = 0.5 * sum(_np.abs(qvec - fvec)) # update current TVD
if initialTVD <= W + tol: # TVD is "in-budget" for this circuit due to adjustment; leave as is
return qvec

#Note: must ensure that A,B,C,D are *disjoint*
fvec_equals_qvec = _np.logical_and(fvec - base_tol <= qvec, qvec <= fvec + base_tol) # fvec == qvec
A = _np.where(_np.logical_and(qvec > fvec + base_tol, fvec > 0))[0]
Expand Down Expand Up @@ -1022,6 +1060,24 @@ def _per_op_wildcard_error_deriv_from_vector(self, wildcard_vector):
ret[i, self.primitive_op_param_index[lbl]] = 1.0
return ret

def _to_nice_serialization(self):
state = super()._to_nice_serialization()
param_index_by_primitive_op = [self.primitive_op_param_index[lbl] for lbl in self.primitive_op_labels]
state.update({'param_index_by_primitive_op': param_index_by_primitive_op,
'trivial_param_mapping': self.trivial_param_mapping})
return state

@classmethod
def _from_nice_serialization(cls, state): # memo holds already de-serialized objects
primitive_op_labels = cls._parse_primitive_op_labels(state['primitive_op_labels'])
if state['trivial_param_mapping']:
primitive_ops = primitive_op_labels
else:
primitive_ops = {lbl: i for lbl, i in zip(primitive_op_labels, state['param_index_by_primitive_op'])}

budget = {lbl: val for lbl, val in zip(primitive_op_labels, state['wildcard_vector'])}
return cls(primitive_ops, budget, state['idle_name'])


class PrimitiveOpsSingleScaleWildcardBudget(PrimitiveOpsWildcardBudgetBase):
"""
Expand Down Expand Up @@ -1090,6 +1146,18 @@ def _per_op_wildcard_error_deriv_from_vector(self, wildcard_vector):
"""
return self.reference_values.reshape((self.num_primitive_ops, 1)).copy()

def _to_nice_serialization(self):
state = super()._to_nice_serialization()
state.update({'reference_values': list(self.reference_values),
'reference_name': self.reference_name})
return state

@classmethod
def _from_nice_serialization(cls, state): # memo holds already de-serialized objects
primitive_op_labels = cls._parse_primitive_op_labels(state['primitive_op_labels'])
alpha = state['wildcard_vector'][0]
return cls(primitive_op_labels, state['reference_values'], alpha, state['idle_name'], state['reference_name'])

@property
def description(self):
"""
Expand Down
19 changes: 10 additions & 9 deletions pygsti/protocols/gst.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,13 +1873,13 @@ def _add_badfit_estimates(results, base_estimate_label, badfit_options,
try:
budget_dict = _compute_wildcard_budget(objfn_cache, mdc_objfn, parameters, badfit_options, printer - 1)
for chain_name, (unmodeled, active_constraint_list) in budget_dict.items():
base_estimate.extra_parameters[chain_name + "_unmodeled_error"] = unmodeled
base_estimate.extra_parameters[chain_name + "_unmodeled_error"] = unmodeled.to_nice_serialization()
base_estimate.extra_parameters[chain_name + "_unmodeled_active_constraints"] \
= active_constraint_list
if len(budget_dict) > 0: # also store first chain info w/empty chain name (convenience)
first_chain = next(iter(budget_dict))
unmodeled, active_constraint_list = budget_dict[first_chain]
base_estimate.extra_parameters["unmodeled_error"] = unmodeled
base_estimate.extra_parameters["unmodeled_error"] = unmodeled.to_nice_serialization()
base_estimate.extra_parameters["unmodeled_active_constraints"] = active_constraint_list
except NotImplementedError as e:
printer.warning("Failed to get wildcard budget - continuing anyway. Error was:\n" + str(e))
Expand All @@ -1897,11 +1897,11 @@ def _add_badfit_estimates(results, base_estimate_label, badfit_options,
budget = _compute_wildcard_budget_1d_model(base_estimate, objfn_cache, mdc_objfn, parameters,
badfit_options, printer - 1)

base_estimate.extra_parameters['wildcard1d' + "_unmodeled_error"] = budget
base_estimate.extra_parameters['wildcard1d' + "_unmodeled_error"] = budget.to_nice_serialization()
sserita marked this conversation as resolved.
Show resolved Hide resolved
base_estimate.extra_parameters['wildcard1d' + "_unmodeled_active_constraints"] \
= None

base_estimate.extra_parameters["unmodeled_error"] = budget
base_estimate.extra_parameters["unmodeled_error"] = budget.to_nice_serialization()
base_estimate.extra_parameters["unmodeled_active_constraints"] = None
except NotImplementedError as e:
printer.warning("Failed to get wildcard budget - continuing anyway. Error was:\n" + str(e))
Expand Down Expand Up @@ -2344,10 +2344,11 @@ def _evaluate_constraints(wv):

circ_ind_max = _np.argmax(percircuit_constraint)
if glob_constraint > 0:
active_constraints['global'] = glob_constraint,
active_constraints['global'] = float(glob_constraint),
if percircuit_constraint[circ_ind_max] > 0:
active_constraints['percircuit'] = (circ_ind_max, global_circuits_to_use[circ_ind_max],
percircuit_constraint[circ_ind_max])
active_constraints['percircuit'] = (int(circ_ind_max), global_circuits_to_use[circ_ind_max].str,
float(percircuit_constraint[circ_ind_max]))
#Note: make sure active_constraints is JSON serializable (this is why we put the circuit *str* in)
else:
if budget_was_optimized:
printer.log((" - Element %.3g is %.3g. This is below %.3g, so trialing snapping to zero"
Expand Down Expand Up @@ -2377,8 +2378,8 @@ def _evaluate_constraints(wv):
if 'global' in active_constraints:
printer.log(" global constraint:" + str(active_constraints['global']))
if 'percircuit' in active_constraints:
_, circuit, constraint_amt = active_constraints['percircuit']
printer.log(" per-circuit constraint:" + circuit.str + " = " + str(constraint_amt))
_, circuit_str, constraint_amt = active_constraints['percircuit']
printer.log(" per-circuit constraint:" + circuit_str + " = " + str(constraint_amt))
else:
printer.log("(no active constraints for " + "--".join(primOp_labels[i]) + ")")
printer.log("")
Expand Down
21 changes: 15 additions & 6 deletions pygsti/report/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pygsti.models.explicitmodel import ExplicitOpModel as _ExplicitOpModel
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from pygsti.objectivefns import objectivefns as _objfns
from pygsti.objectivefns import wildcardbudget as _wildcardbudget
from pygsti.circuits.circuit import Circuit as _Circuit
from pygsti.circuits.circuitlist import CircuitList as _CircuitList
from pygsti.circuits.circuitstructure import PlaquetteGridCircuitStructure as _PlaquetteGridCircuitStructure
Expand Down Expand Up @@ -357,9 +358,13 @@ def _create_master_switchboard(ws, results_dict, confidence_level,
switchBd.eff_ds[d, i] = NA
switchBd.scaled_submxs_dict[d, i] = NA

switchBd.wildcard_budget_optional[d, i] = est.parameters.get("unmodeled_error", None)
wildcard = est.parameters.get("unmodeled_error", None)
if isinstance(wildcard, dict): # assume a serialized budget object
wildcard = _wildcardbudget.WildcardBudget.from_nice_serialization(wildcard)

switchBd.wildcard_budget_optional[d, i] = wildcard
if est.parameters.get("unmodeled_error", None):
switchBd.wildcard_budget[d, i] = est.parameters['unmodeled_error']
switchBd.wildcard_budget[d, i] = wildcard
else:
switchBd.wildcard_budget[d, i] = NA

Expand All @@ -370,7 +375,7 @@ def _create_master_switchboard(ws, results_dict, confidence_level,
est.models['target'])
except AttributeError: # Implicit models don't support everything, like set_all_parameterizations
switchBd.mdl_gaugeinv_ep[d, i] = None
except AssertionError: # if target is badly off, this can fail with an imaginary part assertion
except (ValueError, AssertionError): # if target is badly off, e.g. an imaginary part assertion
switchBd.mdl_gaugeinv_ep[d, i] = None

switchBd.mdl_final[d, i, :] = [est.models.get(l, NA) for l in gauge_opt_labels]
Expand Down Expand Up @@ -1216,10 +1221,13 @@ def construct_standard_report(results, title="auto",
#check if the wildcard budget is an instance
#of the diamond distance model, in which case we
#will add an extra flag/plot to the report.
if (isinstance(est.parameters['unmodeled_error'], PrimitiveOpsSingleScaleWildcardBudget)
and est.parameters['unmodeled_error'].reference_name == 'diamond distance'):
wildcard = est.parameters['unmodeled_error']
if isinstance(wildcard, dict): # assume a serialized budget object
wildcard = _wildcardbudget.WildcardBudget.from_nice_serialization(wildcard)
if (isinstance(wildcard, PrimitiveOpsSingleScaleWildcardBudget)
and wildcard.reference_name == 'diamond distance'):
flags.add('DiamondDistanceWildcard')

if combine_robust:
flags.add('CombineRobust')

Expand Down Expand Up @@ -1506,6 +1514,7 @@ def construct_nqnoise_report(results, title="auto",
'colorBoxPlotKeyPlot': ws.BoxKeyPlot(switchBd.prep_fiducials, switchBd.meas_fiducials),
'bestGatesetGaugeOptParamsTable': ws.GaugeOptParamsTable(switchBd.goparams),
'gramBarPlot': ws.GramMatrixBarPlot(switchBd.ds, switchBd.mdl_target, 10, switchBd.fiducials_tup)
# Note by EGN 11/10/2022 - I don't think 'gramBarPlot' is needed here, maybe just a copy/paste oversight?
}

report_params = {
Expand Down