Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch-backed forward simulation #390

Open
wants to merge 64 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
fbe23e9
resolve deprecation warning
rileyjmurray Jan 17, 2024
985404f
tiny bugfix
rileyjmurray Jan 17, 2024
8f36247
starting point for building the TorchForwardSimulator class
rileyjmurray Jan 17, 2024
842c0f7
notes
rileyjmurray Jan 18, 2024
1b698e6
infrastructure
rileyjmurray Jan 19, 2024
cbc15b5
I understand how I am stuck and will get help
rileyjmurray Jan 24, 2024
b3ac3da
change list comprehension into for-loop in order to simplify setting …
rileyjmurray Jan 25, 2024
73363d1
leave comments describing object inheritance structures for states, p…
rileyjmurray Jan 25, 2024
9983d1b
comments indicating class types of povm-related objects
rileyjmurray Jan 25, 2024
bd345f6
improve readability
rileyjmurray Jan 26, 2024
2e76f32
remove unnecessary dependence of certain Evotypes on trivial Cython b…
rileyjmurray Jan 26, 2024
bd82b41
left out of last commit
rileyjmurray Jan 26, 2024
e158c21
comments explaining that densitymx_slow is really "superket_slow"
rileyjmurray Jan 26, 2024
c6b4d8f
left out of last commit
rileyjmurray Jan 26, 2024
ae73090
remove commented-out functions which I now clearly understand we do n…
rileyjmurray Jan 26, 2024
ffa7ea0
remove abstraction layers in TorchForwardSimulator
rileyjmurray Jan 26, 2024
d787025
remove more abstractions
rileyjmurray Jan 27, 2024
b510b2e
remove references to new TorchLayerRules class and discussion surroun…
rileyjmurray Jan 27, 2024
6fc59dd
make an apparent limitation of TorchForwardSimulator (and I suppose a…
rileyjmurray Jan 27, 2024
6aac2af
remove unused function
rileyjmurray Jan 27, 2024
107b26b
explicitly override the function that iterates over circuits and call…
rileyjmurray Jan 27, 2024
c1fcfc2
get array representations of all quantities as prep work before compu…
rileyjmurray Jan 31, 2024
761496c
use torch to compute circuit probabilities (infrastructure not in pla…
rileyjmurray Feb 1, 2024
abdfdc7
progress toward bypassing explicit calls to _rep fields of various mo…
rileyjmurray Feb 1, 2024
9b56b2a
more progress on modelmember.torch_base(...) pattern
rileyjmurray Feb 1, 2024
0c9b103
demonstrate how we can access povm data through the TPPOVM abstractio…
rileyjmurray Feb 1, 2024
243b757
write basic TPPOVM.torch_base function. Need to modify that function …
rileyjmurray Feb 1, 2024
0bea829
forward simulation codepath that computes gradients seems to work. Ha…
rileyjmurray Feb 1, 2024
b88643a
can build the entire vector of outcome probabilities as a torch Tenso…
rileyjmurray Feb 1, 2024
c1eacb3
make a function that lets us access the torch representation of compu…
rileyjmurray Feb 2, 2024
3ef9502
simplified torch_cache
rileyjmurray Feb 2, 2024
7073544
step toward what we need for torch jacfwd function
rileyjmurray Feb 2, 2024
aa5c4e7
progress toward functional evaluation in TPPOVM.torch_base. Need to a…
rileyjmurray Feb 2, 2024
0bc3736
add a static_torch_base function
rileyjmurray Feb 2, 2024
6658c47
progress toward statelessness
rileyjmurray Feb 2, 2024
852d8a6
more functional
rileyjmurray Feb 2, 2024
b6bc0f0
created (and put to work) a new StatelessModel helper class
rileyjmurray Feb 2, 2024
9855144
I can successfully call jacfwd and get reasonable output. Next step i…
rileyjmurray Feb 2, 2024
f85716b
IT IS ALIVE
rileyjmurray Feb 2, 2024
14f1af4
note some opportunities for improved efficiency
rileyjmurray Feb 2, 2024
2c6be95
simplified StatelessModel and StatelessCircuit
rileyjmurray Feb 3, 2024
23207f7
remove unnecessary comments
rileyjmurray Feb 3, 2024
3a04a31
clean up TorchForwardSimulator
rileyjmurray Feb 3, 2024
6c2e5f3
revert change that helped with debugging once-upon-a-time, but wasn`t…
rileyjmurray Feb 3, 2024
eb79162
Have meaningful comments for classes in evotypes/densitymx_slow/
rileyjmurray Feb 3, 2024
3461335
improve comments for classes in evotypes/densitymx_slow/
rileyjmurray Feb 3, 2024
1cc944c
remove unused function
rileyjmurray Feb 3, 2024
0e2f051
undo change
rileyjmurray Feb 3, 2024
cfa9232
removed unused file
rileyjmurray Feb 3, 2024
cf05d9a
documentation
rileyjmurray Feb 6, 2024
f312b92
remove comment logged as GitHub Issue #397
rileyjmurray Feb 6, 2024
a55efde
unify the API for torch_base and getting necessary ModelMember metadata
rileyjmurray Feb 6, 2024
a8f6145
remove old comments and unused imports. Style tweaks.
rileyjmurray Feb 6, 2024
e72dbad
formally declare the stateless_data and torch_base functions in the M…
rileyjmurray Feb 6, 2024
d2c8d38
reenable commented-out tests in test_forwardsim.py
rileyjmurray Feb 7, 2024
2435a50
gracefully handle when pytorch is not installed
rileyjmurray Feb 7, 2024
2e4c3cf
stash
rileyjmurray Feb 15, 2024
a3ffa68
better workaround for circular imports in type annotations
rileyjmurray May 6, 2024
f5383b9
Create Torchable subclass of ModelMember
rileyjmurray May 7, 2024
ac2e8e7
remove static constant from TorchForwardSimulator class
rileyjmurray May 7, 2024
5a1be5d
docstring changes
rileyjmurray May 7, 2024
1ec6909
docstring changes
rileyjmurray May 7, 2024
957192a
clean up TPState constructor. Add documentation for TPPOVM. Change im…
rileyjmurray May 22, 2024
07537f3
fix handling lack of pytorch
rileyjmurray May 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions pygsti/evotypes/densitymx_slow/effectreps.py
Expand Up @@ -13,12 +13,13 @@
import numpy as _np

# import functools as _functools
from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import matrixtools as _mt


class EffectRep(_basereps.EffectRep):
class EffectRep:
"""Any representation of an "effect" in the sense of a POVM."""

def __init__(self, state_space):
self.state_space = _StateSpace.cast(state_space)

Expand All @@ -27,6 +28,10 @@ def probability(self, state):


class EffectRepConjugatedState(EffectRep):
"""
A real superket representation of an "effect" in the sense of a POVM.
Internally uses a StateRepDense object to hold the real superket.
"""

def __init__(self, state_rep):
self.state_rep = state_rep
Expand Down
17 changes: 15 additions & 2 deletions pygsti/evotypes/densitymx_slow/opreps.py
Expand Up @@ -17,7 +17,6 @@
from scipy.sparse.linalg import LinearOperator

from .statereps import StateRepDense as _StateRepDense
from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import basistools as _bt
from ...tools import internalgates as _itgs
Expand All @@ -26,7 +25,11 @@
from ...tools import optools as _ot


class OpRep(_basereps.OpRep):
class OpRep:
"""
A real superoperator on Hilbert-Schmidt space.
"""

def __init__(self, state_space):
self.state_space = state_space

Expand All @@ -41,6 +44,10 @@ def adjoint_acton(self, state):
raise NotImplementedError()

def aslinearoperator(self):
"""
Return a SciPy LinearOperator that accepts superket representations of vectors
in Hilbert-Schmidt space and returns a vector of that same representation.
"""
def mv(v):
if v.ndim == 2 and v.shape[1] == 1: v = v[:, 0]
in_state = _StateRepDense(_np.ascontiguousarray(v, 'd'), self.state_space, None)
Expand All @@ -54,6 +61,12 @@ def rmv(v):


class OpRepDenseSuperop(OpRep):
"""
A real superoperator on Hilbert-Schmidt space.
The operator's action (and adjoint action) work with Hermitian matrices
stored as *vectors* in their real superket representations.
"""

def __init__(self, mx, basis, state_space):
state_space = _StateSpace.cast(state_space)
if mx is None:
Expand Down
13 changes: 10 additions & 3 deletions pygsti/evotypes/densitymx_slow/statereps.py
Expand Up @@ -14,7 +14,6 @@

import numpy as _np

from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import basistools as _bt
from ...tools import optools as _ot
Expand All @@ -25,13 +24,17 @@
_fastcalc = None


class StateRep(_basereps.StateRep):
class StateRep:
"""A real superket representation of an element in Hilbert-Schmidt space."""

def __init__(self, data, state_space):
#vec = _np.asarray(vec, dtype='d')
assert(data.dtype == _np.dtype('d'))
self.data = _np.require(data.copy(), requirements=['OWNDATA', 'C_CONTIGUOUS'])
self.state_space = _StateSpace.cast(state_space)
assert(len(self.data) == self.state_space.dim)
ds0 = self.data.shape[0]
assert(ds0 == self.state_space.dim)
assert(ds0 == self.data.size)

def __reduce__(self):
return (StateRep, (self.data, self.state_space), (self.data.flags.writeable,))
Expand Down Expand Up @@ -62,6 +65,10 @@ def __str__(self):


class StateRepDense(StateRep):
"""
An almost-trivial wrapper around StateRep.
Implements the "base" property and defines a trivial "base_has_changed" function.
"""

def __init__(self, data, state_space, basis):
#ignore basis for now (self.basis = basis in future?)
Expand Down
1 change: 1 addition & 0 deletions pygsti/forwardsims/__init__.py
Expand Up @@ -12,6 +12,7 @@

from .forwardsim import ForwardSimulator
from .mapforwardsim import SimpleMapForwardSimulator, MapForwardSimulator
from .torchfwdsim import TorchForwardSimulator
from .matrixforwardsim import SimpleMatrixForwardSimulator, MatrixForwardSimulator
from .termforwardsim import TermForwardSimulator
from .weakforwardsim import WeakForwardSimulator
2 changes: 1 addition & 1 deletion pygsti/forwardsims/forwardsim.py
Expand Up @@ -373,7 +373,7 @@ def create_layout(self, circuits, dataset=None, resource_alloc=None,
if 'epp' in array_types:
derivative_dimensions = (self.model.num_params, self.model.num_params)
elif 'ep' in array_types:
derivative_dimensions = (self.model.num_params)
derivative_dimensions = (self.model.num_params,)
else:
derivative_dimensions = tuple()
return _CircuitOutcomeProbabilityArrayLayout.create_from(circuits, self.model, dataset, derivative_dimensions,
Expand Down
251 changes: 251 additions & 0 deletions pygsti/forwardsims/torchfwdsim.py
@@ -0,0 +1,251 @@
"""
Defines the TorchForwardSimulator class
"""
#***************************************************************************************************
# Copyright 2024, National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights
# in this software.
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
#***************************************************************************************************

from collections import OrderedDict
import warnings as warnings
from typing import Tuple, Optional, TypeVar, Union, List, Dict
import importlib as _importlib
import warnings as _warnings
from pygsti.tools import slicetools as _slct

import numpy as np
import scipy.linalg as la
try:
import torch
TORCH_ENABLED = True
except ImportError:
TORCH_ENABLED = False

from pygsti.forwardsims.forwardsim import ForwardSimulator

# Below: imports only needed for typehints
from pygsti.circuits import Circuit
from pygsti.baseobjs.resourceallocation import ResourceAllocation
Label = TypeVar('Label')
ExplicitOpModel = TypeVar('ExplicitOpModel')
SeparatePOVMCircuit = TypeVar('SeparatePOVMCircuit')
CircuitOutcomeProbabilityArrayLayout = TypeVar('CircuitOutcomeProbabilityArrayLayout')
# ^ declare to avoid circular references


"""
Proposal:
There are lots of places where we use np.dot in the codebase.
I think we're much better off replacing with the @ operator
unless we're using the "out" keyword of np.dot. Reason being:
different classes of ndarray-like objects (like pytorch Tensors!)
overload @ in whatever way that they need.
"""

"""Efficiency ideas
* Compute the jacobian in blocks of rows at a time (iterating over the blocks in parallel). Ideally pytorch
would recognize how the computations decompose, but we should check to make sure it does.

* Recycle some of the work in setting up the Jacobian function.
Calling circuit.expand_instruments_and_separate_povm(model, outcomes) inside the StatelessModel constructor
might be expensive. It only need to happen once during an iteration of GST.
"""

class StatelessCircuit:

def __init__(self, spc: SeparatePOVMCircuit):
self.prep_label = spc.circuit_without_povm[0]
self.op_labels = spc.circuit_without_povm[1:]
self.povm_label = spc.povm_label
return


class StatelessModel:

def __init__(self, model: ExplicitOpModel, layout):
circuits = []
for _, circuit, outcomes in layout.iter_unique_circuits():
expanded_circuit_outcomes = circuit.expand_instruments_and_separate_povm(model, outcomes)
if len(expanded_circuit_outcomes) > 1:
raise NotImplementedError("I don't know what to do with this.")
spc = list(expanded_circuit_outcomes.keys())[0]
c = StatelessCircuit(spc)
circuits.append(c)
self.circuits = circuits

self.param_metadata = []
for lbl, obj in model._iter_parameterized_objs():
param_type = type(obj)
typestr = str(param_type)
if 'TPPOVM' in typestr:
param_data = (lbl, param_type, len(obj), obj.dim)
elif 'FullTPOp' in typestr:
param_data = (lbl, param_type, obj.dim)
elif 'TPState' in typestr:
param_data = (lbl, param_type, obj.dim)
else:
raise ValueError()
self.param_metadata.append(param_data)
self.num_params = len(self.param_metadata)
return

def get_free_parameters(self, model: ExplicitOpModel):
d = OrderedDict()
for i, (lbl, obj) in enumerate(model._iter_parameterized_objs()):
gpind = obj.gpindices_as_array()
vec = obj.to_vector()
vec = torch.from_numpy(vec)
assert int(gpind.size) == int(np.prod(vec.shape))
assert self.param_metadata[i][0] == lbl
d[lbl] = vec
return d

def get_torch_cache(self, free_params: OrderedDict[Label, torch.Tensor], grad: bool):
torch_cache = dict()
for i, fp_val in enumerate(free_params.values()):

if grad: fp_val.requires_grad_(True)
metadata = self.param_metadata[i]
fp_label = metadata[0]
fp_type = metadata[1]
fp_tstr = str(fp_type)

if ('FullTPOp' in fp_tstr) or ('TPState' in fp_tstr):
param_t = fp_type.torch_base(metadata[2], fp_val)
elif 'TPPOVM' in fp_tstr:
param_t = fp_type.torch_base(metadata[2], metadata[3], fp_val)
else:
raise ValueError()
torch_cache[fp_label] = param_t

return torch_cache

def circuit_probs(self, torch_cache: Dict[Label, torch.Tensor]):
probs = []
for c in self.circuits:
superket = torch_cache[c.prep_label]
superops = [torch_cache[ol] for ol in c.op_labels]
povm_mat = torch_cache[c.povm_label]
for superop in superops:
superket = superop @ superket
circuit_probs = povm_mat @ superket
probs.append(circuit_probs)
probs = torch.concat(probs)
return probs

def functional_circuit_probs(self, *free_params: Tuple[torch.Tensor]):
assert len(free_params) == len(self.param_metadata) == self.num_params
free_params = {self.param_metadata[i][0] : free_params[i] for i in range(self.num_params)}
torch_cache = self.get_torch_cache(free_params, grad=False)
probs = self.circuit_probs(torch_cache)
return probs


class TorchForwardSimulator(ForwardSimulator):
"""
A forward simulator that leverages automatic differentiation in PyTorch.
"""
def __init__(self, model : Optional[ExplicitOpModel] = None):
if not TORCH_ENABLED:
raise RuntimeError('PyTorch could not be imported.')
self.model = model
super(ForwardSimulator, self).__init__(model)

@staticmethod
def separate_state(model: ExplicitOpModel, layout, grad=False):
slm = StatelessModel(model, layout)
free_params = slm.get_free_parameters(model)
torch_cache = slm.get_torch_cache(free_params, grad)
return slm, torch_cache

@staticmethod
def _check_copa_layout(layout: CircuitOutcomeProbabilityArrayLayout):
# I need to verify some assumptions on what layout.iter_unique_circuits()
# returns. Looking at the implementation of that function, the assumptions
# can be framed in terms of the "layout._element_indicies" OrderedDict.
eind = layout._element_indices
assert isinstance(eind, OrderedDict)
items = iter(eind.items())
k_prev, v_prev = next(items)
assert k_prev == 0
assert v_prev.start == 0
for k, v in items:
assert k == k_prev + 1
assert v.start == v_prev.stop
k_prev = k
v_prev = v
return v_prev.stop

def _bulk_fill_probs(self, array_to_fill, layout, stripped_abstractions: Optional[tuple] = None):
if stripped_abstractions is None:
slm, torch_cache = TorchForwardSimulator.separate_state(self.model, layout)
else:
slm, torch_cache = stripped_abstractions

layout_len = TorchForwardSimulator._check_copa_layout(layout)
probs = slm.circuit_probs(torch_cache)
array_to_fill[:layout_len] = probs.cpu().detach().numpy().flatten()
pass

def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill):
slm = StatelessModel(self.model, layout)
free_params = slm.get_free_parameters(self.model)
torch_cache = slm.get_torch_cache(free_params, grad=False)
if pr_array_to_fill is not None:
self._bulk_fill_probs(pr_array_to_fill, layout, (slm, torch_cache))

argnums = tuple(range(slm.num_params))
J_func = torch.func.jacfwd(slm.functional_circuit_probs, argnums=argnums)
free_param_tup = tuple(free_params.values())
J_val = J_func(*free_param_tup)
J_val = torch.column_stack(J_val)
J_np = J_val.cpu().detach().numpy()
array_to_fill[:] = J_np
return


"""
Running GST produces the following traceback if I set a breakpoint inside the
loop over expanded_circuit_outcomes.items() in self._compute_circuit_outcome_probabilities(...).

I think something's happening where accessing the objects here (via the debugger)
makes some object set "self.dirty=True" for the ComplementPOVMEffect.

UPDATE
The problem shows up when we try to access effect.base for some FullPOVMEffect object "effect".
CONFIRMED
FullPOVMEffect resolves an attempt to access to .base attribute by a default implementation
in its DenseEffectInterface subclass. The last thing that function does is set
self.dirty = True.

pyGSTi/pygsti/forwardsims/forwardsim.py:562: in _bulk_fill_probs_block
self._compute_circuit_outcome_probabilities(array_to_fill[element_indices], circuit,
pyGSTi/pygsti/forwardsims/torchfwdsim.py:177: in _compute_circuit_outcome_probabilities
if povmrep is None:
pyGSTi/pygsti/forwardsims/torchfwdsim.py:177: in <listcomp>
if povmrep is None:
pyGSTi/pygsti/models/model.py:1479: in circuit_layer_operator
self._clean_paramvec()
pyGSTi/pygsti/models/model.py:679: in _clean_paramvec
clean_obj(obj, lbl)
pyGSTi/pygsti/models/model.py:675: in clean_obj
clean_obj(subm, _Label(lbl.name + ":%d" % i, lbl.sslbls))
pyGSTi/pygsti/models/model.py:676: in clean_obj
clean_single_obj(obj, lbl)
pyGSTi/pygsti/models/model.py:666: in clean_single_obj
w = obj.to_vector()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <pygsti.modelmembers.povms.complementeffect.ComplementPOVMEffect object at 0x2a79e31f0>

def to_vector(self):
'''<Riley removed comment block>'''
> raise ValueError(("ComplementPOVMEffect.to_vector() should never be called"
" - use TPPOVM.to_vector() instead"))
E ValueError: ComplementPOVMEffect.to_vector() should never be called - use TPPOVM.to_vector() instead

"""