This repository has been archived by the owner on Jan 13, 2024. It is now read-only.
/
test_MODULE_pycuda.py
60 lines (46 loc) · 1.81 KB
/
test_MODULE_pycuda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
@brief test log(time=6s)
"""
import sys
import os
import unittest
import warnings
from pyquickhelper.loghelper import fLOG
from pyquickhelper.pycode import is_travis_or_appveyor
class TestModulesCuda(unittest.TestCase):
@unittest.skipIf(is_travis_or_appveyor() is not None, "nopycuda on CI")
def test_cuda(self):
fLOG(__file__, self._testMethodName, OutputPrint=__name__ == "__main__")
if sys.platform.startswith("win"):
dll = os.path.join("c:\\Windows\\System32", "NVCUDA.DLL")
if not os.path.exists(dll):
warnings.warn("Missing DLL: " + dll)
return
try:
import pycuda.driver as drv # pylint: disable=C0415
except ImportError as e:
warnings.warn("No pycuda installed: {0}".format(e)) # pylint: disable=C0209
return
import numpy # pylint: disable=C0415
from pycuda.compiler import SourceModule # pylint: disable=C0415
options = None if sys.platform.startswith(
"win") else ["-ccbin", "clang-3.8"]
mod = SourceModule("""
__global__ void multiply_them(float *dest, float *a, float *b)
{
const int i = threadIdx.x;
dest[i] = a[i] * b[i];
}
""", options=options, cache_dir=".")
multiply_them = mod.get_function("multiply_them")
a = numpy.random.randn(400).astype( # pylint: disable=E1101
numpy.float32) # pylint: disable=E1101
b = numpy.random.randn(400).astype( # pylint: disable=E1101
numpy.float32) # pylint: disable=E1101
dest = numpy.zeros_like(a)
multiply_them(
drv.Out(dest), drv.In(a), drv.In(b),
block=(400, 1, 1), grid=(1, 1))
fLOG(dest - a * b)
if __name__ == "__main__":
unittest.main()