diff --git a/src/qibotn/backends/cutensornet.py b/src/qibotn/backends/cutensornet.py index c8341bd..1d38520 100644 --- a/src/qibotn/backends/cutensornet.py +++ b/src/qibotn/backends/cutensornet.py @@ -1,19 +1,9 @@ -import cuquantum # pylint: disable=import-error import numpy as np from qibo.backends.numpy import NumpyBackend from qibo.config import raise_error from qibo.result import QuantumState -CUDA_TYPES = { - "complex64": ( - cuquantum.cudaDataType.CUDA_C_32F, - cuquantum.ComputeType.COMPUTE_32F, - ), - "complex128": ( - cuquantum.cudaDataType.CUDA_C_64F, - cuquantum.ComputeType.COMPUTE_64F, - ), -} +CUDA_TYPES = {} class CuTensorNet(NumpyBackend): # pragma: no cover @@ -75,6 +65,18 @@ def __init__(self, runcard): self.supports_multigpu = True self.handle = self.cutn.create() + global CUDA_TYPES + CUDA_TYPES = { + "complex64": ( + self.cuquantum.cudaDataType.CUDA_C_32F, + self.cuquantum.ComputeType.COMPUTE_32F, + ), + "complex128": ( + self.cuquantum.cudaDataType.CUDA_C_64F, + self.cuquantum.ComputeType.COMPUTE_64F, + ), + } + def apply_gate(self, gate, state, nqubits): # pragma: no cover raise_error(NotImplementedError, "QiboTN cannot apply gates directly.")