Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch-backed forward simulation #390

Open
wants to merge 64 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
fbe23e9
resolve deprecation warning
rileyjmurray Jan 17, 2024
985404f
tiny bugfix
rileyjmurray Jan 17, 2024
8f36247
starting point for building the TorchForwardSimulator class
rileyjmurray Jan 17, 2024
842c0f7
notes
rileyjmurray Jan 18, 2024
1b698e6
infrastructure
rileyjmurray Jan 19, 2024
cbc15b5
I understand how I am stuck and will get help
rileyjmurray Jan 24, 2024
b3ac3da
change list comprehension into for-loop in order to simplify setting …
rileyjmurray Jan 25, 2024
73363d1
leave comments describing object inheritance structures for states, p…
rileyjmurray Jan 25, 2024
9983d1b
comments indicating class types of povm-related objects
rileyjmurray Jan 25, 2024
bd345f6
improve readability
rileyjmurray Jan 26, 2024
2e76f32
remove unnecessary dependence of certain Evotypes on trivial Cython b…
rileyjmurray Jan 26, 2024
bd82b41
left out of last commit
rileyjmurray Jan 26, 2024
e158c21
comments explaining that densitymx_slow is really "superket_slow"
rileyjmurray Jan 26, 2024
c6b4d8f
left out of last commit
rileyjmurray Jan 26, 2024
ae73090
remove commented-out functions which I now clearly understand we do n…
rileyjmurray Jan 26, 2024
ffa7ea0
remove abstraction layers in TorchForwardSimulator
rileyjmurray Jan 26, 2024
d787025
remove more abstractions
rileyjmurray Jan 27, 2024
b510b2e
remove references to new TorchLayerRules class and discussion surroun…
rileyjmurray Jan 27, 2024
6fc59dd
make an apparent limitation of TorchForwardSimulator (and I suppose a…
rileyjmurray Jan 27, 2024
6aac2af
remove unused function
rileyjmurray Jan 27, 2024
107b26b
explicitly override the function that iterates over circuits and call…
rileyjmurray Jan 27, 2024
c1fcfc2
get array representations of all quantities as prep work before compu…
rileyjmurray Jan 31, 2024
761496c
use torch to compute circuit probabilities (infrastructure not in pla…
rileyjmurray Feb 1, 2024
abdfdc7
progress toward bypassing explicit calls to _rep fields of various mo…
rileyjmurray Feb 1, 2024
9b56b2a
more progress on modelmember.torch_base(...) pattern
rileyjmurray Feb 1, 2024
0c9b103
demonstrate how we can access povm data through the TPPOVM abstractio…
rileyjmurray Feb 1, 2024
243b757
write basic TPPOVM.torch_base function. Need to modify that function …
rileyjmurray Feb 1, 2024
0bea829
forward simulation codepath that computes gradients seems to work. Ha…
rileyjmurray Feb 1, 2024
b88643a
can build the entire vector of outcome probabilities as a torch Tenso…
rileyjmurray Feb 1, 2024
c1eacb3
make a function that lets us access the torch representation of compu…
rileyjmurray Feb 2, 2024
3ef9502
simplified torch_cache
rileyjmurray Feb 2, 2024
7073544
step toward what we need for torch jacfwd function
rileyjmurray Feb 2, 2024
aa5c4e7
progress toward functional evaluation in TPPOVM.torch_base. Need to a…
rileyjmurray Feb 2, 2024
0bc3736
add a static_torch_base function
rileyjmurray Feb 2, 2024
6658c47
progress toward statelessness
rileyjmurray Feb 2, 2024
852d8a6
more functional
rileyjmurray Feb 2, 2024
b6bc0f0
created (and put to work) a new StatelessModel helper class
rileyjmurray Feb 2, 2024
9855144
I can successfully call jacfwd and get reasonable output. Next step i…
rileyjmurray Feb 2, 2024
f85716b
IT IS ALIVE
rileyjmurray Feb 2, 2024
14f1af4
note some opportunities for improved efficiency
rileyjmurray Feb 2, 2024
2c6be95
simplified StatelessModel and StatelessCircuit
rileyjmurray Feb 3, 2024
23207f7
remove unnecessary comments
rileyjmurray Feb 3, 2024
3a04a31
clean up TorchForwardSimulator
rileyjmurray Feb 3, 2024
6c2e5f3
revert change that helped with debugging once-upon-a-time, but wasn`t…
rileyjmurray Feb 3, 2024
eb79162
Have meaningful comments for classes in evotypes/densitymx_slow/
rileyjmurray Feb 3, 2024
3461335
improve comments for classes in evotypes/densitymx_slow/
rileyjmurray Feb 3, 2024
1cc944c
remove unused function
rileyjmurray Feb 3, 2024
0e2f051
undo change
rileyjmurray Feb 3, 2024
cfa9232
removed unused file
rileyjmurray Feb 3, 2024
cf05d9a
documentation
rileyjmurray Feb 6, 2024
f312b92
remove comment logged as GitHub Issue #397
rileyjmurray Feb 6, 2024
a55efde
unify the API for torch_base and getting necessary ModelMember metadata
rileyjmurray Feb 6, 2024
a8f6145
remove old comments and unused imports. Style tweaks.
rileyjmurray Feb 6, 2024
e72dbad
formally declare the stateless_data and torch_base functions in the M…
rileyjmurray Feb 6, 2024
d2c8d38
reenable commented-out tests in test_forwardsim.py
rileyjmurray Feb 7, 2024
2435a50
gracefully handle when pytorch is not installed
rileyjmurray Feb 7, 2024
2e4c3cf
stash
rileyjmurray Feb 15, 2024
a3ffa68
better workaround for circular imports in type annotations
rileyjmurray May 6, 2024
f5383b9
Create Torchable subclass of ModelMember
rileyjmurray May 7, 2024
ac2e8e7
remove static constant from TorchForwardSimulator class
rileyjmurray May 7, 2024
5a1be5d
docstring changes
rileyjmurray May 7, 2024
1ec6909
docstring changes
rileyjmurray May 7, 2024
957192a
clean up TPState constructor. Add documentation for TPPOVM. Change im…
rileyjmurray May 22, 2024
07537f3
fix handling lack of pytorch
rileyjmurray May 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions pygsti/evotypes/densitymx_slow/effectreps.py
Expand Up @@ -13,12 +13,13 @@
import numpy as _np

# import functools as _functools
from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import matrixtools as _mt


class EffectRep(_basereps.EffectRep):
class EffectRep:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@enielse, any chance you recall the reasoning behind this import structure? I agree with Riley that this doesn't appear to be doing anything (and so is safe to remove) but it is worth double-checking. My guess is that this may be a vestige of future expansion plans that weren't required ultimately.

"""Any representation of an "effect" in the sense of a POVM."""

def __init__(self, state_space):
self.state_space = _StateSpace.cast(state_space)

Expand All @@ -27,6 +28,10 @@ def probability(self, state):


class EffectRepConjugatedState(EffectRep):
"""
A real superket representation of an "effect" in the sense of a POVM.
Internally uses a StateRepDense object to hold the real superket.
"""

def __init__(self, state_rep):
self.state_rep = state_rep
Expand Down
17 changes: 15 additions & 2 deletions pygsti/evotypes/densitymx_slow/opreps.py
Expand Up @@ -17,7 +17,6 @@
from scipy.sparse.linalg import LinearOperator

from .statereps import StateRepDense as _StateRepDense
from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import basistools as _bt
from ...tools import internalgates as _itgs
Expand All @@ -26,7 +25,11 @@
from ...tools import optools as _ot


class OpRep(_basereps.OpRep):
class OpRep:
"""
A real superoperator on Hilbert-Schmidt space.
"""

def __init__(self, state_space):
self.state_space = state_space

Expand All @@ -41,6 +44,10 @@ def adjoint_acton(self, state):
raise NotImplementedError()

def aslinearoperator(self):
"""
Return a SciPy LinearOperator that accepts superket representations of vectors
in Hilbert-Schmidt space and returns a vector of that same representation.
"""
def mv(v):
if v.ndim == 2 and v.shape[1] == 1: v = v[:, 0]
in_state = _StateRepDense(_np.ascontiguousarray(v, 'd'), self.state_space, None)
Expand All @@ -54,6 +61,12 @@ def rmv(v):


class OpRepDenseSuperop(OpRep):
"""
A real superoperator on Hilbert-Schmidt space.
The operator's action (and adjoint action) work with Hermitian matrices
stored as *vectors* in their real superket representations.
"""

def __init__(self, mx, basis, state_space):
state_space = _StateSpace.cast(state_space)
if mx is None:
Expand Down
13 changes: 10 additions & 3 deletions pygsti/evotypes/densitymx_slow/statereps.py
Expand Up @@ -14,7 +14,6 @@

import numpy as _np

from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import basistools as _bt
from ...tools import optools as _ot
Expand All @@ -25,13 +24,17 @@
_fastcalc = None


class StateRep(_basereps.StateRep):
class StateRep:
"""A real superket representation of an element in Hilbert-Schmidt space."""

def __init__(self, data, state_space):
#vec = _np.asarray(vec, dtype='d')
assert(data.dtype == _np.dtype('d'))
self.data = _np.require(data.copy(), requirements=['OWNDATA', 'C_CONTIGUOUS'])
self.state_space = _StateSpace.cast(state_space)
assert(len(self.data) == self.state_space.dim)
ds0 = self.data.shape[0]
assert(ds0 == self.state_space.dim)
assert(ds0 == self.data.size)

def __reduce__(self):
return (StateRep, (self.data, self.state_space), (self.data.flags.writeable,))
Expand Down Expand Up @@ -62,6 +65,10 @@ def __str__(self):


class StateRepDense(StateRep):
"""
An almost-trivial wrapper around StateRep.
Implements the "base" property and defines a trivial "base_has_changed" function.
"""

def __init__(self, data, state_space, basis):
#ignore basis for now (self.basis = basis in future?)
Expand Down
1 change: 1 addition & 0 deletions pygsti/forwardsims/__init__.py
Expand Up @@ -12,6 +12,7 @@

from .forwardsim import ForwardSimulator
from .mapforwardsim import SimpleMapForwardSimulator, MapForwardSimulator
from .torchfwdsim import TorchForwardSimulator, TORCH_ENABLED
from .matrixforwardsim import SimpleMatrixForwardSimulator, MatrixForwardSimulator
from .termforwardsim import TermForwardSimulator
from .weakforwardsim import WeakForwardSimulator
2 changes: 1 addition & 1 deletion pygsti/forwardsims/forwardsim.py
Expand Up @@ -373,7 +373,7 @@ def create_layout(self, circuits, dataset=None, resource_alloc=None,
if 'epp' in array_types:
derivative_dimensions = (self.model.num_params, self.model.num_params)
elif 'ep' in array_types:
derivative_dimensions = (self.model.num_params)
derivative_dimensions = (self.model.num_params,)
else:
derivative_dimensions = tuple()
return _CircuitOutcomeProbabilityArrayLayout.create_from(circuits, self.model, dataset, derivative_dimensions,
Expand Down
237 changes: 237 additions & 0 deletions pygsti/forwardsims/torchfwdsim.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment that a number of the methods defined in this file could use additional work fleshing out docstrings. I think this would also be useful for some of the private methods (for the sake of future us).

@@ -0,0 +1,237 @@
"""
Defines the TorchForwardSimulator class
"""
#***************************************************************************************************
# Copyright 2024, National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights
# in this software.
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
#***************************************************************************************************

from __future__ import annotations
from typing import Tuple, Optional, Dict, TYPE_CHECKING
if TYPE_CHECKING:
from pygsti.baseobjs.label import Label
from pygsti.models.explicitmodel import ExplicitOpModel
from pygsti.circuits.circuit import SeparatePOVMCircuit
from pygsti.layouts.copalayout import CircuitOutcomeProbabilityArrayLayout

from pygsti.modelmembers.torchable import Torchable
from collections import OrderedDict
import warnings as warnings

import numpy as np
try:
import torch
TORCH_ENABLED = True
except ImportError:
TORCH_ENABLED = False

from pygsti.forwardsims.forwardsim import ForwardSimulator


"""Efficiency ideas
* Compute the jacobian in blocks of rows at a time (iterating over the blocks in parallel). Ideally pytorch
would recognize how the computations decompose, but we should check to make sure it does.

* Recycle some of the work in setting up the Jacobian function.
Calling circuit.expand_instruments_and_separate_povm(model, outcomes) inside the StatelessModel constructor
might be expensive. It only need to happen once during an iteration of GST.
"""


class StatelessCircuit:
"""
Helper data structure useful for simulating a specific circuit quantum (including prep,
applying a sequence of gates, and applying a POVM to the output of the last gate).

The forward simulation can only be done when we have access to a dict that maps
pyGSTi Labels to certain PyTorch Tensors.
"""

def __init__(self, spc: SeparatePOVMCircuit):
self.prep_label = spc.circuit_without_povm[0]
self.op_labels = spc.circuit_without_povm[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the point where you'd be making use of this it should generally be the case that the circuit has already had it's prep labels and POVM labels prepended, but it can be the case that circuit_without_povm is length 1 (i.e. only a prep label), which would happen when calling Model.complete_circuit on the empty circuit. So this needs an edge case check.

self.povm_label = spc.povm_label
return


class StatelessModel:
"""
A container for the information in an ExplicitOpModel that's "stateless"
in the sense of object-oriented programming.

Currently, that information is just specifications of the model's
circuits, and model parameter metadata.

StatelessModels have functions to (1) extract stateful data from an
ExplicitOpModel, (2) reformat that data into particular PyTorch
Tensors, and (3) run the forward simulation using that data. There
is also a function that combines (2) and (3).
"""

def __init__(self, model: ExplicitOpModel, layout):
circuits = []
for _, circuit, outcomes in layout.iter_unique_circuits():
expanded_circuit_outcomes = circuit.expand_instruments_and_separate_povm(model, outcomes)
if len(expanded_circuit_outcomes) > 1:
raise NotImplementedError("I don't know what to do with this.")
spc = list(expanded_circuit_outcomes.keys())[0]
c = StatelessCircuit(spc)
circuits.append(c)
self.circuits = circuits

self.param_metadata = []
for lbl, obj in model._iter_parameterized_objs():
assert isinstance(obj, Torchable)
param_type = type(obj)
param_data = (lbl, param_type) + (obj.stateless_data(),)
self.param_metadata.append(param_data)
self.num_params = len(self.param_metadata)
return

def get_free_parameters(self, model: ExplicitOpModel):
"""
Return an ordered dict that maps pyGSTi Labels to PyTorch Tensors.
The Labels correspond to parameterized objects in "model".
The Tensors correspond to the current values of an object's parameters.
For the purposes of forward simulation, we intend that the following
equivalence holds:

model == (self, [dict returned by this function]).

That said, the values in this function's returned dict need to be
formatted by get_torch_cache BEFORE being used in forward simulation.
"""
free_params = OrderedDict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of OrderedDict isn't strictly needed here, but we can refactor that along with other instances at a later date per the currently open github issue.

for i, (lbl, obj) in enumerate(model._iter_parameterized_objs()):
gpind = obj.gpindices_as_array()
vec = obj.to_vector()
vec = torch.from_numpy(vec)
assert int(gpind.size) == int(np.prod(vec.shape))
# ^ a sanity check that we're interpreting the results of obj.to_vector()
# correctly. Future implementations might need us to also keep track of
# the "gpind" variable. Right now we get around NOT using that variable
# by using an OrderedDict and by iterating over parameterized objects in
# the same way that "model"s does.
assert self.param_metadata[i][0] == lbl
# ^ If this check fails then it invalidates our assumptions about how
# we're using OrderedDict objects.
free_params[lbl] = vec
return free_params

def get_torch_cache(self, free_params: OrderedDict[Label, torch.Tensor], grad: bool):
"""
Returns a dict mapping pyGSTi Labels to PyTorch tensors. The dict makes it easy
to simulate a stateful model implied by (self, free_params). It is obtained by
applying invertible transformations --- defined in various ModelMember subclasses
--- on the tensors stored in free_params.

If ``grad`` is True, then the values in the returned dict are preparred for use
in PyTorch's backpropogation functionality. If we want to compute a Jacobian of
circuit outcome probabilities then such functionality is actually NOT needed.
Therefore for purposes of computing Jacobians this should be set to False.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we want to use backprogation for Jacobians? The jacobians for circuit outcome probabilities will generally be wide and short, which I had understood to be the ideal use case for reverse-mode AD.

torch_cache = dict()
for i, fp_val in enumerate(free_params.values()):

if grad: fp_val.requires_grad_(True)
metadata = self.param_metadata[i]

fp_label = metadata[0]
fp_type = metadata[1]
param_t = fp_type.torch_base(metadata[2], fp_val)
torch_cache[fp_label] = param_t

return torch_cache

def circuit_probs(self, torch_cache: Dict[Label, torch.Tensor]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the rationale behind putting this method here instead of in the forward simulator?

probs = []
for c in self.circuits:
superket = torch_cache[c.prep_label]
superops = [torch_cache[ol] for ol in c.op_labels]
povm_mat = torch_cache[c.povm_label]
for superop in superops:
superket = superop @ superket
circuit_probs = povm_mat @ superket
probs.append(circuit_probs)
probs = torch.concat(probs)
return probs

def jac_friendly_circuit_probs(self, *free_params: Tuple[torch.Tensor]):
"""
This function combines parameter reformatting and forward simulation.
It's needed so that we can use PyTorch to compute the Jacobian of
the map from a model's free parameters to circuit outcome probabilities.
"""
assert len(free_params) == len(self.param_metadata) == self.num_params
free_params = {self.param_metadata[i][0] : free_params[i] for i in range(self.num_params)}
torch_cache = self.get_torch_cache(free_params, grad=False)
probs = self.circuit_probs(torch_cache)
return probs


class TorchForwardSimulator(ForwardSimulator):

"""
A forward simulator that leverages automatic differentiation in PyTorch.
"""
def __init__(self, model : Optional[ExplicitOpModel] = None):
if not TORCH_ENABLED:
raise RuntimeError('PyTorch could not be imported.')
self.model = model
super(ForwardSimulator, self).__init__(model)

@staticmethod
def separate_state(model: ExplicitOpModel, layout, grad=False):
slm = StatelessModel(model, layout)
free_params = slm.get_free_parameters(model)
torch_cache = slm.get_torch_cache(free_params, grad)
return slm, torch_cache

@staticmethod
def _check_copa_layout(layout: CircuitOutcomeProbabilityArrayLayout):
# I need to verify some assumptions on what layout.iter_unique_circuits()
# returns. Looking at the implementation of that function, the assumptions
# can be framed in terms of the "layout._element_indicies" OrderedDict.
eind = layout._element_indices
assert isinstance(eind, OrderedDict)
items = iter(eind.items())
k_prev, v_prev = next(items)
assert k_prev == 0
assert v_prev.start == 0
for k, v in items:
assert k == k_prev + 1
assert v.start == v_prev.stop
k_prev = k
v_prev = v
return v_prev.stop

def _bulk_fill_probs(self, array_to_fill, layout, stripped_abstractions: Optional[tuple] = None):
if stripped_abstractions is None:
slm, torch_cache = TorchForwardSimulator.separate_state(self.model, layout)
else:
slm, torch_cache = stripped_abstractions

layout_len = TorchForwardSimulator._check_copa_layout(layout)
probs = slm.circuit_probs(torch_cache)
array_to_fill[:layout_len] = probs.cpu().detach().numpy().flatten()
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why pass for this function and an empty return for the next?


def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill):
slm = StatelessModel(self.model, layout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a fresh StatelessModel for every jacobian call sounds expensive. For _bulk_fill_probs you had an optional stripped_abstractions argument that allowed for reuse of a previously generated one, does that not work here as well?

free_params = slm.get_free_parameters(self.model)
torch_cache = slm.get_torch_cache(free_params, grad=False)
if pr_array_to_fill is not None:
self._bulk_fill_probs(pr_array_to_fill, layout, (slm, torch_cache))

argnums = tuple(range(slm.num_params))
J_func = torch.func.jacfwd(slm.jac_friendly_circuit_probs, argnums=argnums)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See question above about reverse vs. forward mode AD.

free_param_tup = tuple(free_params.values())
J_val = J_func(*free_param_tup)
J_val = torch.column_stack(J_val)
J_np = J_val.cpu().detach().numpy()
array_to_fill[:] = J_np
return