Summary
ReadDataFallback (added by #196 for #437) classifies XML-modeled tensor-producing APIs that don't have a dedicated TensorGenerator subclass yet as a generic ⊤-shape / UNKNOWN-dtype tensor source. That's enough to satisfy "is this a tensor source?" checks (e.g. Hybridize's getHasTensorParameter()) but loses precision: the ops it covers all do have well-defined output shapes and dtypes the analysis could be capturing.
This issue tracks adding per-op TensorGenerator subclasses with proper precision, so the fallback's ⊤ result becomes the exception rather than the norm. Once every op in scope has its own subclass, ReadDataFallback itself can be removed.
Ops Currently Routed Through ReadDataFallback
These are the ops surfaced by wala/ML#437's investigation that hit the fallback (each hits the analysis but with ⊤ shape / UNKNOWN dtype):
| Op |
Output shape (semantic) |
Output dtype |
tf.rank(t) |
() (scalar) |
int32 |
tf.stack(values, axis) |
t.shape[:axis] + (len(values),) + t.shape[axis:] |
inherits from values[0] |
tf.linspace(start, stop, num) |
(num,) |
float32 (or start.dtype if specified) |
tf.range(start, limit, delta) |
derived from start/limit/delta |
inherits from args (already has Range in dispatch — partial precision) |
tf.boolean_mask(tensor, mask) |
depends on mask |
inherits from tensor |
tf.einsum(equation, *inputs) |
derived from equation |
inherits from inputs |
tf.math.top_k(input, k) |
tuple of (values, indices), both shape input.shape[:-1] + (k,) |
values: input.dtype; indices: int32 |
tf.meshgrid(*xi) |
tuple of len(xi) tensors |
inherits from xi[0] |
tf.image.extract_patches(images, sizes, strides, rates, padding) |
derived |
inherits from images |
tf.math.exp(x) / tf.math.exp2(x) |
x.shape |
x.dtype |
tf.math.argmax(input, axis) |
input.shape minus axis |
int64 (default) or int32 |
tf.math.argmin(input, axis) |
same as argmax |
same |
... (full list per wala/ML#380's 22-op enumeration) |
|
|
Approach
For each op:
- Add a subclass
<Op> extends TensorGenerator (or extends TensorGeneratorWithSingleInput / similar shared base if patterns emerge).
- Override
getDefaultShapes to compute the precise output shape from the input parameter(s).
- Override
getDefaultDTypes to return the right dtype (often inherited from input, sometimes fixed like int32/int64 for argmax/rank).
- Add a dispatch entry:
else if (isType(calledFunction, OP.getDeclaringClass())) return new <Op>(source);.
- Add an Ariadne-side test:
tf2_test_<op>.py + TestTensorflow2Model.test<Op> asserting the precise shape/dtype (per the test-fixture three-check protocol).
- Once the subclass is in place, the
ReadDataFallback no longer fires for that op.
Acceptance Criteria
Cross-Refs
Summary
ReadDataFallback(added by #196 for #437) classifies XML-modeled tensor-producing APIs that don't have a dedicatedTensorGeneratorsubclass yet as a generic ⊤-shape /UNKNOWN-dtype tensor source. That's enough to satisfy "is this a tensor source?" checks (e.g. Hybridize'sgetHasTensorParameter()) but loses precision: the ops it covers all do have well-defined output shapes and dtypes the analysis could be capturing.This issue tracks adding per-op
TensorGeneratorsubclasses with proper precision, so the fallback's ⊤ result becomes the exception rather than the norm. Once every op in scope has its own subclass,ReadDataFallbackitself can be removed.Ops Currently Routed Through
ReadDataFallbackThese are the ops surfaced by
wala/ML#437's investigation that hit the fallback (each hits the analysis but with ⊤ shape /UNKNOWNdtype):tf.rank(t)()(scalar)int32tf.stack(values, axis)t.shape[:axis] + (len(values),) + t.shape[axis:]values[0]tf.linspace(start, stop, num)(num,)float32(orstart.dtypeif specified)tf.range(start, limit, delta)Rangein dispatch — partial precision)tf.boolean_mask(tensor, mask)tensortf.einsum(equation, *inputs)tf.math.top_k(input, k)(values, indices), both shapeinput.shape[:-1] + (k,)values: input.dtype;indices: int32tf.meshgrid(*xi)len(xi)tensorsxi[0]tf.image.extract_patches(images, sizes, strides, rates, padding)imagestf.math.exp(x)/tf.math.exp2(x)x.shapex.dtypetf.math.argmax(input, axis)input.shapeminusaxisint64(default) orint32tf.math.argmin(input, axis)argmaxwala/ML#380's 22-op enumeration)Approach
For each op:
<Op> extends TensorGenerator(orextends TensorGeneratorWithSingleInput/ similar shared base if patterns emerge).getDefaultShapesto compute the precise output shape from the input parameter(s).getDefaultDTypesto return the right dtype (often inherited from input, sometimes fixed likeint32/int64for argmax/rank).else if (isType(calledFunction, OP.getDeclaringClass())) return new <Op>(source);.tf2_test_<op>.py+TestTensorflow2Model.test<Op>asserting the precise shape/dtype (per the test-fixture three-check protocol).ReadDataFallbackno longer fires for that op.Acceptance Criteria
TensorGeneratorsubclasses incom.ibm.wala.cast.python.ml.client/.ReadDataFallbackno longer fires for any test in Ariadne'scom.ibm.wala.cast.python.ml.testsuite (verifiable by running the suite withLOGGER.fineenabled and checking forread_data-pattern fallback log lines).ReadDataFallback,getPropertyReadMemberNames, andgetTensorflowReadDataPropertyNamesremoved fromTensorGeneratorFactory.Cross-Refs
tf.rank/tf.stack/tf.linspace/tf.range/etc. —TensorTypeAnalysislacks entries despite unchanged XML #437 — the regression that motivatedReadDataFallback.reshape()propagation #196 (ponder-lab/ML#196) — addsReadDataFallback; this issue tracks its replacement.read_data/read_datasetmarker allocations intensorflow.xml#380 — separate but related: inliningread_dataXML markers (the analyzer-side work here is the next step after).TensorGeneratorFactory.getFunction()anddispatchByPropertyNameexposed by #437 #448 —getFunction()design refactor; orthogonal but would let the dispatch/fallback boundary be cleaner.