Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch-backed forward simulation #390

Open
wants to merge 62 commits into
base: develop
Choose a base branch
from
Open

PyTorch-backed forward simulation #390

wants to merge 62 commits into from

Conversation

rileyjmurray
Copy link
Collaborator

@rileyjmurray rileyjmurray commented Jan 18, 2024

This PR introduces TorchForwardSimulator, a forward simulator (for computing circuit outcome probabilities) based on PyTorch. It uses automatic differentiation to compute the Jacobian of the map from model parameters to circuit outcome probabilities. In the future we could extend it to do computations on a system's GPU, or to use PyTorch-based optimization algorithms instead of pyGSTi's custom algorithms for MLE.

Approach

My approach required creating a new ModelMember subclass called Torchable. This subclass adds two required functions, called stateless_data and torch_base. Their meanings are given below:

def stateless_data(self) -> Tuple:
"""
Return this ModelMember's data that is considered constant for purposes of model fitting.
Note: the word "stateless" here is used in the sense of object-oriented programming.
"""
raise NotImplementedError()
@staticmethod
def torch_base(sd : Tuple, t_param : Tensor) -> Tensor:
"""
Suppose "obj" is an instance of some Torchable subclass. If we compute
vec = obj.to_vector()
t_param = torch.from_numpy(vec)
sd = obj.stateless_data()
t = type(obj).torch_base(sd, t_param)
then t will be a PyTorch Tensor that represents "obj" in a canonical numerical way.
The meaning of "canonical" is implementation dependent. If type(obj) implements
the ``.base`` attribute, then a reasonable implementation will probably satisfy
np.allclose(obj.base, t.numpy()).
"""
raise NotImplementedError()

In principle, TorchForwardSimulator can handle all models for which constituent parameterized ModelMembers are Torchable. So far I've only extended TPState, FullTPOp, and TPPOVM to be Torchable; these are the classes used in "full TP" GST.

The Python file that contains TorchForwardSimulator also defines two helper classes: StatelessCircuit and StatelessModel. I think it's fine to keep these classes as purely internal implementation-specific constructs for now. Depending on future performance optimizations of TorchForwardSimulator we might want to put them elsewhere in pyGSTi.

What should come after this PR

We should compare performance of TorchForwardSimulator to MapForwardSimulator on problems of interest. There's a chance that the former isn't faster than the latter with the current implementation. If that's the case then I should look at possible performance optimizations specifically inside TorchForwardSimulator.

We should add implementations of stateless_data and torch_base to GST models beyond "Full TP" (in particular I'd like to try CPTP).

Incidental changes

My implementation originally interacted with the following evotype classes

    <class 'pygsti.evotypes.densitymx[_slow].statereps.StateRepDense'>
    <class 'pygsti.evotypes.densitymx[_slow].opreps.OpRepDenseSuperop'>
    <class 'pygsti.evotypes.densitymx[_slow].effectreps.EffectRepConjugatedState'>

When I write [_slow] in the class names above you can put the empty string or just _slow, depending on the default evotype specified in evotypes.py.

To my surprise, I found that interacting with evotypes was neither necessary nor sufficient for what I wanted to accomplish. So while I did make changes in evotypes/densitymx_slow/ to remove unnecessary class inheritances and to add documentation, those changes were only to make life a little easier for future pyGSTi contributors.

…ovms, and gates (as they appear in TorchOpModel._compute_circuit_outcome_probabilities)
…lso SimpleMapForwardSimulator ...) that the dict returned by circuit.expand_instruments_and_separate_povm(...) has at most one element.
…delmembers (needed to construct differentiable torch tensors). Have a new torch_base property of TPState objects. Need such a property for FullTPOp objects. Unclear how to implement for povms, since right now we`re bypassing the POVM abstraction and going directly into the effects abstraction of the circuit.
…n, rather than only through ConjugatedStatePOVMEffect objects associated with a SeparatePOVMCircuit
…vn`t used it to speed up derivative computations yet.
…r before converting to a numpy array and writing to array_to_fill in TorchForwardSimulator._bulk_fill_probs_block.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to resolve a deprecation warning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is just to improve readability.

@rileyjmurray rileyjmurray marked this pull request as ready for review February 7, 2024 13:54
@rileyjmurray rileyjmurray requested review from a team as code owners February 7, 2024 13:54
@sserita
Copy link
Contributor

sserita commented Feb 7, 2024

I'm giving a few quick comments here because I think this PR will actually take me some time to get through.

The main thing I am concerned and thinking about is the choice to extend the ModelMember class as opposed to adjusting the evotypes. The general purpose of the split between modelmembers and evotypes is so that we don't have to go through and implement these abstract methods in all modelmembers - we can make the change in the evotype and then any modelmember works.
It is totally possible that extending the API is the best way to do this. The TermSimulator follows a similar pattern where the API extensions are the cleanest way to do it. But I'll probably spend some time thinking about whether this is way we want to implement this. Probably a point of discussion for us in our dev meetings in the upcoming weeks.

@rileyjmurray
Copy link
Collaborator Author

@sserita, regarding

The main thing I am concerned and thinking about is the choice to extend the ModelMember class as opposed to adjusting the evotypes. The general purpose of the split between modelmembers and evotypes is so that we don't have to go through and implement these abstract methods in all modelmembers - we can make the change in the evotype and then any modelmember works.

Unfortunately there's no way to do this just through evotypes. Using pytorch's AD capabilities requires knowing the free parameters in a modelmember and how those free parameters map to the common parameterization-agnostic representation (i.e., representations in evotypes).

Comment on lines 26 to 32
# Below: variables for type annotations.
# We have to create variable aliases rather than importing the types
# directly, since importing the types would cause circular imports.
Label = TypeVar('Label')
ExplicitOpModel = TypeVar('ExplicitOpModel')
SeparatePOVMCircuit = TypeVar('SeparatePOVMCircuit')
CircuitOutcomeProbabilityArrayLayout = TypeVar('CircuitOutcomeProbabilityArrayLayout')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to use from __future__ import annotations to avoid use of TypeVars.

@rileyjmurray
Copy link
Collaborator Author

Notes from today's meeting:

  • I removed empty base classes for OpRep, EffectRep, StateRep. I should make sure we aren't doing type checking elsewhere in pyGSTi that would be broken by this. (There's a chance we are and it simply isn't showing up in current tests.)

@sserita sserita added this to the 0.9.13 milestone Apr 2, 2024
@rileyjmurray
Copy link
Collaborator Author

@coreyostrove, @sserita, @enielse: this is ready for review.

Comment on lines +88 to +93
first_basis_vec = torch.zeros(size=(1, dim), dtype=torch.double)
first_basis_vec[0,0] = dim ** 0.25
t_param_mat = t_param.reshape((num_effects - 1, dim))
t_func = first_basis_vec - t_param_mat.sum(axis=0, keepdim=True)
t = torch.row_stack((t_param_mat, t_func))
return t
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coreyostrove, @sserita: is the dim ** 0.25 scale appropriate when the underlying vector space is something other than "tensor product of qubit space"?

def torch_base(sd: Tuple[int], t_param: _Torchable.Tensor) -> _Torchable.Tensor:
torch = _Torchable.torch_handle
dim = sd[0]
t_const = (dim ** -0.25) * torch.ones(1, dtype=torch.double)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coreyostrove, @sserita: same question as above. Is the scale factor of dim ** -0.25 appropriate for general spaces?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants