Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dlpack for array interop #10

Merged
merged 29 commits into from
Feb 21, 2022
Merged

Use dlpack for array interop #10

merged 29 commits into from
Feb 21, 2022

Conversation

rejuvyesh
Copy link
Owner

@rejuvyesh
Copy link
Owner Author

julia> judge(median(results), median(no_dlpack[1]))
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "pytorchhub" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "bs=16" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+18.40% => regression)
                          "functorch" => TrialJudgement(+15.16% => regression)
                          "jl" => TrialJudgement(+585.51% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+10.65% => regression)
                          "functorch" => TrialJudgement(+244.91% => regression)
                          "jl" => TrialJudgement(+426.57% => regression)
          "bs=32" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+136.70% => regression)
                          "functorch" => TrialJudgement(+147.69% => regression)
                          "jl" => TrialJudgement(+841.88% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+122.57% => regression)
                          "functorch" => TrialJudgement(+132.32% => regression)
                          "jl" => TrialJudgement(+213.17% => regression)
          "bs=8" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+400.93% => regression)
                          "functorch" => TrialJudgement(+391.76% => regression)
                          "jl" => TrialJudgement(+699.24% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+427.55% => regression)
                          "functorch" => TrialJudgement(+403.57% => regression)
                          "jl" => TrialJudgement(+471.85% => regression)
          "bs=1" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+1582.03% => regression)
                          "functorch" => TrialJudgement(+1324.41% => regression)
                          "jl" => TrialJudgement(+1247.94% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+7.27% => regression)
                          "functorch" => TrialJudgement(+7.67% => regression)
                          "jl" => TrialJudgement(-14.63% => improvement)
  "pytorchmlp" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "bs=16" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-12.30% => improvement)
                          "functorch" => TrialJudgement(-7.23% => improvement)
                          "jl" => TrialJudgement(+19.95% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-1.89% => invariant)
                          "functorch" => TrialJudgement(-2.81% => invariant)
                          "jl" => TrialJudgement(+13.12% => regression)
          "bs=32" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-17.17% => improvement)
                          "functorch" => TrialJudgement(-8.52% => improvement)
                          "jl" => TrialJudgement(+27.38% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-2.17% => invariant)
                          "functorch" => TrialJudgement(-4.80% => invariant)
                          "jl" => TrialJudgement(+19.82% => regression)
          "bs=8" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-15.03% => improvement)
                          "functorch" => TrialJudgement(-6.87% => improvement)
                          "jl" => TrialJudgement(+9.58% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-0.59% => invariant)
                          "functorch" => TrialJudgement(-2.66% => invariant)
                          "jl" => TrialJudgement(+6.15% => regression)
          "bs=1" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-14.15% => improvement)
                          "functorch" => TrialJudgement(-7.62% => improvement)
                          "jl" => TrialJudgement(+2.07% => invariant)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-4.48% => invariant)
                          "functorch" => TrialJudgement(-4.81% => invariant)
                          "jl" => TrialJudgement(-0.07% => invariant)

Might be doing this wrong but not beneficial on CPU? Need a better machine to evaluate.

@rejuvyesh
Copy link
Owner Author

Likely some issue with GC.@preserve. Need to figure out a MWE.

julia> for i in 1:10; TestEnv.activate() do; include("test/test_pytorch.jl"); end; end
Precompiling project...
  1 dependency successfully precompiled in 2 seconds (16 already precompiled)
Test Summary: | Pass  Total
dlpack        |    4      4
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│   caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:68
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.009436185 0.0996897 0.20193027; 0.38426346 -0.28908443 0.08106785], Float32[1.938506f-39 0.0 5.01105f-33; 0.0 4.5915f-41 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:68 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:68
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.096376784, -0.57178324], Float32[8.28208f-40, 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:68 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:73
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.009436185 0.0996897 0.20193027; 0.38426346 -0.28908443 0.08106785], Float32[1.938506f-39 0.0 5.01105f-33; 0.0 4.5915f-41 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:73 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:73
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.096376784, -0.57178324], Float32[8.28208f-40, 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:73 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:77
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.009436185 0.0996897 0.20193027; 0.38426346 -0.28908443 0.08106785], Float32[1.938506f-39 0.0 5.01105f-33; 0.0 4.5915f-41 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:77 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:77
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.096376784, -0.57178324], Float32[8.28208f-40, 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:77 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
Test Summary: | Pass  Fail  Total
linear        |   14     6     20
ERROR: LoadError: Some tests did not pass: 14 passed, 6 failed, 0 errored, 0 broken.
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:58

vs
All tests pass in:

julia> for i in 1:10; GC.enable(false); TestEnv.activate() do; include("test/test_pytorch.jl"); end; GC.enable(true) end

src/jax.jl Outdated Show resolved Hide resolved
src/jax.jl Outdated Show resolved Hide resolved
src/jax.jl Outdated Show resolved Hide resolved

maybecontiguous(x::AbstractArray) = Array(x)
mayebecontiguous(x::StridedArray) = x
function maybecontiguous(x::FillArrays.AbstractFill)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't be the best way to handle FillArrays?

@rejuvyesh rejuvyesh marked this pull request as ready for review February 21, 2022 05:17
@rejuvyesh rejuvyesh merged commit 7b65d86 into main Feb 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants