Skip to content

Commit

Permalink
Simplify type annotations in mps_simulator.py (#3926)
Browse files Browse the repository at this point in the history
These give me type errors when running mypy locally. It seems to happen on intermittently, for reasons I haven't full understood, but in any event it seems better to just simplify them since these refer to types defined in the same module.
  • Loading branch information
maffoo committed Mar 18, 2021
1 parent 47f5908 commit 3d369a0
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,15 @@ class MPSOptions:

class MPSSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateState[
'cirq.contrib.quimb.mps_simulator.MPSSimulatorStepResult',
'cirq.contrib.quimb.mps_simulator.MPSTrialResult',
'cirq.contrib.quimb.mps_simulator.MPSState',
],
simulator.SimulatesIntermediateState['MPSSimulatorStepResult', 'MPSTrialResult', 'MPSState'],
):
"""An efficient simulator for MPS circuits."""

def __init__(
self,
noise: 'cirq.NOISE_MODEL_LIKE' = None,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
simulation_options: 'cirq.contrib.quimb.mps_simulator.MPSOptions' = MPSOptions(),
simulation_options: MPSOptions = MPSOptions(),
grouping: Optional[Dict['cirq.Qid', int]] = None,
):
"""Creates instance of `MPSSimulator`.
Expand All @@ -88,7 +84,7 @@ def __init__(

def _base_iterator(
self, circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList, initial_state: int
) -> Iterator['cirq.contrib.quimb.mps_simulator.MPSSimulatorStepResult']:
) -> Iterator['MPSSimulatorStepResult']:
"""Iterator over MPSSimulatorStepResult from Moments of a Circuit
Args:
Expand Down Expand Up @@ -144,8 +140,8 @@ def _create_simulator_trial_result(
self,
params: study.ParamResolver,
measurements: Dict[str, np.ndarray],
final_simulator_state: 'cirq.contrib.quimb.mps_simulator.MPSState',
) -> 'cirq.contrib.quimb.mps_simulator.MPSTrialResult':
final_simulator_state: 'MPSState',
) -> 'MPSTrialResult':
"""Creates a single trial results with the measurements.
Args:
Expand Down Expand Up @@ -183,7 +179,7 @@ def _run(
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)

measurements = {} # type: Dict[str, List[np.ndarray]]
measurements: Dict[str, List[np.ndarray]] = {}

for _ in range(repetitions):
all_step_results = self._base_iterator(
Expand Down Expand Up @@ -292,7 +288,7 @@ class MPSState:
def __init__(
self,
qubit_map: Dict['cirq.Qid', int],
simulation_options: 'cirq.contrib.quimb.mps_simulator.MPSOptions' = MPSOptions(),
simulation_options: MPSOptions = MPSOptions(),
grouping: Optional[Dict['cirq.Qid', int]] = None,
initial_state: int = 0,
):
Expand Down

0 comments on commit 3d369a0

Please sign in to comment.