Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting issue in MMM operation #534

Closed
mtsokol opened this issue May 8, 2024 · 2 comments · Fixed by #553
Closed

Broadcasting issue in MMM operation #534

mtsokol opened this issue May 8, 2024 · 2 comments · Fixed by #553
Labels
bug Something isn't working

Comments

@mtsokol
Copy link
Member

mtsokol commented May 8, 2024

Hi @willow-ahrens,

Here's a reproduction code for broadcasting issue that I found while implementing SDDMM:

using Finch

LEN = 10;
a_raw = rand(LEN, LEN - 5) * 10;
b_raw = rand(LEN, LEN - 5) * 10;
c_raw = rand(LEN, LEN) * 10;

a = lazy(swizzle(Tensor(a_raw), 1, 2));
b = lazy(swizzle(Tensor(b_raw), 1, 2));
c = lazy(swizzle(Tensor(c_raw), 1, 2));

# doesn't equal
plan = broadcast(*, broadcast(*, c[:, :, nothing], a[:, nothing, :]), b[nothing, :, :]);
# works correctly
# plan = c[:, :, nothing] .* a[:, nothing, :] .* b[nothing, :, :];

result = compute(plan);

actual = reshape(c_raw, 10, 10, 1) .* reshape(a_raw, 10, 1, 5) .* reshape(b_raw, 1, 10, 5);
# other notation
# actual = broadcast(*, broadcast(*, reshape(c_raw, 10, 10, 1), reshape(a_raw, 10, 1, 5)), reshape(b_raw, 1, 10, 5));

isequal(result, actual)
@mtsokol mtsokol added the bug Something isn't working label May 8, 2024
@mtsokol
Copy link
Member Author

mtsokol commented May 9, 2024

@willow-ahrens Here's the output of the debug mode for .* broadcasting that works correctly:

plan = c[:, :, nothing] .* a[:, nothing, :] .* b[nothing, :, :];
compute(plan, verbose=true);
Executing:
:(function var"##compute#378"(prgm)
      begin
          V = (((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[2]).children[2]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          V_2 = (((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[2]).children[3]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          V_3 = ((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[3]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          A0 = V::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          A2 = V_2::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          A4 = V_3::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          A6 = Tensor(Dense(Dense(Dense(Element{0.0, Float64}()))))::Tensor{DenseLevel{Int64, DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}}
          @finch mode = :fast begin
                  A6 .= 0.0
                  for i20 = _
                      for i12 = _
                          for i11 = _
                              A6[i11, i12, i20] << Finch.FinchNotation.InitWriter{0.0}() >>= (*)((*)(A0[i11, i12], A2[i11, i20]), A4[i12, i20])
                          end
                      end
                  end
                  return A6
              end
          return (A6,)
      end
  end)

And here's the output of the debug mode for broadcast broadcasting (used internally by finch-tensor) that we're debugging:

plan = broadcast(*, broadcast(*, c[:, :, nothing], a[:, nothing, :]), b[nothing, :, :]);
compute(plan, verbose=true);
Executing:
:(function var"##compute#421"(prgm)
      begin
          V = (((((((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[2]).children[1]).children[1]).children[2]).children[1]).children[2]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          V_2 = (((((((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[2]).children[1]).children[1]).children[2]).children[1]).children[3]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          V_3 = ((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[3]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          A0 = V::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          A2 = V_2::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          A5 = V_3::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          A7 = Tensor(Dense(Dense(Dense(Element{0.0, Float64}()))))::Tensor{DenseLevel{Int64, DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}}
          @finch mode = :fast begin
                  A7 .= 0.0
                  for i31 = _
                      for i30 = _
                          for i21 = _
                              A7[i21, i30, i31] << Finch.FinchNotation.InitWriter{0.0}() >>= (*)((*)(A0[i21, 1], A2[i21, 1]), A5[i30, i31])
                          end
                      end
                  end
                  return A7
              end
          return (A7,)
      end
  end)

I think the key difference (that I found with https://www.diffchecker.com) is in:

A6[i11, i12, i20] << Finch.FinchNotation.InitWriter{0.0}() >>= (*)((*)(A0[i11, i12], A2[i11, i20]), A4[i12, i20])
vs
A7[i21, i30, i31] << Finch.FinchNotation.InitWriter{0.0}() >>= (*)((*)(A0[i21, 1], A2[i21, 1]), A5[i30, i31])

For some reason for broadcast(...) a 1 was placed there instead of the index. WDYT?

@hameerabbasi
Copy link
Collaborator

I think it's during broadcasting that an index would be replaced with 1, right @willow-ahrens? That would mean that the "broadcast indices" are somehow at an incorrect location.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants