Skip to content

Commit

Permalink
Merge pull request #177 from nwittler/speed
Browse files Browse the repository at this point in the history
Removing tensorflow bottlenecks
  • Loading branch information
nwittler authored Apr 1, 2022
2 parents 52d0a12 + 2815cb4 commit 381bdc3
Show file tree
Hide file tree
Showing 26 changed files with 740 additions and 632 deletions.
31 changes: 22 additions & 9 deletions c3/c3objs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Basic custom objects."""

import hjson
from typing import List
import numpy as np
import tensorflow as tf
from c3.utils.utils import num3str
Expand Down Expand Up @@ -121,6 +122,13 @@ def asdict(self) -> dict:
"symbol": self.symbol,
}

def tolist(self) -> List:
if self.length > 1:
tolist = self.get_value().numpy().tolist()
else:
tolist = [self.get_value().numpy().tolist()]
return tolist

def __add__(self, other):
out_val = copy.deepcopy(self)
out_val._set_value_extend(self.get_value() + other)
Expand Down Expand Up @@ -236,7 +244,7 @@ def numpy(self) -> np.ndarray:
# TODO should be removed to be consistent with get_value
return self.get_value().numpy() / self.pref

def get_value(self, val: tf.float64 = None, dtype: tf.dtypes = None) -> tf.Tensor:
def get_value(self) -> tf.Tensor:
"""
Return the value of this quantity as tensorflow.
Expand All @@ -245,13 +253,18 @@ def get_value(self, val: tf.float64 = None, dtype: tf.dtypes = None) -> tf.Tenso
val : tf.float64
dtype: tf.dtypes
"""
if val is None:
val = self.value
if dtype is None:
dtype = self.value.dtype
return self.scale * (self.value + 1) / 2 + self.offset

def get_other_value(self, val) -> tf.Tensor:
"""
Return an arbitrary value of the same scale as this quantity as tensorflow.
value = self.scale * (val + 1) / 2 + self.offset
return tf.cast(value, dtype)
Parameters
----------
val : tf.float64
dtype: tf.dtypes
"""
return (self.scale * (val + 1) / 2 + self.offset) / self.pref

def set_value(self, val, extend_bounds=False):
if extend_bounds:
Expand Down Expand Up @@ -289,9 +302,9 @@ def _set_value_extend(self, val) -> None:
self.set_limits(min_val, max_val)
self._set_value(val)

def get_opt_value(self) -> np.ndarray:
def get_opt_value(self) -> tf.Tensor:
"""Get an optimizer friendly representation of the value."""
return self.value.numpy().flatten()
return tf.reshape(self.value, (-1,))

def set_opt_value(self, val: float) -> None:
"""Set value optimizer friendly.
Expand Down
37 changes: 33 additions & 4 deletions c3/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
tf_state_to_dm,
tf_super,
tf_vec_to_dm,
_tf_matmul_n_even,
_tf_matmul_n_odd,
)

from c3.libraries.propagation import unitary_provider, state_provider
Expand Down Expand Up @@ -54,7 +56,7 @@ class Experiment:
"""

def __init__(self, pmap: ParameterMap = None, prop_method=None):
def __init__(self, pmap: ParameterMap = None, prop_method=None, sim_res=100e9):
self.pmap = pmap
self.opt_gates = None
self.propagators: Dict[str, tf.Tensor] = {}
Expand All @@ -67,6 +69,7 @@ def __init__(self, pmap: ParameterMap = None, prop_method=None):
self.compute_propagators_timestamp = 0
self.stop_partial_propagator_gradient = True
self.evaluate = self.evaluate_legacy
self.sim_res = sim_res
self.set_prop_method(prop_method)

def set_prop_method(self, prop_method=None) -> None:
Expand All @@ -76,6 +79,8 @@ def set_prop_method(self, prop_method=None) -> None:
"""
if prop_method is None:
self.propagation = unitary_provider["pwc"]
if self.pmap is not None:
self._compute_folding_stack()
elif isinstance(prop_method, str):
try:
self.propagation = unitary_provider[prop_method]
Expand All @@ -84,6 +89,22 @@ def set_prop_method(self, prop_method=None) -> None:
elif callable(prop_method):
self.propagation = prop_method

def _compute_folding_stack(self):
self.folding_stack = {}
for instr in self.pmap.instructions.values():
n_steps = int((instr.t_end - instr.t_start) * self.sim_res)
if n_steps not in self.folding_stack:
stack = []
while n_steps > 1:
if not n_steps % 2: # is divisable by 2
stack.append(_tf_matmul_n_even)
else:
stack.append(_tf_matmul_n_odd)
n_steps = np.ceil(n_steps / 2)
self.folding_stack[
int((instr.t_end - instr.t_start) * self.sim_res)
] = stack

def enable_qasm(self) -> None:
"""
Switch the sequencing format to QASM. Will become the default.
Expand Down Expand Up @@ -184,7 +205,9 @@ def make_absolute(filename: str) -> str:
)
instructions.append(instr)

self.sim_res = 100e9
self.pmap = ParameterMap(instructions, generator=gen, model=model)
self.set_prop_method()

def read_config(self, filepath: str) -> None:
"""
Expand Down Expand Up @@ -214,6 +237,8 @@ def from_dict(self, cfg: Dict) -> None:
for k, v in cfg["options"].items():
self.__dict__[k] = v
self.pmap = pmap
self.sim_res = cfg.pop("sim_res", 100e9)
self.set_prop_method()

def write_config(self, filepath: str) -> None:
"""
Expand All @@ -238,6 +263,7 @@ def asdict(self) -> Dict:
"overwrite_propagators": self.overwrite_propagators,
"stop_partial_propagator_gradient": self.stop_partial_propagator_gradient,
}
exp_dict["sim_res"] = self.sim_res
return exp_dict

def __str__(self) -> str:
Expand Down Expand Up @@ -445,7 +471,7 @@ def compute_states(self) -> Dict[Instruction, List[tf.Tensor]]:
f" Available gates are:\n {list(instructions.keys())}."
)
signal = generator.generate_signals(instr)
result = self.propagation(model, signal)
result = self.propagation(model, signal, self.folding_stack)
states[instr] = result["states"]
self.states = states
return result
Expand Down Expand Up @@ -479,7 +505,10 @@ def compute_propagators(self):
)

model.controllability = self.use_control_fields
result = self.propagation(model, generator, instr)
steps = int((instr.t_end - instr.t_start) * self.sim_res)
result = self.propagation(
model, generator, instr, self.folding_stack[steps]
)
U = result["U"]
dUs = result["dUs"]
self.ts = result["ts"]
Expand All @@ -489,7 +518,7 @@ def compute_propagators(self):
framechanges = {}
for line, ctrls in instr.comps.items():
# TODO calculate properly the average frequency that each qubit sees
offset = 0.0
offset = tf.constant(0.0, tf.float64)
for ctrl in ctrls.values():
if "freq_offset" in ctrl.params.keys():
if ctrl.params["amp"] != 0.0:
Expand Down
Loading

0 comments on commit 381bdc3

Please sign in to comment.