Skip to content

GPUEngine.Reshape silently ignores dst argument (compute.Engine contract violation) #81

@dndungu

Description

@dndungu

Problem

compute/gpu_engine_memory.go:614 GPUEngine.Reshape takes a variadic dst ...*tensor.TensorNumeric[T] parameter, matching the compute.Engine contract shared with CPUEngine.Reshape. However the zero-copy GPUStorage fast-path at line 644 ignores dst entirely:

// GPUStorage[T]: zero-copy reshape.
if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && isFloat32[T]() && newSize == currentSize {
    return tensor.NewWithStorage[T](inferredShape, gs.View(gs.Len()))
}

It returns a brand-new tensor aliasing a's storage and never touches the caller-provided dst. Callers that (like other ops in the engine) discard the return value and rely on dst being the updated tensor silently receive stale pre-allocated storage.

Impact

This is what caused the PatchTST GPU training convergence regression tracked in #79 — see the postmortem on that issue. A single-line zerfoo workaround (use the return value) restored convergence, but any future caller that follows the same pattern as other engine ops (pass a dst, discard the return) will re-hit the silent-zero trap.

Expected behavior

When dst is provided, Reshape should either:

  • (a) SetStorage on dst[0] to alias the view, matching the zero-copy semantics and leaving the GPUStorage shared with a; or
  • (b) copy the source data into dst[0]'s existing storage, matching the caller's intent to use its pre-allocated buffer.

Option (a) preserves the zero-copy win and is consistent with how other engine ops mutate dst. Either choice must be a deliberate documented contract — currently the behavior is neither.

Repro

See .claude/scratch/wave-7-probe-logs.txt in zerfoo main (commit 1ebdb787) for first-batch GPU probe evidence showing fc.dFlat populated by MatMul and fc.dX (dst of Reshape) left all-zero with a different storage pointer.

Suggested fix

In gpu_engine_memory.go:644, after constructing the view, if len(dst) > 0 && dst[0] != nil, update dst[0]'s storage and shape (SetStorage + shape update) and return dst[0]. Add a test in compute/ that passes a pre-allocated dst with poisoned data and asserts the returned tensor AND dst both reflect the reshaped source.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions