Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_extension_tutorial(name):
ext_modules = []
for ext in ['dot_blas_lapack', 'dot_cython',
'experiment_cython', 'dot_cython_omp',
'mul_cython_omp']:
'mul_cython_omp', 'td_mul_cython']:
ext_modules.extend(get_extension_tutorial(ext))


Expand Down
42 changes: 42 additions & 0 deletions td3a_cpp/tutorial/td_mul_cython.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
TD 2021/01/19.
"""
from cython.parallel import prange
cimport cython
from cpython cimport array
import array

import numpy as pynumpy
cimport numpy as cnumpy
cnumpy.import_array()


def multiply_matrix(m1, m2):
"Matrix multiplication"
m3 = pynumpy.zeros((m1.shape[0], m2.shape[1]), dtype=m1.dtype)
for i in range(0, m1.shape[0]):
for j in range(0, m2.shape[1]):
for k in range(0, m1.shape[1]):
m3[i, j] += m1[i, k] * m2[k, j]
return m3


cdef void _c_multiply_matrix(double[:, :] m1, double[:, :] m2,
double[:, :] m3,
cython.int ni, cython.int nj, cython.int nk) nogil:
"Matrix multiplication wuth cython"
cdef cython.int i, j, k
for i in prange(0, ni):
for j in range(0, nj):
for k in range(0, nk):
m3[i, j] += m1[i, k] * m2[k, j]


def c_multiply_matrix(double[:, :] m1, double[:, :] m2):
"Matrix multiplication calling the cython version"
m3 = pynumpy.zeros((m1.shape[0], m2.shape[1]), dtype=pynumpy.float64)
cdef cython.int ni = m1.shape[0]
cdef cython.int nj = m2.shape[1]
cdef cython.int nk = m1.shape[1]
_c_multiply_matrix(m1, m2, m3, ni, nj, nk)
return m3
46 changes: 46 additions & 0 deletions tests/test_tutorial_td.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
"""
import unittest
import timeit
import numpy
from numpy.testing import assert_almost_equal
from td3a_cpp.tutorial.td_mul_cython import (
multiply_matrix, c_multiply_matrix)


class TestTutorialTD(unittest.TestCase):

def test_matrix_multiply_matrix(self):
va = numpy.random.randn(3, 4).astype(numpy.float64)
vb = numpy.random.randn(4, 5).astype(numpy.float64)
res1 = va @ vb
res2 = multiply_matrix(va, vb)
assert_almost_equal(res1, res2)

def test_matrix_cmultiply_matrix(self):
va = numpy.random.randn(3, 4).astype(numpy.float64)
vb = numpy.random.randn(4, 5).astype(numpy.float64)
res1 = va @ vb
res2 = c_multiply_matrix(va, vb)
assert_almost_equal(res1, res2)

def test_timeit(self):
va = numpy.random.randn(300, 400).astype(numpy.float64)
vb = numpy.random.randn(400, 500).astype(numpy.float64)
ctx = {'va': va, 'vb': vb, 'c_multiply_matrix': c_multiply_matrix,
'multiply_matrix': multiply_matrix}
res1 = timeit.timeit('va @ vb', number=10, globals=ctx)
res2 = timeit.timeit(
'c_multiply_matrix(va, vb)', number=10, globals=ctx)
res3 = timeit.timeit(
'multiply_matrix(va, vb)', number=10, globals=ctx)
self.assertLess(res1, res2) # numpy is much faster.
ratio1 = res2 / res1
# ratio1 = number of times numpy is faster
self.assertGreater(ratio1, 1)
ratio2 = res3 / res1
self.assertGreater(ratio2, 1)


if __name__ == '__main__':
unittest.main()