Skip to content

Commit

Permalink
Adding two options for the MPS simulator: max_bond and method (#3793)
Browse files Browse the repository at this point in the history
The ultimate goal is to truncate the SVD in a way that is closer to the paper. I think what they do is fix the maximum number of singular values.

Doing so means that the estimation of the fidelity reduction must be done at every step (and could be more accurate). However, this requires an option that is not yet part of the Quimb release, so left as a TODO.

Finally, since the options were getting numerous, I decided to factor out the options in a separate class.
  • Loading branch information
tonybruguier committed Feb 17, 2021
1 parent 6aa46d7 commit 6bad7bf
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 52 deletions.
1 change: 1 addition & 0 deletions cirq/contrib/quimb/__init__.py
Expand Up @@ -18,6 +18,7 @@
)

from cirq.contrib.quimb.mps_simulator import (
MPSOptions,
MPSSimulator,
MPSSimulatorStepResult,
MPSState,
Expand Down
94 changes: 53 additions & 41 deletions cirq/contrib/quimb/mps_simulator.py
Expand Up @@ -21,6 +21,7 @@
import math
from typing import Any, Dict, List, Iterator, Optional, Sequence, Set

import dataclasses
import numpy as np
import quimb.tensor as qtn

Expand All @@ -29,35 +30,44 @@
from cirq.sim import simulator


@dataclasses.dataclass(frozen=True)
class MPSOptions:
# Some of these parameters are fed directly to Quimb so refer to the documentation for detail:
# https://quimb.readthedocs.io/en/latest/_autosummary/ \
# quimb.tensor.tensor_core.html#quimb.tensor.tensor_core.tensor_split

# How to split the tensor. Refer to the Quimb documentation for the exact meaning.
method: str = 'svds'
# If integer, the maxmimum number of singular values to keep, regardless of ``cutoff``.
max_bond: Optional[int] = None
# Method with which to apply the cutoff threshold. Refer to the Quimb documentation.
cutoff_mode: str = 'rsum2'
# The threshold below which to discard singular values. Refer to the Quimb documentation.
cutoff: float = 1e-6
# Because the computation is approximate, the sum of the probabilities is not 1.0. This
# parameter is the absolute deviation from 1.0 that is allowed.
sum_prob_atol: float = 1e-3


class MPSSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
"""An efficient simulator for MPS circuits."""

def __init__(
self,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
rsum2_cutoff: float = 1e-3,
sum_prob_atol: float = 1e-3,
simulation_options: 'cirq.contrib.quimb.mps_simulator.MPSOptions' = MPSOptions(),
grouping: Optional[Dict['cirq.Qid', int]] = None,
):
"""Creates instance of `MPSSimulator`.
Args:
seed: The random seed to use for this simulator.
rsum2_cutoff: We drop singular values so that the sum of the
square of the dropped singular values divided by the sum of the
square of all the singular values is less than rsum2_cutoff.
This is related to the fidelity of the computation. If we have
N 2D gates, then the estimated fidelity is
(1 - rsum2_cutoff) ** N.
sum_prob_atol: Because the computation is approximate, the sum of
the probabilities is not 1.0. This parameter is the absolute
deviation from 1.0 that is allowed.
simulation_options: Numerical options for the simulation.
grouping: How to group qubits together, if None all are individual.
"""
self.init = True
self._prng = value.parse_random_state(seed)
self.rsum2_cutoff = rsum2_cutoff
self.sum_prob_atol = sum_prob_atol
self.simulation_options = simulation_options
self.grouping = grouping

def _base_iterator(
Expand Down Expand Up @@ -85,8 +95,7 @@ def _base_iterator(
measurements={},
state=MPSState(
qubit_map,
self.rsum2_cutoff,
self.sum_prob_atol,
self.simulation_options,
self.grouping,
initial_state=initial_state,
),
Expand All @@ -95,8 +104,7 @@ def _base_iterator(

state = MPSState(
qubit_map,
self.rsum2_cutoff,
self.sum_prob_atol,
self.simulation_options,
self.grouping,
initial_state=initial_state,
)
Expand Down Expand Up @@ -300,24 +308,15 @@ class MPSState:
def __init__(
self,
qubit_map: Dict['cirq.Qid', int],
rsum2_cutoff: float,
sum_prob_atol: float,
simulation_options: 'cirq.contrib.quimb.mps_simulator.MPSOptions' = MPSOptions(),
grouping: Optional[Dict['cirq.Qid', int]] = None,
initial_state: int = 0,
):
"""Creates and MPSState
Args:
qubit_map: A map from Qid to an integer that uniquely identifies it.
rsum2_cutoff: We drop singular values so that the sum of the
square of the dropped singular values divided by the sum of the
square of all the singular values is less than rsum2_cutoff.
This is related to the fidelity of the computation. If we have
N 2D gates, then the estimated fidelity is
(1 - rsum2_cutoff) ** N.
sum_prob_atol: Because the computation is approximate, the sum of
the probabilities is not 1.0. This parameter is the absolute
deviation from 1.0 that is allowed.
simulation_options: Numerical options for the simulation.
grouping: How to group qubits together, if None all are individual.
initial_state: An integer representing the initial state.
"""
Expand Down Expand Up @@ -354,9 +353,8 @@ def __init__(
n = self.grouping[qubit]
self.M[n] @= qtn.Tensor(x, inds=(self.i_str(i),))
initial_state = initial_state // d
self.rsum2_cutoff = rsum2_cutoff
self.sum_prob_atol = sum_prob_atol
self.num_svd_splits = 0
self.simulation_options = simulation_options
self.estimated_gate_error_list: List[float] = []

def i_str(self, i: int) -> str:
# Returns the index name for the i'th qid.
Expand All @@ -374,12 +372,12 @@ def __str__(self) -> str:
return str(qtn.TensorNetwork(self.M))

def _value_equality_values_(self) -> Any:
return self.qubit_map, self.M, self.rsum2_cutoff, self.sum_prob_atol, self.grouping
return self.qubit_map, self.M, self.simulation_options, self.grouping

def copy(self) -> 'MPSState':
state = MPSState(self.qubit_map, self.rsum2_cutoff, self.sum_prob_atol, self.grouping)
state = MPSState(self.qubit_map, self.simulation_options, self.grouping)
state.M = [x.copy() for x in self.M]
state.num_svd_splits = self.num_svd_splits
state.estimated_gate_error_list = self.estimated_gate_error_list
return state

def state_vector(self) -> np.ndarray:
Expand Down Expand Up @@ -465,7 +463,6 @@ def apply_unitary(self, op: 'cirq.Operation'):
{new_inds[0]: old_inds[0], new_inds[1]: old_inds[1]}
)
else:
self.num_svd_splits += 1
# This is the index on which we do the contraction. We need to add it iff it's
# the first time that we do the joining for that specific pair.
mu_ind = self.mu_str(n, p)
Expand All @@ -479,13 +476,27 @@ def apply_unitary(self, op: 'cirq.Operation'):
left_inds = tuple(set(T.inds) & set(self.M[n].inds)) + (new_inds[0],)
X, Y = T.split(
left_inds,
cutoff=self.rsum2_cutoff,
cutoff_mode='rsum2',
method=self.simulation_options.method,
max_bond=self.simulation_options.max_bond,
cutoff=self.simulation_options.cutoff,
cutoff_mode=self.simulation_options.cutoff_mode,
get='tensors',
absorb='both',
bond_ind=mu_ind,
)

# Equations (13), (14), and (15):
# TODO(tonybruguier): When Quimb 2.0.0 is released, the split()
# function should have a 'renorm' that, when set to None, will
# allow to compute e_n exactly as:
# np.sum(abs((X @ Y).data) ** 2).real / np.sum(abs(T) ** 2).real
#
# The renormalization would then have to be done manually.
#
# However, for now, e_n are just the estimated value.
e_n = self.simulation_options.cutoff
self.estimated_gate_error_list.append(e_n)

self.M[n] = X.reindex({new_inds[0]: old_inds[0]})
self.M[p] = Y.reindex({new_inds[1]: old_inds[1]})
else:
Expand All @@ -505,14 +516,15 @@ def estimation_stats(self):

# The computation below is done for numerical stability, instead of directly using the
# formula:
# estimated_fidelity = (1 - self.rsum2_cutoff) ** self.num_svd_splits
estimated_fidelity = 1.0 + np.expm1(np.log1p(-self.rsum2_cutoff) * self.num_svd_splits)
# estimated_fidelity = \prod_i (1 - estimated_gate_error_list_i)
estimated_fidelity = 1.0 + np.expm1(
sum(np.log1p(-x) for x in self.estimated_gate_error_list)
)
estimated_fidelity = round(estimated_fidelity, ndigits=3)

return {
"num_coefs_used": num_coefs_used,
"memory_bytes": memory_bytes,
"num_svd_splits": self.num_svd_splits,
"estimated_fidelity": estimated_fidelity,
}

Expand Down Expand Up @@ -544,7 +556,7 @@ def perform_measurement(

# Because the computation is approximate, the probabilities do not
# necessarily add up to 1.0, and thus we re-normalize them.
if abs(sum_probs - 1.0) > self.sum_prob_atol:
if abs(sum_probs - 1.0) > self.simulation_options.sum_prob_atol:
raise ValueError('Sum of probabilities exceeds tolerance: {}'.format(sum_probs))
norm_probs = [x / sum_probs for x in probs]

Expand Down
47 changes: 36 additions & 11 deletions cirq/contrib/quimb/mps_simulator_test.py
Expand Up @@ -164,7 +164,9 @@ def test_probs_dont_sum_up_to_one():
q0 = cirq.NamedQid('q0', dimension=2)
circuit = cirq.Circuit(cirq.measure(q0))

simulator = ccq.mps_simulator.MPSSimulator(sum_prob_atol=-0.5)
simulator = ccq.mps_simulator.MPSSimulator(
simulation_options=ccq.mps_simulator.MPSOptions(sum_prob_atol=-0.5)
)

with pytest.raises(ValueError, match="Sum of probabilities exceeds tolerance"):
simulator.run(circuit, repetitions=1)
Expand Down Expand Up @@ -254,7 +256,7 @@ def test_measurement_str():
def test_trial_result_str():
q0 = cirq.LineQubit(0)
final_simulator_state = ccq.mps_simulator.MPSState(
qubit_map={q0: 0}, rsum2_cutoff=1e-3, sum_prob_atol=1e-3
qubit_map={q0: 0}, simulation_options=ccq.mps_simulator.MPSOptions()
)
assert (
str(
Expand All @@ -273,7 +275,7 @@ def test_trial_result_str():

def test_empty_step_result():
q0 = cirq.LineQubit(0)
state = ccq.mps_simulator.MPSState(qubit_map={q0: 0}, rsum2_cutoff=1e-3, sum_prob_atol=1e-3)
state = ccq.mps_simulator.MPSState(qubit_map={q0: 0})
step_result = ccq.mps_simulator.MPSSimulatorStepResult(state, measurements={'0': [1]})
assert (
str(step_result)
Expand All @@ -286,9 +288,18 @@ def test_empty_step_result():

def test_state_equal():
q0, q1 = cirq.LineQubit.range(2)
state0 = ccq.mps_simulator.MPSState(qubit_map={q0: 0}, rsum2_cutoff=1e-3, sum_prob_atol=1e-3)
state1a = ccq.mps_simulator.MPSState(qubit_map={q1: 0}, rsum2_cutoff=1e-3, sum_prob_atol=1e-3)
state1b = ccq.mps_simulator.MPSState(qubit_map={q1: 0}, rsum2_cutoff=1729.0, sum_prob_atol=1e-3)
state0 = ccq.mps_simulator.MPSState(
qubit_map={q0: 0},
simulation_options=ccq.mps_simulator.MPSOptions(cutoff=1e-3, sum_prob_atol=1e-3),
)
state1a = ccq.mps_simulator.MPSState(
qubit_map={q1: 0},
simulation_options=ccq.mps_simulator.MPSOptions(cutoff=1e-3, sum_prob_atol=1e-3),
)
state1b = ccq.mps_simulator.MPSState(
qubit_map={q1: 0},
simulation_options=ccq.mps_simulator.MPSOptions(cutoff=1729.0, sum_prob_atol=1e-3),
)
assert state0 == state0
assert state0 != state1a
assert state1a != state1b
Expand All @@ -313,7 +324,7 @@ def test_supremacy_equal_more_cols():
def test_tensor_index_names():
qubits = cirq.LineQubit.range(12)
qubit_map = {qubit: i for i, qubit in enumerate(qubits)}
state = ccq.mps_simulator.MPSState(qubit_map, rsum2_cutoff=0.1234, sum_prob_atol=1e-3)
state = ccq.mps_simulator.MPSState(qubit_map)

assert state.i_str(0) == "i_00"
assert state.i_str(11) == "i_11"
Expand All @@ -329,16 +340,30 @@ def test_supremacy_big():
q0 = next(iter(qubit_order))
circuit.append(cirq.measure(q0))

mps_simulator = ccq.mps_simulator.MPSSimulator(rsum2_cutoff=5e-5)
result = mps_simulator.simulate(circuit, qubit_order=qubit_order, initial_state=0)
mps_simulator_1 = ccq.mps_simulator.MPSSimulator(
simulation_options=ccq.mps_simulator.MPSOptions(cutoff=5e-5)
)
result_1 = mps_simulator_1.simulate(circuit, qubit_order=qubit_order, initial_state=0)

assert result.final_state.estimation_stats() == {
assert result_1.final_state.estimation_stats() == {
'estimated_fidelity': 0.997,
'memory_bytes': 11008,
'num_svd_splits': 64,
'num_coefs_used': 688,
}

mps_simulator_2 = ccq.mps_simulator.MPSSimulator(
simulation_options=ccq.mps_simulator.MPSOptions(
method='isvd', max_bond=1, cutoff_mode='sum2'
)
)
result_2 = mps_simulator_2.simulate(circuit, qubit_order=qubit_order, initial_state=0)

assert result_2.final_state.estimation_stats() == {
'estimated_fidelity': 1.0,
'memory_bytes': 1568,
'num_coefs_used': 98,
}


def test_simulate_moment_steps_sample():
q0, q1 = cirq.LineQubit.range(2)
Expand Down

0 comments on commit 6bad7bf

Please sign in to comment.