/
mul_protocol.py
74 lines (60 loc) · 2.71 KB
/
mul_protocol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from cirq.protocols.resolve_parameters import is_parameterized
# This is a special indicator value used to determine whether or not the caller
# provided a 'default' argument.
RaiseTypeErrorIfNotProvided = ([],) # type: Any
def mul(lhs: Any, rhs: Any, default: Any = RaiseTypeErrorIfNotProvided) -> Any:
"""Returns lhs * rhs, or else a default if the operator is not implemented.
This method is mostly used by __pow__ methods trying to return
NotImplemented instead of causing a TypeError.
Args:
lhs: Left hand side of the multiplication.
rhs: Right hand side of the multiplication.
default: Default value to return if the multiplication is not defined.
If not default is specified, a type error is raised when the
multiplication fails.
Returns:
The product of the two inputs, or else the default value if the product
is not defined, or else raises a TypeError if no default is defined.
Raises:
TypeError:
lhs doesn't have __mul__ or it returned NotImplemented
AND lhs doesn't have __rmul__ or it returned NotImplemented
AND a default value isn't specified.
"""
# Use left-hand-side's __mul__.
left_mul = getattr(lhs, '__mul__', None)
result = NotImplemented if left_mul is None else left_mul(rhs)
# Fallback to right-hand-side's __rmul__.
if result is NotImplemented:
right_mul = getattr(rhs, '__rmul__', None)
result = NotImplemented if right_mul is None else right_mul(lhs)
# Don't build up factors of 1.0 vs sympy Symbols.
if lhs == 1 and is_parameterized(rhs):
result = rhs
if rhs == 1 and is_parameterized(lhs):
result = lhs
if lhs == -1 and is_parameterized(rhs):
result = -rhs
if rhs == -1 and is_parameterized(lhs):
result = -lhs
# Output.
if result is not NotImplemented:
return result
if default is not RaiseTypeErrorIfNotProvided:
return default
raise TypeError(f"unsupported operand type(s) for *: '{type(lhs)}' and '{type(rhs)}'")
# pylint: enable=function-redefined, redefined-builtin