Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolve_parameters is slow for cirq.Duration #6269

Closed
andbe91 opened this issue Aug 29, 2023 · 4 comments · Fixed by #6270
Closed

resolve_parameters is slow for cirq.Duration #6269

andbe91 opened this issue Aug 29, 2023 · 4 comments · Fixed by #6270
Assignees
Labels
area/parameters parameter resolution, parameterized gates, operations area/performance kind/bug-report Something doesn't seem to work.

Comments

@andbe91
Copy link
Collaborator

andbe91 commented Aug 29, 2023

Description of the issue
It seems like by including just one extra parameter to resolve, the time it takes increases dramatically.

How to reproduce the issue

import cirq
import sympy

qubits = cirq.GridQubit.rect(10, 10)

prep_moment = cirq.Moment((cirq.X**(sympy.Symbol(f"half_turns_{str(q)}"))).on(q) for q in qubits)
wait_moment = cirq.Moment(cirq.wait(*qubits, nanos=sympy.Symbol("delay_ns")))
measure_moment = cirq.Moment(cirq.measure(*qubits))
circuit = cirq.Circuit.from_moments(prep_moment, wait_moment, measure_moment)

sweep = cirq.Zip(*[cirq.Linspace(f"half_turns_{str(q)}", 0, 2, 256) for q in qubits]) * cirq.Points("delay_ns", [1])

for s in sweep:
  cirq.resolve_parameters(circuit, s)

This takes ~14 seconds on my machine. However, if I exclude wait_moment from the circuit such that there is one less parameter, then it only takes ~2 seconds. Note that the length of the sweep is the same in both cases. I have attached cProfiles below in the details.

Fast version

         5239420 function calls (5082235 primitive calls) in 2.327 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    2.327    2.327 {built-in method builtins.exec}
        1    0.000    0.000    2.327    2.327 <string>:1(<module>)
        1    0.008    0.008    2.327    2.327 2604498839.py:1(resolve)
51968/256    0.127    0.000    2.254    0.009 resolve_parameters.py:136(resolve_parameters)
      256    0.001    0.000    2.248    0.009 circuit.py:1289(_resolve_parameters_)
      256    0.029    0.000    2.238    0.009 moment.py:270(_resolve_parameters_)
    25600    0.010    0.000    1.079    0.000 value_equality_attr.py:83(_value_equality_ne)
51200/25600    0.093    0.000    1.070    0.000 value_equality_attr.py:72(_value_equality_eq)
    25600    0.028    0.000    0.823    0.000 gate_operation.py:290(_resolve_parameters_)
    51200    0.055    0.000    0.598    0.000 expr.py:143(__eq__)
    51200    0.027    0.000    0.534    0.000 sympify.py:486(_sympify)
    51200    0.046    0.000    0.507    0.000 sympify.py:92(sympify)
    25600    0.119    0.000    0.436    0.000 numbers.py:1031(__new__)
    25600    0.039    0.000    0.420    0.000 eigen_gate.py:356(_resolve_parameters_)
954590/953054    0.179    0.000    0.299    0.000 {built-in method builtins.isinstance}
    51200    0.035    0.000    0.299    0.000 common_gates.py:330(_value_equality_values_)
    51200    0.020    0.000    0.264    0.000 eigen_gate.py:318(_value_equality_values_)
    25600    0.061    0.000    0.261    0.000 resolver.py:84(value_of)
    51200    0.038    0.000    0.243    0.000 eigen_gate.py:308(_canonical_exponent)
    25600    0.026    0.000    0.240    0.000 gate_operation.py:77(with_gate)
    25600    0.019    0.000    0.206    0.000 raw_types.py:224(on)
    25600    0.018    0.000    0.188    0.000 gate_operation.py:53(__init__)
    51200    0.051    0.000    0.178    0.000 resolver.py:276(_resolve_value)
    25600    0.017    0.000    0.170    0.000 raw_types.py:204(validate_args)
    25700    0.040    0.000    0.163    0.000 eigen_gate.py:289(_period)
    25600    0.057    0.000    0.149    0.000 raw_types.py:1036(_validate_qid_shape)
    51200    0.029    0.000    0.124    0.000 gate_operation.py:153(_value_equality_values_)
      256    0.030    0.000    0.119    0.000 moment.py:84(__init__)
   205827    0.048    0.000    0.118    0.000 abc.py:117(__instancecheck__)
    25600    0.025    0.000    0.106    0.000 numbers.py:1191(_new)
    25400    0.031    0.000    0.096    0.000 libmpf.py:410(from_float)
    76958    0.042    0.000    0.089    0.000 {built-in method builtins.any}
    51200    0.036    0.000    0.089    0.000 gate_operation.py:142(_group_interchangeable_qubits)
   205827    0.055    0.000    0.071    0.000 {built-in method _abc._abc_instancecheck}
77415/77158    0.031    0.000    0.068    0.000 resolve_parameters.py:66(is_parameterized)
   412110    0.064    0.000    0.067    0.000 {built-in method builtins.getattr}
      257    0.005    0.000    0.065    0.000 sweeps.py:118(__iter__)
   102912    0.035    0.000    0.063    0.000 _compat.py:104(wrapped_no_args)
    25600    0.022    0.000    0.059    0.000 libmpf.py:291(from_man_exp)
    25600    0.019    0.000    0.058    0.000 common_gates.py:166(_with_exponent)
    25858    0.020    0.000    0.056    0.000 gate_operation.py:278(_is_parameterized_)
    25600    0.045    0.000    0.056    0.000 numbers.py:151(mpf_norm)
    25700    0.024    0.000    0.050    0.000 eigen_gate.py:406(_approximate_common_period)
 1281/257    0.002    0.000    0.049    0.000 sweeps.py:243(_gen)
    25400    0.015    0.000    0.048    0.000 sympify.py:58(_is_numpy_instance)
    51457    0.023    0.000    0.047    0.000 eigen_gate.py:350(_is_parameterized_)
      257    0.011    0.000    0.047    0.000 sweeps.py:312(param_tuples)
102912/25856    0.025    0.000    0.044    0.000 op_tree.py:97(flatten_to_ops)
    52224    0.018    0.000    0.043    0.000 resolver.py:69(__init__)
    25700    0.021    0.000    0.041    0.000 eigen_gate.py:221(_eigen_shifts)
   155536    0.032    0.000    0.040    0.000 {built-in method builtins.hasattr}
    25600    0.020    0.000    0.039    0.000 common_gates.py:94(__init__)
    25600    0.019    0.000    0.031    0.000 libmpf.py:64(dps_to_prec)
    51000    0.029    0.000    0.029    0.000 libmpf.py:153(_normalize)
    26212    0.014    0.000    0.029    0.000 sweeps.py:413(param_tuples)
    51968    0.019    0.000    0.027    0.000 resolver.py:236(__bool__)
    25600    0.020    0.000    0.025    0.000 basic.py:113(__new__)
   181248    0.025    0.000    0.025    0.000 resolver.py:80(param_dict)
    25600    0.014    0.000    0.024    0.000 qid_shape_protocol.py:81(qid_shape)
   179459    0.024    0.000    0.024    0.000 gate_operation.py:64(gate)
    25700    0.017    0.000    0.020    0.000 eigen_gate.py:299(<listcomp>)
    52224    0.014    0.000    0.019    0.000 resolver.py:64(__new__)
    51200    0.016    0.000    0.019    0.000 raw_types.py:1048(<genexpr>)
    25600    0.015    0.000    0.018    0.000 eigen_gate.py:94(__init__)
    25400    0.009    0.000    0.018    0.000 libintmath.py:91(python_bitcount)
    76200    0.017    0.000    0.017    0.000 sympify.py:64(<genexpr>)
    26368    0.006    0.000    0.016    0.000 abc.py:121(__subclasscheck__)
   128900    0.015    0.000    0.015    0.000 {built-in method builtins.len}
    25700    0.015    0.000    0.015    0.000 sweeps.py:498(_values)
   102400    0.015    0.000    0.015    0.000 gate_operation.py:69(qubits)
    25700    0.013    0.000    0.013    0.000 eigen_gate.py:298(<setcomp>)
   102400    0.011    0.000    0.011    0.000 value_equality_attr.py:223(<lambda>)
    25700    0.011    0.000    0.011    0.000 common_gates.py:154(_eigen_components)
    26368    0.010    0.000    0.010    0.000 {built-in method _abc._abc_subclasscheck}
    25400    0.009    0.000    0.009    0.000 {built-in method _bisect.bisect_right}
    51400    0.009    0.000    0.009    0.000 eigen_gate.py:436(<genexpr>)
    25700    0.008    0.000    0.008    0.000 eigen_gate.py:234(<listcomp>)
      256    0.008    0.000    0.008    0.000 {built-in method builtins.sum}
    25600    0.007    0.000    0.007    0.000 expr.py:125(__hash__)
    25400    0.006    0.000    0.006    0.000 {built-in method math.frexp}
    26000    0.006    0.000    0.006    0.000 {built-in method builtins.max}
    51400    0.006    0.000    0.006    0.000 {built-in method builtins.abs}
    25600    0.006    0.000    0.006    0.000 {built-in method builtins.round}
    25600    0.005    0.000    0.005    0.000 common_gates.py:151(_qid_shape_)
    26056    0.005    0.000    0.005    0.000 {built-in method __new__ of type object at 0x7fc988fa4bc0}
      256    0.001    0.000    0.005    0.000 circuit.py:1754(_from_moments)
      256    0.002    0.000    0.004    0.000 circuit.py:1727(__init__)
      256    0.001    0.000    0.004    0.000 circuit.py:1283(_is_parameterized_)

Slow version

         26603534 function calls (26001163 primitive calls) in 14.645 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   14.645   14.645 {built-in method builtins.exec}
        1    0.000    0.000   14.645   14.645 <string>:1(<module>)
        1    0.023    0.023   14.645   14.645 2604498839.py:1(resolve)
53248/256    0.154    0.000   14.538    0.057 resolve_parameters.py:136(resolve_parameters)
      256    0.003    0.000   14.528    0.057 circuit.py:1289(_resolve_parameters_)
      512    0.038    0.000   14.510    0.028 moment.py:270(_resolve_parameters_)
    25856    0.034    0.000   12.737    0.000 gate_operation.py:290(_resolve_parameters_)
26624/25856    0.096    0.000   11.993    0.000 resolver.py:84(value_of)
      256    0.002    0.000   11.701    0.046 wait_gate.py:80(_resolve_parameters_)
      256    0.002    0.000   11.689    0.046 duration.py:100(_resolve_parameters_)
      256    0.206    0.001   11.618    0.045 basic.py:764(subs)
233516/104236    0.328    0.000    8.765    0.000 cache.py:69(wrapper)
    25856    0.028    0.000    7.876    0.000 decorators.py:254(_func)
    25856    0.024    0.000    7.831    0.000 decorators.py:129(binary_op_wrapper)
    25856    0.022    0.000    7.803    0.000 expr.py:222(__mul__)
    25856    0.272    0.000    7.549    0.000 operations.py:46(__new__)
    25856    0.594    0.000    6.706    0.000 mul.py:178(flatten)
258560/155136    0.162    0.000    4.882    0.000 assumptions.py:452(getit)
155136/25856    0.777    0.000    4.795    0.000 assumptions.py:464(_ask)
   155136    1.114    0.000    3.535    0.000 random.py:380(shuffle)
  2766592    1.770    0.000    2.402    0.000 random.py:239(_randbelow_with_getrandbits)
    25856    0.037    0.000    1.728    0.000 expr.py:855(_eval_is_negative)
    25856    0.039    0.000    1.722    0.000 expr.py:845(_eval_is_positive)
   259584    0.174    0.000    1.305    0.000 sympify.py:92(sympify)
    25856    0.012    0.000    1.287    0.000 value_equality_attr.py:83(_value_equality_ne)
51712/25856    0.110    0.000    1.275    0.000 value_equality_attr.py:72(_value_equality_eq)
    51712    0.273    0.000    0.986    0.000 numbers.py:1031(__new__)
    51712    0.381    0.000    0.900    0.000 basic.py:1867(_aresame)
   207872    0.120    0.000    0.900    0.000 sympify.py:486(_sympify)
52224/26112    0.123    0.000    0.888    0.000 compatibility.py:498(ordered)
   103680    0.144    0.000    0.880    0.000 expr.py:143(__eq__)
    25856    0.194    0.000    0.734    0.000 basic.py:906(<listcomp>)
    25856    0.032    0.000    0.720    0.000 basic.py:957(_subs)
