Skip to content

Use a typed RaggedDim sentinel for ragged dimensions instead of raw null in TensorType.dims #544

@khatchad

Description

@khatchad

Summary

TensorType.dims lists contain raw null entries to mark ragged dimensions, documented only in implementation comments at RaggedConstant.java:431 and RaggedFromNestedRowLengths.java:180. This is the only per-element encoding that isn't a typed Dimension<?> subclass, and it's undocumented in the public API.

Current State

  • DimensionType enum has three values: Constant, Symbolic, Compound.
  • Dimension<T> abstract class has three subclasses: NumericDim, SymbolicDim, CompoundDim.
  • Neither covers ragged dims; ragged-tensor generators emit raw null into the dim list:
    • RaggedConstant.java:431: for (Long i = 0L; i < R; i++) shape.add(null); // Unknown size for ragged dimensions.
    • RaggedFromNestedRowLengths.java:180: // Then K ragged dimensions (represented as null)
  • TensorType implements Iterable<Dimension<?>>, but per-element null violates the implicit non-null-element contract.
  • The TensorType Javadoc documents whole-list dims == null (shape-⊤) but does not mention per-element null.
  • The null sentinel also has hidden coupling with TensorShapeUtil.areBroadcastable and TensorShapeUtil.getBroadcastedShapes, which special-case xDim == null as "compatible; propagate the other side." A typed RaggedDim requires parallel instanceof RaggedDim handling in both helpers, otherwise broadcast-on-ragged paths (testAdd66 through testAdd99, testGradient, testGradient2) throw NonBroadcastableShapesException.

Impact

Consumers iterating TensorType.dims must defensively null-check each element or risk NPEs on dim.type(), dim.value(), dim.equals(). Internal null-guards at TensorType.java:491-492, 501-502 already compensate.

Downstream tooling (e.g., Hybridize's input-signature inference) collapses null to SymbolicDim("?"), losing the structural raggedness signal—a precision regression that stems directly from the representation.

Proposed Fix

Introduce RaggedDim extends Dimension<Void> and DimensionType.Ragged. Migrate the two ragged generators to emit new RaggedDim() instead of null. Discriminate via instanceof RaggedDim or dim.type() == Ragged. The Iterable<Dimension<?>> contract holds.

Whole-list dims == null (shape-⊤) is unaffected.

Background

Surfaced by a precision audit in ponder-lab/Hybridize-Functions-Refactoring#522. Fixture: testHasLikelyTensorParameter59 (tf.RaggedTensor.from_nested_row_splits).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions