Skip to content

TensorOperations.jl compatible fast contractor for Julia, based on TBLIS, with generic strides and automatic differentiation support, within 400 lines.

License

Notifications You must be signed in to change notification settings

xrq-phys/BliContractor.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

56 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BliContractor.jl

CI

Fast tensor contractor for Julia, based on TBLIS, with high-order AD and Stride support, within 400† lines.
† Result may vary as more dispatch rules are added.

  • All these are made possible thanks to TBLIS;

Installation

] add BliContractor

This will link the Julia package against TBLIS library vendored by tblis_jll.

If one wants to use their own TBLIS build, specify their TBLIS installation root with export TBLISDIR=${PathToYourTBLIS}, start Julia and run:

] build BliContractor

BliContractor.jl will be relinked to use the user-defined TBLIS installation. Build steps as well as environment specification needs to be done only once.

Usage

From TensorOperations.jl

My implementation now contains necessary overriding of TensorOperation.jl's CPU backend method. One can directly invoke the @tensor macro (or the ncon function, etc.) and reach TBLIS backend.

using TensorOperations
using BliContractor

A = rand(10, 10, 10, 10);
B = rand(10, 10, 10, 10);
C = ones(10, 10, 10, 10);
@tensor C[i, a, k, c] = A[i, j, k, l] * B[a, l, c, j]

Supported datatypes are Float32, Float64, ComplexF32, ComplexF64, Dual{N, Float32} and Dual{N, Float64}. See also below for AD support.

As Standalone Package

The simplest API is given by contract:

using BliContractor
using ForwardDiff: Dual
At = rand( 6, 4, 5) * Dual(1.0, 1.0);
Bt = rand(10, 5, 4) * Dual(1.0, 0.0);
contract(At, "ikl", Bt, "jlk", "ij")
# or equivalently:
contract(At, Bt, "ikl", "jlk", "ij")

Index notation here is the same as TBLIS, namely the Einstein's summation rules. This contract (with exclamation mark !) is also the only subroutine with Zygote's backward derivative support (while all subroutines in this module supports ForwardDiff's forward differential).

If one's having destination tensor C preallocated, a contract! routine (which is in fact called by contract) is also available:

Ct = zeros(6, 10) * Dual(1.0, 0.0);
contract!(At, "ikl", Bt, "jlk", Ct, "ij")

Tensors can be Arrays or strided SubArrays:

Aw = view(At, :, 1:2:4, :);
# Unlike the case of BLAS,
# first dimension is not required to be 1 for performance to be nice:
Bw = view(Bt, 1:2:10, :, 1:2:4);
Cv = zeros(6, 5) * Dual(1.0, 0.0);
contract!(Aw, "ikl", Bw, "jlk", Cv, "ij")

Roadmap

  • Explicitly dispatch mixed multiplication of plain values with Duals, e.g. (Float64, Dual{Tag, Float64}) or (Dual{Tag, Float64}, Dual{Tag, Dual{Tag, Float64}}), though they are already available via type conversion;
  • Let it play well with Zygote.jl, to at least 1st order;
  • Enable 2nd order pullback for Zygote.jl.

On 2nd Derivative with Zygote.jl

Second derivative through hessian is already working on Zygote.jl's master branch, but taking pullback 2 times requires something more which is currently only available in the upstream development branch: DhairyaLGandhi/Zygote.jl/dg/iddict.

Performance

Here are two benchmark reports collected on the following system:

  • Processor Model: Intel® Xeon® Platinum 8260
  • SIMD Width: 512bits (AVX512)
  • FP Pipelines: 2 AVX512 pipelines per core
  • Frequency Throttling: Yes, even on serial execution.
  • Basic Clock-Frequency: 2.40GHz
  • OS: CentOS Linux 7
  • Julia: 1.5.2 (official)
  • OpenMP # of Threads: 4

GEMM-Incompatible Contractions

A contraction which can not be handled by BLAS' GEMM routines is tested to show superiority of TBLIS over blocked-GEMM calls launched by TensorOperations.jl.

Generic Strided Tensors

Generic-strided tensors are directly supported by TBLIS while in TensorOperations.jl they fall back to hand-written loops without BLAS assembly call. Tensor entries in this test has datatype ForwardDiff.Dual{<:Any, Float64, 1}, which mean that the each variable part (value or 1st differential) is stored with 1-element skip within each "column" (A[:, l, m, n]) of the tensor, forming a generic-strided storage.

About

TensorOperations.jl compatible fast contractor for Julia, based on TBLIS, with generic strides and automatic differentiation support, within 400 lines.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published