Skip to content

Add dedicated TensorGenerator subclasses for #437/ReadDataFallback ops #449

@khatchad

Description

@khatchad

Summary

ReadDataFallback (added by #196 for #437) classifies XML-modeled tensor-producing APIs that don't have a dedicated TensorGenerator subclass yet as a generic ⊤-shape / UNKNOWN-dtype tensor source. That's enough to satisfy "is this a tensor source?" checks (e.g. Hybridize's getHasTensorParameter()) but loses precision: the ops it covers all do have well-defined output shapes and dtypes the analysis could be capturing.

This issue tracks adding per-op TensorGenerator subclasses with proper precision, so the fallback's ⊤ result becomes the exception rather than the norm. Once every op in scope has its own subclass, ReadDataFallback itself can be removed.

Ops Currently Routed Through ReadDataFallback

These are the ops surfaced by wala/ML#437's investigation that hit the fallback (each hits the analysis but with ⊤ shape / UNKNOWN dtype):

Op Output shape (semantic) Output dtype
tf.rank(t) () (scalar) int32
tf.stack(values, axis) t.shape[:axis] + (len(values),) + t.shape[axis:] inherits from values[0]
tf.linspace(start, stop, num) (num,) float32 (or start.dtype if specified)
tf.range(start, limit, delta) derived from start/limit/delta inherits from args (already has Range in dispatch — partial precision)
tf.boolean_mask(tensor, mask) depends on mask inherits from tensor
tf.einsum(equation, *inputs) derived from equation inherits from inputs
tf.math.top_k(input, k) tuple of (values, indices), both shape input.shape[:-1] + (k,) values: input.dtype; indices: int32
tf.meshgrid(*xi) tuple of len(xi) tensors inherits from xi[0]
tf.image.extract_patches(images, sizes, strides, rates, padding) derived inherits from images
tf.math.exp(x) / tf.math.exp2(x) x.shape x.dtype
tf.math.argmax(input, axis) input.shape minus axis int64 (default) or int32
tf.math.argmin(input, axis) same as argmax same
... (full list per wala/ML#380's 22-op enumeration)

Approach

For each op:

  1. Add a subclass <Op> extends TensorGenerator (or extends TensorGeneratorWithSingleInput / similar shared base if patterns emerge).
  2. Override getDefaultShapes to compute the precise output shape from the input parameter(s).
  3. Override getDefaultDTypes to return the right dtype (often inherited from input, sometimes fixed like int32/int64 for argmax/rank).
  4. Add a dispatch entry: else if (isType(calledFunction, OP.getDeclaringClass())) return new <Op>(source);.
  5. Add an Ariadne-side test: tf2_test_<op>.py + TestTensorflow2Model.test<Op> asserting the precise shape/dtype (per the test-fixture three-check protocol).
  6. Once the subclass is in place, the ReadDataFallback no longer fires for that op.

Acceptance Criteria

  • All ops in the table above have dedicated TensorGenerator subclasses in com.ibm.wala.cast.python.ml.client/.
  • Each op has Ariadne-side tests asserting the precise shape/dtype.
  • ReadDataFallback no longer fires for any test in Ariadne's com.ibm.wala.cast.python.ml.test suite (verifiable by running the suite with LOGGER.fine enabled and checking for read_data-pattern fallback log lines).
  • ReadDataFallback, getPropertyReadMemberNames, and getTensorflowReadDataPropertyNames removed from TensorGeneratorFactory.

Cross-Refs

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions