Skip to content

Commit

Permalink
Merge pull request #1309 from spcl/port-tensor-related-library-nodes
Browse files Browse the repository at this point in the history
Cherry-picks tensor-related libraries/replacements
  • Loading branch information
tbennun committed Jul 16, 2023
2 parents 531b0ae + b52aaba commit a2cb57f
Show file tree
Hide file tree
Showing 27 changed files with 1,080 additions and 62 deletions.
90 changes: 79 additions & 11 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace

import ast
Expand Down Expand Up @@ -778,19 +778,29 @@ def _transpose(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, inpname: str, a
if axes == (1, 0): # Special case for 2D transposition
acc1 = state.add_read(inpname)
acc2 = state.add_write(outname)
import dace.libraries.blas # Avoid import loop
tasklet = dace.libraries.blas.Transpose('_Transpose_', restype)
import dace.libraries.standard # Avoid import loop
tasklet = dace.libraries.standard.Transpose('_Transpose_', restype)
state.add_node(tasklet)
state.add_edge(acc1, None, tasklet, '_inp', Memlet.from_array(inpname, arr1))
state.add_edge(tasklet, '_out', acc2, None, Memlet.from_array(outname, arr2))
else:
state.add_mapped_tasklet(
"_transpose_", {"_i{}".format(i): "0:{}".format(s)
for i, s in enumerate(arr1.shape)},
dict(_in=Memlet.simple(inpname, ", ".join("_i{}".format(i) for i, _ in enumerate(arr1.shape)))),
"_out = _in",
dict(_out=Memlet.simple(outname, ", ".join("_i{}".format(axes[i]) for i, _ in enumerate(arr1.shape)))),
external_edges=True)
else: # Tensor transpose
modes = len(arr1.shape)
idx = axes.index(0)
# Special case of tensor transposition: matrix transpose + reshape
if axes[idx:] == list(range(modes - idx)) and axes[:idx] == list(range(axes[-1] + 1, modes)):
rows = data._prod([arr1.shape[axes[i]] for i in range(idx, len(arr1.shape))])
cols = data._prod([arr1.shape[axes[i]] for i in range(idx)])
matrix = _ndarray_reshape(pv, sdfg, state, inpname, [rows, cols])
trans_matrix = _transpose(pv, sdfg, state, matrix)
return _ndarray_reshape(pv, sdfg, state, trans_matrix, [arr1.shape[i] for i in axes])

read = state.add_read(inpname)
write = state.add_write(outname)
from dace.libraries.standard import TensorTranspose
tasklet = TensorTranspose('_TensorTranspose', axes or list(range(len(arr1.shape))))
state.add_node(tasklet)
state.add_edge(read, None, tasklet, '_inp_tensor', Memlet.from_array(inpname, arr1))
state.add_edge(tasklet, '_out_tensor', write, None, Memlet.from_array(outname, arr2))

return outname

Expand Down Expand Up @@ -4539,6 +4549,64 @@ def _inv(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, inp_op: str):
return out_arr[0]


@oprepo.replaces('dace.tensordot')
@oprepo.replaces('numpy.tensordot')
def _tensordot(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
op_a: str,
op_b: str,
axes: Union[int, Sequence[int]] = 2,
out_axes: Sequence[int] = None):

# NOTE: `out_axes` is a non-standard extension to `numpy.tensordot`, allowing trasposition of the output

for op in (op_a, op_b):
if not isinstance(op, str) or not op in sdfg.arrays.keys():
raise SyntaxError()

arr_a = sdfg.arrays[op_a]
arr_b = sdfg.arrays[op_b]

if isinstance(axes, Integral):
left_axes = list(range(len(arr_a.shape) - axes, len(arr_a.shape)))
right_axes = list(range(0, axes))
else:
left_axes = axes[0]
right_axes = axes[1]

# Some validation (more detailed validation is done inside the TensorDot library node)
if any(a >= len(arr_a.shape) or a < 0 for a in left_axes):
raise ValueError("Axes for left tensor are out-of-bounds.")
if any(a >= len(arr_b.shape) or a < 0 for a in right_axes):
raise ValueError("Axes for right tensor are out-of-bounds.")
if len(left_axes) != len(right_axes):
raise ValueError("The input tensors must have the same number of contracting modes.")
if any(arr_a.shape[l] != arr_b.shape[r] for l, r in zip(left_axes, right_axes)):
raise ValueError("The input tensors' contracting modes must have the same length.")

dot_shape = [s for i, s in enumerate(arr_a.shape) if i not in left_axes]
dot_shape.extend([s for i, s in enumerate(arr_b.shape) if i not in right_axes])

if out_axes:
if list(sorted(out_axes)) != list(range(len(dot_shape))):
raise ValueError("Output axes is not a permutation of the output's modes.")
dot_shape = [dot_shape[i] for i in out_axes]

op_c, arr_c = sdfg.add_temp_transient(dot_shape, arr_a.dtype, storage=arr_a.storage)

from dace.libraries.linalg import TensorDot
a = state.add_read(op_a)
b = state.add_read(op_b)
c = state.add_write(op_c)
tasklet = TensorDot("_TensorDot_", left_axes, right_axes, out_axes)
state.add_edge(a, None, tasklet, '_left_tensor', Memlet.from_array(op_a, arr_a))
state.add_edge(b, None, tasklet, '_right_tensor', Memlet.from_array(op_b, arr_b))
state.add_edge(tasklet, '_out_tensor', c, None, Memlet.from_array(op_c, arr_c))

return op_c


# CuPy replacements


Expand Down
3 changes: 1 addition & 2 deletions dace/libraries/blas/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from .matmul import MatMul
from .dot import Dot
from .gemv import Gemv
from .gemm import Gemm
from .ger import Ger
from .batched_matmul import BatchedMatMul
from .transpose import Transpose

from .axpy import Axpy
from .einsum import Einsum
2 changes: 1 addition & 1 deletion dace/libraries/lapack/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from .cusolverdn import *
2 changes: 1 addition & 1 deletion dace/libraries/lapack/environments/cusolverdn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace.library


Expand Down
2 changes: 1 addition & 1 deletion dace/libraries/lapack/include/dace_cusolverdn.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
// Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
#pragma once

#include <cuda_runtime.h>
Expand Down
2 changes: 1 addition & 1 deletion dace/libraries/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from dace.library import register_library
from .nodes import *

Expand Down
2 changes: 2 additions & 0 deletions dace/libraries/linalg/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
from .cutensor import *
39 changes: 39 additions & 0 deletions dace/libraries/linalg/environments/cutensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace.library


@dace.library.environment
class cuTensor:

cmake_minimum_version = None
cmake_packages = ["CUDA"]
cmake_variables = {}
cmake_includes = []
cmake_libraries = ["cutensor"]
cmake_compile_flags = []
cmake_link_flags = ["-L -lcutensor"]
cmake_files = []

headers = {'frame': ["../include/dace_cutensor.h"], 'cuda': ["../include/dace_cutensor.h"]}
state_fields = ["dace::linalg::CuTensorHandle cutensor_handle;"]
init_code = ""
finalize_code = ""
dependencies = []

@staticmethod
def handle_setup_code(node):
location = node.location
if not location or "gpu" not in node.location:
location = 0
else:
try:
location = int(location["gpu"])
except ValueError:
raise ValueError("Invalid GPU identifier: {}".format(location))

code = """\
const int __dace_cuda_device = {location};
cutensorHandle_t &__dace_cutensor_handle = __state->cutensor_handle.Get(__dace_cuda_device);
// cutensorSetStream(__dace_cutensor_handle, __dace_current_stream);\n"""

return code.format(location=location)
66 changes: 66 additions & 0 deletions dace/libraries/linalg/include/dace_cutensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
#pragma once

#include <cuda_runtime.h>
#include <cutensor.h>

#include <cstddef> // size_t
#include <stdexcept> // std::runtime_error
#include <string> // std::to_string
#include <unordered_map>

namespace dace {

namespace linalg {

static void CheckCuTensorError(cutensorStatus_t const& status) {
if (status != CUTENSOR_STATUS_SUCCESS) {
throw std::runtime_error("cuTENSOR failed with error code: " + std::string(cutensorGetErrorString(status)));
}
}

static cutensorHandle_t CreateCuTensorHandle(int device) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
}
cutensorHandle_t handle;
CheckCuTensorError(cutensorInit(&handle));
return handle;
}

/**
* cuTENSOR wrapper class for DaCe. Once constructed, the class can be used to
* get or create a cuTENSOR library handle (cutensorHandle_t) for a given
* GPU ID. The class is constructed when the cuTENSOR DaCe library is used.
**/
class CuTensorHandle {
public:
CuTensorHandle() = default;
CuTensorHandle(CuTensorHandle const&) = delete;

cutensorHandle_t& Get(int device) {
auto f = handles_.find(device);
if (f == handles_.end()) {
// Lazily construct new cuSolverDn handle if the specified key does not
// yet exist
auto handle = CreateCuTensorHandle(device);
f = handles_.emplace(device, handle).first;
}
return f->second;
}

~CuTensorHandle() {
// NOTE: It seems that the cuTENSOR API is missing a method of destroying a cuTENSOR handle
// for (auto& h : handles_) {
// CheckCuTensorError(cutensorDestroy(h.second));
// }
}

CuTensorHandle& operator=(CuTensorHandle const&) = delete;

std::unordered_map<int, cutensorHandle_t> handles_;
};

} // namespace linalg

} // namespace dace
3 changes: 2 additions & 1 deletion dace/libraries/linalg/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from .inv import Inv
from .solve import Solve
from .cholesky import Cholesky
from .tensordot import TensorDot
7 changes: 4 additions & 3 deletions dace/libraries/linalg/nodes/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import copy
import dace.library
import dace.properties
import dace.sdfg.nodes
from dace import dtypes, Memlet
from dace.libraries.blas import Transpose

from dace import Memlet
from dace.libraries.lapack import Potrf
from dace.libraries.standard import Transpose
from dace.transformation.transformation import ExpandTransformation
from dace.libraries.lapack import environments
from dace.libraries.blas import environments as blas_environments
Expand Down
7 changes: 4 additions & 3 deletions dace/libraries/linalg/nodes/solve.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import copy
import dace
import dace.library
import dace.properties
import dace.sdfg.nodes
import numpy as np

from dace import Memlet
from dace.libraries.blas.nodes import Transpose
from dace.libraries.lapack.nodes import Getrf, Getrs
from dace.libraries.lapack import Getrf, Getrs
from dace.libraries.standard import Transpose
from dace.transformation.transformation import ExpandTransformation
from dace.libraries.lapack import environments
from dace.libraries.blas import environments as blas_environments
Expand Down

0 comments on commit a2cb57f

Please sign in to comment.