Skip to content

Conversation

@tchan102
Copy link

@tchan102 tchan102 commented Nov 4, 2025

Add optimization for Join → Repeat when concatenating identical tensors

Description

This PR introduces a graph rewrite optimization in pytensor/tensor/rewriting/basic.py that replaces redundant Join operations with an equivalent and more efficient Repeat operation when all concatenated tensors are identical.

Example:
join(0, x, x, x) → repeat(x, 3, axis=0)

Key additions:

  • Implemented new rewrite function local_join_to_repeat registered under both @register_canonicalize and @register_specialize.
  • Added corresponding test test_local_join_to_repeat to verify correctness, performance, and behavior for vectors and matrices.

Related Issue

Checklist

Type of change

  • [ x] New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94 ricardoV94 added graph rewriting enhancement New feature or request labels Nov 4, 2025
@ricardoV94
Copy link
Member

Let's try with @register_canonicalize only

@ricardoV94
Copy link
Member

Btw would be nice to get rid of this join (and split) symbolic axis if you would like to work on that after this PR. relevant issue: #1528

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite concatenate([x, x]) as repeat(x, 2)

2 participants