Skip to content

Commit

Permalink
Some more type annotations and format strings (#2935)
Browse files Browse the repository at this point in the history
  • Loading branch information
viathor committed Apr 23, 2020
1 parent 10fc655 commit e08f94a
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 99 deletions.
29 changes: 12 additions & 17 deletions cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"""

import collections
from typing import Dict, List, Iterator, Sequence
from typing import Any, Dict, List, Iterator, Sequence

import numpy as np
from cirq.ops.global_phase_op import GlobalPhaseOperation
Expand Down Expand Up @@ -188,21 +188,16 @@ class CliffordTrialResult(simulator.SimulationTrialResult):
def __init__(self, params: study.ParamResolver,
measurements: Dict[str, np.ndarray],
final_simulator_state: 'CliffordState') -> None:

super().__init__(params=params,
measurements=measurements,
final_simulator_state=final_simulator_state)

self.final_state = final_simulator_state

def __str__(self):
def __str__(self) -> str:
samples = super().__str__()
final = self._final_simulator_state

return 'measurements: {}\noutput state: {}'.format(samples, final)

def __repr__(self):
return super().__repr__()
return f'measurements: {samples}\noutput state: {final}'


class CliffordSimulatorStepResult(simulator.StepResult):
Expand All @@ -222,7 +217,7 @@ def __init__(self, state, measurements):
self.measurements = measurements
self.state = state

def __str__(self):
def __str__(self) -> str:

def bitstring(vals):
return ''.join('1' if v else '0' for v in vals)
Expand All @@ -234,12 +229,12 @@ def bitstring(vals):
if len(results) == 0:
measurements = ''
else:
measurements = ' '.join(
['{}={}'.format(key, val) for key, val in results]) + '\n'
measurements = ' '.join([f'{key}={val}' for key, val in results
]) + '\n'

final = self.state

return '{}{}'.format(measurements, final)
return f'{measurements}{final}'

def _simulator_state(self):
return self.state
Expand Down Expand Up @@ -296,24 +291,24 @@ def _from_json_dict_(cls, qubit_map, tableau, ch_form, **kwargs):

return state

def _value_equality_values_(self):
def _value_equality_values_(self) -> Any:
return self.qubit_map, self.tableau, self.ch_form

def copy(self):
def copy(self) -> 'CliffordState':
state = CliffordState(self.qubit_map)
state.tableau = self.tableau.copy()
state.ch_form = self.ch_form.copy()

return state

def __repr__(self):
def __repr__(self) -> str:
return repr(self.ch_form)

def __str__(self):
def __str__(self) -> str:
"""Return the wavefunction string representation of the state."""
return str(self.ch_form)

def to_numpy(self):
def to_numpy(self) -> np.ndarray:
return self.ch_form.to_state_vector()

def stabilizers(self) -> List[DensePauliString]:
Expand Down
56 changes: 27 additions & 29 deletions cirq/sim/clifford/clifford_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import Any, Dict, List
import numpy as np

import cirq
Expand Down Expand Up @@ -51,7 +51,7 @@ def bits(s):
self.xs[i, i] = True
self.zs[self.n + i, i] = True

def _json_dict_(self):
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['n', 'rs', 'xs', 'zs'])

@classmethod
Expand All @@ -60,7 +60,6 @@ def _from_json_dict_(cls, n, rs, xs, zs, **kwargs):
state.rs = rs
state.xs = xs
state.zs = zs

return state

def __eq__(self, other):
Expand All @@ -71,65 +70,64 @@ def __eq__(self, other):
np.array_equal(self.xs, other.xs) and
np.array_equal(self.zs, other.zs))

def copy(self):
def copy(self) -> 'CliffordTableau':
state = CliffordTableau(self.n)
state.rs = self.rs.copy()
state.xs = self.xs.copy()
state.zs = self.zs.copy()

return state

def __repr__(self):
return "stabilizers: [{}]".format(", ".join(
[repr(stab) for stab in self.stabilizers()]))
def __repr__(self) -> str:
stabilizers = ", ".join([repr(stab) for stab in self.stabilizers()])
return f'stabilizers: [{stabilizers}]'

def __str__(self):
string = ""
def __str__(self) -> str:
string = ''

for i in range(self.n, 2 * self.n):
string += "- " if self.rs[i] else "+ "
string += '- ' if self.rs[i] else '+ '

for k in range(0, self.n):
if self.xs[i, k] & (not self.zs[i, k]):
string += "X "
string += 'X '
elif (not self.xs[i, k]) & self.zs[i, k]:
string += "Z "
string += 'Z '
elif self.xs[i, k] & self.zs[i, k]:
string += "Y "
string += 'Y '
else:
string += "I "
string += 'I '

if i < 2 * self.n - 1:
string += "\n"
string += '\n'

return string

def _str_full_(self):
string = ""
def _str_full_(self) -> str:
string = ''

string += "stable" + " " * max(self.n * 2 - 3, 1)
string += "| destable\n"
string += "-" * max(7, self.n * 2 + 3) + "+" + "-" * max(
10, self.n * 2 + 4) + "\n"
string += 'stable' + ' ' * max(self.n * 2 - 3, 1)
string += '| destable\n'
string += '-' * max(7, self.n * 2 + 3) + '+' + '-' * max(
10, self.n * 2 + 4) + '\n'

for j in range(self.n):
for i in [j + self.n, j]:
string += "- " if self.rs[i] else "+ "
string += '- ' if self.rs[i] else '+ '

for k in range(0, self.n):
if self.xs[i, k] & (not self.zs[i, k]):
string += "X%d" % k
string += 'X%d' % k
elif (not self.xs[i, k]) & self.zs[i, k]:
string += "Z%d" % k
string += 'Z%d' % k
elif self.xs[i, k] & self.zs[i, k]:
string += "Y%d" % k
string += 'Y%d' % k
else:
string += " "
string += ' '

if i == j + self.n:
string += " " * max(0, 4 - self.n * 2) + " | "
string += ' ' * max(0, 4 - self.n * 2) + ' | '

string += "\n"
string += '\n'

return string

Expand Down
14 changes: 7 additions & 7 deletions cirq/sim/clifford/stabilizer_state_ch_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import Any, Dict, Union
import numpy as np

import cirq
Expand All @@ -32,7 +32,7 @@ class StabilizerStateChForm():

def __init__(self,
num_qubits: int,
initial_state: Union[int, np.ndarray] = 0):
initial_state: Union[int, np.ndarray] = 0) -> None:
"""Initializes StabilizerStateChForm
Args:
num_qubits: The number of qubits in the system
Expand Down Expand Up @@ -65,7 +65,7 @@ def bits(s):
if val:
self._X(self.n - i - 1)

def _json_dict_(self):
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(
self, ['n', 'G', 'F', 'M', 'gamma', 'v', 's', 'omega'])

Expand All @@ -83,7 +83,7 @@ def _from_json_dict_(cls, n, G, F, M, gamma, v, s, omega, **kwargs):

return copy

def _value_equality_values_(self):
def _value_equality_values_(self) -> Any:
return (self.n, self.G, self.F, self.M, self.gamma, self.v, self.v,
self.s, self.omega)

Expand All @@ -100,13 +100,13 @@ def copy(self) -> 'cirq.StabilizerStateChForm':

return copy

def __str__(self):
def __str__(self) -> str:
"""Return the wavefunction string representation of the state."""
return cirq.dirac_notation(self.to_state_vector())

def __repr__(self):
def __repr__(self) -> str:
"""Return the CH form representation of the state. """
return 'StabilizerStateChForm(num_qubits={!r})'.format(self.n)
return f'StabilizerStateChForm(num_qubits={self.n!r})'

def inner_product_of_state_and_x(self, x: int) -> Union[float, complex]:
""" Returns the amplitude of x'th element of
Expand Down
36 changes: 17 additions & 19 deletions cirq/sim/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import collections

from typing import Dict, Iterator, List, TYPE_CHECKING, Type, Union
from typing import Any, Dict, Iterator, List, TYPE_CHECKING, Tuple, Type, Union

import numpy as np

Expand Down Expand Up @@ -465,22 +465,21 @@ class DensityMatrixSimulatorState():
"""

def __init__(self, density_matrix: np.ndarray,
qubit_map: Dict[ops.Qid, int]):
qubit_map: Dict[ops.Qid, int]) -> None:
self.density_matrix = density_matrix
self.qubit_map = qubit_map
self._qid_shape = simulator._qubit_map_to_shape(qubit_map)

def _qid_shape_(self):
def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def _value_equality_values_(self):
def _value_equality_values_(self) -> Any:
return (self.density_matrix.tolist(), self.qubit_map)

def __repr__(self):
return ("cirq.DensityMatrixSimulatorState("
"density_matrix=np.array({!r}), "
"qubit_map={!r})".format(self.density_matrix.tolist(),
self.qubit_map))
def __repr__(self) -> str:
return ('cirq.DensityMatrixSimulatorState('
f'density_matrix=np.array({self.density_matrix.tolist()!r}), '
f'qubit_map={self.qubit_map!r})')


@value.value_equality(unhashable=True)
Expand Down Expand Up @@ -533,19 +532,18 @@ def __init__(self, params: study.ParamResolver,
self.final_density_matrix = np.reshape(
final_simulator_state.density_matrix, (size, size))

def _value_equality_values_(self):
def _value_equality_values_(self) -> Any:
measurements = {
k: v.tolist() for k, v in sorted(self.measurements.items())
}
return (self.params, measurements, self._final_simulator_state)

def __str__(self):
def __str__(self) -> str:
samples = super().__str__()
return 'measurements: {}\nfinal density matrix:\n{}'.format(
samples, self.final_density_matrix)

def __repr__(self):
return ("cirq.DensityMatrixTrialResult(params={!r}, measurements={!r}, "
"final_simulator_state={!r})".format(
self.params, self.measurements,
self._final_simulator_state))
return (f'measurements: {samples}\n'
f'final density matrix:\n{self.final_density_matrix}')

def __repr__(self) -> str:
return ('cirq.DensityMatrixTrialResult('
f'params={self.params!r}, measurements={self.measurements!r}, '
f'final_simulator_state={self._final_simulator_state!r})')
17 changes: 8 additions & 9 deletions cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,13 @@ def __init__(self,
self.measurements = measurements
self._final_simulator_state = final_simulator_state

def __repr__(self):
return ('cirq.SimulationTrialResult(params={!r}, '
'measurements={!r}, '
'final_simulator_state={!r})').format(
self.params, self.measurements, self._final_simulator_state)
def __repr__(self) -> str:
return (f'cirq.SimulationTrialResult(params={self.params!r}, '
f'measurements={self.measurements!r}, '
f'final_simulator_state={self._final_simulator_state!r})')

def __str__(self) -> str:

def __str__(self):
def bitstring(vals):
separator = ' ' if np.max(vals) >= 10 else ''
return separator.join(str(int(v)) for v in vals)
Expand All @@ -534,8 +534,7 @@ def bitstring(vals):
[(key, bitstring(val)) for key, val in self.measurements.items()])
if not results:
return '(no measurements)'
return ' '.join(
['{}={}'.format(key, val) for key, val in results])
return ' '.join([f'{key}={val}' for key, val in results])

def _repr_pretty_(self, p: Any, cycle: bool) -> None:
"""Text output in Jupyter."""
Expand All @@ -545,7 +544,7 @@ def _repr_pretty_(self, p: Any, cycle: bool) -> None:
else:
p.text(str(self))

def _value_equality_values_(self):
def _value_equality_values_(self) -> Any:
measurements = {k: v.tolist() for k, v in
sorted(self.measurements.items())}
return (self.params, measurements, self._final_simulator_state)
Expand Down

0 comments on commit e08f94a

Please sign in to comment.