-
Notifications
You must be signed in to change notification settings - Fork 116
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1309 from spcl/port-tensor-related-library-nodes
Cherry-picks tensor-related libraries/replacements
- Loading branch information
Showing
27 changed files
with
1,080 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. | ||
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. | ||
from .matmul import MatMul | ||
from .dot import Dot | ||
from .gemv import Gemv | ||
from .gemm import Gemm | ||
from .ger import Ger | ||
from .batched_matmul import BatchedMatMul | ||
from .transpose import Transpose | ||
|
||
from .axpy import Axpy | ||
from .einsum import Einsum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. | ||
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. | ||
from .cusolverdn import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. | ||
from .cutensor import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. | ||
import dace.library | ||
|
||
|
||
@dace.library.environment | ||
class cuTensor: | ||
|
||
cmake_minimum_version = None | ||
cmake_packages = ["CUDA"] | ||
cmake_variables = {} | ||
cmake_includes = [] | ||
cmake_libraries = ["cutensor"] | ||
cmake_compile_flags = [] | ||
cmake_link_flags = ["-L -lcutensor"] | ||
cmake_files = [] | ||
|
||
headers = {'frame': ["../include/dace_cutensor.h"], 'cuda': ["../include/dace_cutensor.h"]} | ||
state_fields = ["dace::linalg::CuTensorHandle cutensor_handle;"] | ||
init_code = "" | ||
finalize_code = "" | ||
dependencies = [] | ||
|
||
@staticmethod | ||
def handle_setup_code(node): | ||
location = node.location | ||
if not location or "gpu" not in node.location: | ||
location = 0 | ||
else: | ||
try: | ||
location = int(location["gpu"]) | ||
except ValueError: | ||
raise ValueError("Invalid GPU identifier: {}".format(location)) | ||
|
||
code = """\ | ||
const int __dace_cuda_device = {location}; | ||
cutensorHandle_t &__dace_cutensor_handle = __state->cutensor_handle.Get(__dace_cuda_device); | ||
// cutensorSetStream(__dace_cutensor_handle, __dace_current_stream);\n""" | ||
|
||
return code.format(location=location) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. | ||
#pragma once | ||
|
||
#include <cuda_runtime.h> | ||
#include <cutensor.h> | ||
|
||
#include <cstddef> // size_t | ||
#include <stdexcept> // std::runtime_error | ||
#include <string> // std::to_string | ||
#include <unordered_map> | ||
|
||
namespace dace { | ||
|
||
namespace linalg { | ||
|
||
static void CheckCuTensorError(cutensorStatus_t const& status) { | ||
if (status != CUTENSOR_STATUS_SUCCESS) { | ||
throw std::runtime_error("cuTENSOR failed with error code: " + std::string(cutensorGetErrorString(status))); | ||
} | ||
} | ||
|
||
static cutensorHandle_t CreateCuTensorHandle(int device) { | ||
if (cudaSetDevice(device) != cudaSuccess) { | ||
throw std::runtime_error("Failed to set CUDA device."); | ||
} | ||
cutensorHandle_t handle; | ||
CheckCuTensorError(cutensorInit(&handle)); | ||
return handle; | ||
} | ||
|
||
/** | ||
* cuTENSOR wrapper class for DaCe. Once constructed, the class can be used to | ||
* get or create a cuTENSOR library handle (cutensorHandle_t) for a given | ||
* GPU ID. The class is constructed when the cuTENSOR DaCe library is used. | ||
**/ | ||
class CuTensorHandle { | ||
public: | ||
CuTensorHandle() = default; | ||
CuTensorHandle(CuTensorHandle const&) = delete; | ||
|
||
cutensorHandle_t& Get(int device) { | ||
auto f = handles_.find(device); | ||
if (f == handles_.end()) { | ||
// Lazily construct new cuSolverDn handle if the specified key does not | ||
// yet exist | ||
auto handle = CreateCuTensorHandle(device); | ||
f = handles_.emplace(device, handle).first; | ||
} | ||
return f->second; | ||
} | ||
|
||
~CuTensorHandle() { | ||
// NOTE: It seems that the cuTENSOR API is missing a method of destroying a cuTENSOR handle | ||
// for (auto& h : handles_) { | ||
// CheckCuTensorError(cutensorDestroy(h.second)); | ||
// } | ||
} | ||
|
||
CuTensorHandle& operator=(CuTensorHandle const&) = delete; | ||
|
||
std::unordered_map<int, cutensorHandle_t> handles_; | ||
}; | ||
|
||
} // namespace linalg | ||
|
||
} // namespace dace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. | ||
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. | ||
from .inv import Inv | ||
from .solve import Solve | ||
from .cholesky import Cholesky | ||
from .tensordot import TensorDot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.