2381853/2379293    0.487    0.000    0.637    0.000 {built-in method builtins.isinstance}
    51712    0.256    0.000    0.616    0.000 mul.py:455(_gather)
    25600    0.046    0.000    0.500    0.000 eigen_gate.py:356(_resolve_parameters_)
    26112    0.078    0.000    0.487    0.000 symbol.py:399(__new__)
    25856    0.016    0.000    0.483    0.000 mul.py:29(_mulsort)
    25856    0.036    0.000    0.467    0.000 {method 'sort' of 'list' objects}
    25856    0.117    0.000    0.431    0.000 basic.py:189(compare)
  4303106    0.396    0.000    0.396    0.000 {method 'getrandbits' of '_random.Random' objects}
    25856    0.009    0.000    0.389    0.000 basic.py:925(<lambda>)
    25856    0.100    0.000    0.379    0.000 compatibility.py:476(_nodes)
    51200    0.042    0.000    0.356    0.000 common_gates.py:330(_value_equality_values_)
    25856    0.033    0.000    0.341    0.000 gate_operation.py:77(with_gate)
    26112    0.078    0.000    0.336    0.000 symbol.py:274(__new_stage2__)
    51200    0.024    0.000    0.314    0.000 eigen_gate.py:318(_value_equality_values_)
    77824    0.050    0.000    0.311    0.000 symbol.py:425(_hashable_content)
    25856    0.024    0.000    0.298    0.000 raw_types.py:224(on)
    51200    0.046    0.000    0.290    0.000 eigen_gate.py:308(_canonical_exponent)
    25856    0.021    0.000    0.274    0.000 gate_operation.py:53(__init__)
336640/284928    0.133    0.000    0.271    0.000 expr.py:125(__hash__)
    25856    0.009    0.000    0.264    0.000 basic.py:926(<lambda>)
    77824    0.099    0.000    0.262    0.000 symbol.py:309(_hashable_content)
    25856    0.162    0.000    0.255    0.000 compatibility.py:314(default_sort_key)
    25856    0.021    0.000    0.253    0.000 raw_types.py:204(validate_args)
    25856    0.034    0.000    0.240    0.000 basic.py:1518(count)
    51712    0.057    0.000    0.236    0.000 numbers.py:1191(_new)
  2766848    0.231    0.000    0.231    0.000 {method 'bit_length' of 'int' objects}
    51712    0.230    0.000    0.230    0.000 expr.py:865(_eval_is_extended_positive_negative)
    25856    0.078    0.000    0.226    0.000 raw_types.py:1036(_validate_qid_shape)
    52480    0.063    0.000    0.223    0.000 resolver.py:276(_resolve_value)
      512    0.058    0.000    0.221    0.000 moment.py:84(__init__)
   129083    0.089    0.000    0.221    0.000 {built-in method builtins.any}
    25856    0.129    0.000    0.221    0.000 basic.py:1792(_exec_constructor_postprocessors)
    51412    0.071    0.000    0.220    0.000 libmpf.py:410(from_float)
    25700    0.047    0.000    0.197    0.000 eigen_gate.py:289(_period)
   754485    0.165    0.000    0.193    0.000 {built-in method builtins.getattr}
    26112    0.063    0.000    0.188    0.000 assumptions.py:424(__init__)
    26624    0.032    0.000    0.181    0.000 numbers.py:2236(__eq__)
    25856    0.012    0.000    0.170    0.000 numbers.py:2243(__ne__)
   258560    0.116    0.000    0.166    0.000 <frozen importlib._bootstrap>:404(parent)
   441600    0.119    0.000    0.163    0.000 numbers.py:2282(__hash__)
    25856    0.024    0.000    0.149    0.000 numbers.py:699(_eval_subs)
   205568    0.076    0.000    0.148    0.000 _compat.py:104(wrapped_no_args)
    51712    0.034    0.000    0.147    0.000 mul.py:466(<listcomp>)
    51712    0.034    0.000    0.146    0.000 gate_operation.py:153(_value_equality_values_)
   209669    0.056    0.000    0.143    0.000 abc.py:117(__instancecheck__)
   155136    0.045    0.000    0.140    0.000 basic.py:2045(__next__)
   103680    0.113    0.000    0.140    0.000 basic.py:113(__new__)
    26624    0.063    0.000    0.137    0.000 numbers.py:1862(__eq__)
    51968    0.052    0.000    0.137    0.000 libmpf.py:291(from_man_exp)
    25856    0.019    0.000    0.137    0.000 expr.py:908(_eval_is_extended_positive)
    25856    0.019    0.000    0.131    0.000 expr.py:911(_eval_is_extended_negative)
    26112    0.023    0.000    0.128    0.000 {built-in method builtins.sum}
    51712    0.100    0.000    0.127    0.000 numbers.py:151(mpf_norm)
    51968    0.088    0.000    0.124    0.000 symbol.py:229(_sanitize)
    77824    0.072    0.000    0.121    0.000 symbol.py:321(assumptions0)
    51712    0.068    0.000    0.113    0.000 numbers.py:2198(__mul__)

