Skip to content

Fix static shape inference in pt.linalg.kron and add regression test#1898

Merged
jessegrabowski merged 1 commit intopymc-devs:mainfrom
ayulockedin:fix-kron-static-shape
Feb 23, 2026
Merged

Fix static shape inference in pt.linalg.kron and add regression test#1898
jessegrabowski merged 1 commit intopymc-devs:mainfrom
ayulockedin:fix-kron-static-shape

Conversation

@ayulockedin
Copy link
Contributor

Description

This PR resolves the issue where pt.linalg.kron destroys static shape information, returning (None, None) even when input shapes are fully known.

The previous implementation relied on a vector-wise symbolic multiplication of shapes:
out_shape = tuple(a.shape * b.shape)

This forced the underlying Reshape Op to treat the entire shape vector as a single symbolic entity, which prevented the ShapeFeature from constant-folding individual dimensions into their static values.

The Fix

I implemented element-wise symbolic multiplication for the output shape:
[a.shape[i] * b.shape[i] for i in range(a.ndim)]

This provides the shape inference engine with enough granularity to resolve static constants (e.g., 4 * 3 = 12) at compile time while maintaining the symbolic integrity required for downstream operations like clone_replace.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ayulockedin
Copy link
Contributor Author

@jessegrabowski can you have a look at this when u have a moment thx

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solution looks good. Just needs some cleanup.

Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
@ayulockedin ayulockedin force-pushed the fix-kron-static-shape branch from 302e014 to 6f0e23c Compare February 23, 2026 10:39
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks

@jessegrabowski jessegrabowski merged commit 33d9421 into pymc-devs:main Feb 23, 2026
66 checks passed
@jessegrabowski jessegrabowski added enhancement New feature or request shape inference linalg Linear algebra labels Feb 23, 2026
pytest.skip("Sum of shp0 and shp1 must be more than 2")
x = tensor(dtype="floatX", shape=(None,) * len(shp0))

x = tensor(dtype="floatX", shape=shp0)
Copy link
Member

@ricardoV94 ricardoV94 Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a bit short-sighted. It's useful to make sure these shape Ops work correctly with non static shapes. In the future it could be accidentally relying on x.type.shape internally and fail with non-static shapes and we may not notice it.

The separate test for static shape would have been fine, and we do that for other Ops.

Copy link
Contributor Author

@ayulockedin ayulockedin Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 should i open another PR to get both non-static and static shapes tests implemented to avoid future hassles, cuz it is right that for long term it would be better

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request linalg Linear algebra shape inference

Projects

None yet

Development

Successfully merging this pull request may close these issues.

pt.linalg.kron destroys static shape information

3 participants