Skip to content

Rewrite concatenate([x, x]) as repeat(x, 2) #1710

@ricardoV94

Description

@ricardoV94

Description

import pytensor
import pytensor.tensor as pt

x = pt.vector("x")
out = pt.concatenate([x[None], x[None], x[None]], axis=0)
fn = pytensor.function([x], out, trust_input=True)
fn.dprint()
# Join [id A] 1
#  ├─ 0 [id B]
#  ├─ ExpandDims{axis=0} [id C] 0
#  │  └─ x [id D]
#  ├─ ExpandDims{axis=0} [id C] 0
#  │  └─ ···
#  └─ ExpandDims{axis=0} [id C] 0
#     └─ ···

alt_out = pt.repeat(x[None], 3, axis=0)
alt_fn = pytensor.function([x], alt_out, trust_input=True)
alt_fn.dprint()
# Alloc [id A] 1
#  ├─ x [id B]
#  ├─ 3 [id C]
#  └─ Shape_i{0} [id D] 0
#     └─ x [id B]

x_test = pt.random.uniform(size=(100,)).eval()
%timeit fn(x_test)  # 5.58 μs ± 496 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) 
%timeit alt_fn(x_test)  # 4.01 μs ± 529 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions