diff --git a/setup.py b/setup.py index 9cd6808..77a380d 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,7 @@ def get_extension_tutorial(name): ext_modules = [] for ext in ['dot_blas_lapack', 'dot_cython', 'experiment_cython', 'dot_cython_omp', - 'mul_cython_omp']: + 'mul_cython_omp', 'td_mul_cython']: ext_modules.extend(get_extension_tutorial(ext)) diff --git a/td3a_cpp/tutorial/td_mul_cython.pyx b/td3a_cpp/tutorial/td_mul_cython.pyx new file mode 100644 index 0000000..76aec48 --- /dev/null +++ b/td3a_cpp/tutorial/td_mul_cython.pyx @@ -0,0 +1,42 @@ +""" +TD 2021/01/19. +""" +from cython.parallel import prange +cimport cython +from cpython cimport array +import array + +import numpy as pynumpy +cimport numpy as cnumpy +cnumpy.import_array() + + +def multiply_matrix(m1, m2): + "Matrix multiplication" + m3 = pynumpy.zeros((m1.shape[0], m2.shape[1]), dtype=m1.dtype) + for i in range(0, m1.shape[0]): + for j in range(0, m2.shape[1]): + for k in range(0, m1.shape[1]): + m3[i, j] += m1[i, k] * m2[k, j] + return m3 + + +cdef void _c_multiply_matrix(double[:, :] m1, double[:, :] m2, + double[:, :] m3, + cython.int ni, cython.int nj, cython.int nk) nogil: + "Matrix multiplication wuth cython" + cdef cython.int i, j, k + for i in prange(0, ni): + for j in range(0, nj): + for k in range(0, nk): + m3[i, j] += m1[i, k] * m2[k, j] + + +def c_multiply_matrix(double[:, :] m1, double[:, :] m2): + "Matrix multiplication calling the cython version" + m3 = pynumpy.zeros((m1.shape[0], m2.shape[1]), dtype=pynumpy.float64) + cdef cython.int ni = m1.shape[0] + cdef cython.int nj = m2.shape[1] + cdef cython.int nk = m1.shape[1] + _c_multiply_matrix(m1, m2, m3, ni, nj, nk) + return m3 diff --git a/tests/test_tutorial_td.py b/tests/test_tutorial_td.py new file mode 100644 index 0000000..e906714 --- /dev/null +++ b/tests/test_tutorial_td.py @@ -0,0 +1,46 @@ +""" +""" +import unittest +import timeit +import numpy +from numpy.testing import assert_almost_equal +from td3a_cpp.tutorial.td_mul_cython import ( + multiply_matrix, c_multiply_matrix) + + +class TestTutorialTD(unittest.TestCase): + + def test_matrix_multiply_matrix(self): + va = numpy.random.randn(3, 4).astype(numpy.float64) + vb = numpy.random.randn(4, 5).astype(numpy.float64) + res1 = va @ vb + res2 = multiply_matrix(va, vb) + assert_almost_equal(res1, res2) + + def test_matrix_cmultiply_matrix(self): + va = numpy.random.randn(3, 4).astype(numpy.float64) + vb = numpy.random.randn(4, 5).astype(numpy.float64) + res1 = va @ vb + res2 = c_multiply_matrix(va, vb) + assert_almost_equal(res1, res2) + + def test_timeit(self): + va = numpy.random.randn(300, 400).astype(numpy.float64) + vb = numpy.random.randn(400, 500).astype(numpy.float64) + ctx = {'va': va, 'vb': vb, 'c_multiply_matrix': c_multiply_matrix, + 'multiply_matrix': multiply_matrix} + res1 = timeit.timeit('va @ vb', number=10, globals=ctx) + res2 = timeit.timeit( + 'c_multiply_matrix(va, vb)', number=10, globals=ctx) + res3 = timeit.timeit( + 'multiply_matrix(va, vb)', number=10, globals=ctx) + self.assertLess(res1, res2) # numpy is much faster. + ratio1 = res2 / res1 + # ratio1 = number of times numpy is faster + self.assertGreater(ratio1, 1) + ratio2 = res3 / res1 + self.assertGreater(ratio2, 1) + + +if __name__ == '__main__': + unittest.main()