-
Notifications
You must be signed in to change notification settings - Fork 146
Open
Description
Description
import pytensor
import pytensor.tensor as pt
x = pt.vector("x")
out = pt.concatenate([x[None], x[None], x[None]], axis=0)
fn = pytensor.function([x], out, trust_input=True)
fn.dprint()
# Join [id A] 1
# ├─ 0 [id B]
# ├─ ExpandDims{axis=0} [id C] 0
# │ └─ x [id D]
# ├─ ExpandDims{axis=0} [id C] 0
# │ └─ ···
# └─ ExpandDims{axis=0} [id C] 0
# └─ ···
alt_out = pt.repeat(x[None], 3, axis=0)
alt_fn = pytensor.function([x], alt_out, trust_input=True)
alt_fn.dprint()
# Alloc [id A] 1
# ├─ x [id B]
# ├─ 3 [id C]
# └─ Shape_i{0} [id D] 0
# └─ x [id B]
x_test = pt.random.uniform(size=(100,)).eval()
%timeit fn(x_test) # 5.58 μs ± 496 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%timeit alt_fn(x_test) # 4.01 μs ± 529 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)