Cirq version

'1.3.0.dev20230828214840'

@andbe91 andbe91 added the kind/bug-report Something doesn't seem to work. label Aug 29, 2023
@tanujkhattar tanujkhattar added area/parameters parameter resolution, parameterized gates, operations area/performance labels Aug 29, 2023
@bichengying
Copy link
Collaborator

bichengying commented Aug 29, 2023

I tried the initial profiling on the two cases, the one with wait_moment and the one without wait_moment. Indeed the time difference is huge.

The one with

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    25856    6.823    0.000   15.906    0.001 assumptions.py:509(_ask)
   749824    3.433    0.000    6.332    0.000 random.py:380(shuffle)
    25856    2.666    0.000   26.593    0.001 mul.py:197(flatten)
2356536/2353976    2.006    0.000    2.682    0.000 {built-in method builtins.isinstance}
   698112    2.006    0.000    2.590    0.000 random.py:239(_randbelow_with_getrandbits)
    51712    1.851    0.000    5.025    0.000 basic.py:2109(_aresame)

The one without

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
51968/256    0.713    0.000    9.407    0.037 resolve_parameters.py:136(resolve_parameters)
673280/671744    0.555    0.000    1.144    0.000 {built-in method builtins.isinstance}
51200/25600    0.484    0.000    3.276    0.000 value_equality_attr.py:73(_value_equality_eq)
    25600    0.438    0.000    0.935    0.000 raw_types.py:1036(_validate_qid_shape)
    51200    0.290    0.000    0.588    0.000 resolver.py:276(_resolve_value)
    51200    0.283    0.000    1.713    0.000 common_gates.py:330(_value_equality_values_)

Looks like the problem is located under assumptions.py and some unnecessary randomness generation for shuffling?

@maffoo
Copy link
Contributor

maffoo commented Aug 29, 2023

Resolve parameters is definitely slow, but in this case the problem seems to be with resolving cirq.Duration specifically, because it combines the various fields picos, nanos, micros, and millis into a single picos value which is stored internally. If one of those is symbolic (as here), then you end up with a sympy expression which has to be resolved, as opposed to just a single sympy symbol. Resolving sympy arithmetic expressions is known to be especially slow. I previously attempted to fix this generically by pre-compiling sympy expressions (see #5047), but we ended up not doing that because it could allow for arbitrary code execution in certain circumstances. I think we'll have to refactor Duration to do something more sensible here. (One thing to note is we allow multiple fields to be set, like Duration(nanos=10, picos=10), by adding them with appropriate coefficients; but the intent of having multiple keyword args was more to have a way of expressing units, rather than adding multiple values with different units.)

@tanujkhattar
Copy link
Collaborator

tanujkhattar commented Aug 29, 2023

@maffoo One easy fix for Duration could be that instead of storing the self._picos as the only value, we convert it to a dataclass where we just store a float / sympy.Symbol for each of the 4 picos/nanos/micros/millis at the time of construction and resolve it to appropriate units at the time of access. This makes sure that the parameter resolution path simply resolves (at-most) 4 floats, instead or resolving a sympy expression. The former should be much faster since we will not go via the sympy parameter resolution path and break early instead.

@tanujkhattar
Copy link
Collaborator

tanujkhattar commented Aug 29, 2023

As suggested above, I've sent a fix to speedup the parameter resolution for cirq.Duration in #6270

On my machine, after this change, the original code snippet takes ~700ms, which is basically the same time it takes without the wait_moment.

image

@tanujkhattar tanujkhattar self-assigned this Aug 29, 2023
@tanujkhattar tanujkhattar changed the title resolve_parameters can be slow resolve_parameters is slow for cirq.Duration Aug 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/parameters parameter resolution, parameterized gates, operations area/performance kind/bug-report Something doesn't seem to work.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants