In [1]:
import numpy as np
from numba import jit, prange


In [2]:
@jit(nopython=True)
def go_fast(a, b):
    # Function is compiled and runs in machine code
    m = a.shape[0]
    n = b.shape[0]
    f = a.shape[1]
    trace = np.zeros((m, n))
    for i in range(m):
        for j in range(n):
            for k in range(f):
                trace[i, j] += a[i, k] * b[j, k]
    return trace


@jit(nopython=True)
def go_fast_np(a, b):
    # why is it 10x slower than go_fast
    m = a.shape[0]
    n = b.shape[0]
    trace = np.zeros((m, n))
    for i in range(m):
        for j in range(n):
            trace[i, j] = a[i, :].dot(b[j, :])
    return trace


@jit(nopython=True, fastmath=True, parallel=True)
def go_fast_opt(a, b):
    # Function is compiled and runs in machine code
    m = a.shape[0]
    n = b.shape[0]
    f = a.shape[1]
    trace = np.empty((m, n))
    for i in prange(m):
        for j in range(n):
            # Use scalar here
            acc = 0
            for k in range(f):
                acc += a[i, k] * b[j, k]
            trace[i, j] = acc
    return trace


@jit(nopython=True)
def go_fast_np_opt(a, b):
    m = a.shape[0]
    n = b.shape[0]
    trace = np.dot(a, b.T)
    return trace


In [3]:
class testJIT():
    def __init__(self) -> None:
        pass

    def go_fast(self, a, b):
        @jit(nopython=True )
        def __go_fast(a, b):
            # Function is compiled and runs in machine code
            m = a.shape[0]
            n = b.shape[0]
            f = a.shape[1]
            trace = np.zeros((m, n))
            for i in range(m):
                for j in range(n):
                    for k in range(f):
                        trace[i, j] += a[i, k] * b[j, k]
            return trace
        return __go_fast(a, b)


    def go_fast_np(self, a, b):
        @jit(nopython=True )
        def __go_fast_np(a, b):
            # why is it 10x slower than go_fast
            m = a.shape[0]
            n = b.shape[0]
            trace = np.zeros((m, n))
            for i in range(m):
                for j in range(n):
                    trace[i, j] = a[i, :].dot(b[j, :])
            return trace
        return __go_fast_np(a, b)

    def go_fast_opt(self, a, b):
        @jit(nopython=True, fastmath=True, parallel=True)
        def __go_fast_opt(a, b):
            # Function is compiled and runs in machine code
            m = a.shape[0]
            n = b.shape[0]
            f = a.shape[1]
            trace = np.empty((m, n))
            for i in prange(m):
                for j in range(n):
                    # Use scalar here
                    acc = 0
                    for k in range(f):
                        acc += a[i, k] * b[j, k]
                    trace[i, j] = acc
            return trace
        return __go_fast_opt(a,b)

    def go_fast_np_opt(self, a, b):
        @jit(nopython=True )
        def __go_fast_np_opt(a, b):
            m = a.shape[0]
            n = b.shape[0]
            trace = np.dot(a, b.T)
            return trace
        return __go_fast_np_opt(a, b)

tt = testJIT()

In [4]:
class testJIT2():
    def __init__(self) -> None:
        pass
    
    @staticmethod
    @jit(nopython=True)
    def go_fast(a, b):
        # Function is compiled and runs in machine code
        m = a.shape[0]
        n = b.shape[0]
        f = a.shape[1]
        trace = np.zeros((m, n))
        for i in range(m):
            for j in range(n):
                for k in range(f):
                    trace[i, j] += a[i, k] * b[j, k]
        return trace

    @staticmethod
    @jit(nopython=True)
    def go_fast_np(a, b):
        # why is it 10x slower than go_fast
        m = a.shape[0]
        n = b.shape[0]
        trace = np.zeros((m, n))
        for i in range(m):
            for j in range(n):
                trace[i, j] = a[i, :].dot(b[j, :])
        return trace

    @staticmethod
    @jit(nopython=True,fastmath=True, parallel=True)
    def go_fast_opt(a, b):
        # Function is compiled and runs in machine code
        m = a.shape[0]
        n = b.shape[0]
        f = a.shape[1]
        trace = np.empty((m, n))
        for i in prange(m):
            for j in range(n):
                # Use scalar here
                acc = 0
                for k in range(f):
                    acc += a[i, k] * b[j, k]
                trace[i, j] = acc
        return trace

    @staticmethod
    @jit(nopython=True)
    def go_fast_np_opt(a, b):
        m = a.shape[0]
        n = b.shape[0]
        trace = np.dot(a, b.T)
        return trace

tt2 = testJIT2()

In [5]:
class testJIT3:
    def __init__(self) -> None:
        pass

    @staticmethod
    @jit(nopython=True)
    def __go_fast(a, b):
        # Function is compiled and runs in machine code
        m = a.shape[0]
        n = b.shape[0]
        f = a.shape[1]
        trace = np.zeros((m, n))
        for i in range(m):
            for j in range(n):
                for k in range(f):
                    trace[i, j] += a[i, k] * b[j, k]
        return trace

    @staticmethod
    @jit(nopython=True)
    def __go_fast_np(a, b):
        # why is it 10x slower than go_fast
        m = a.shape[0]
        n = b.shape[0]
        trace = np.zeros((m, n))
        for i in range(m):
            for j in range(n):
                trace[i, j] = a[i, :].dot(b[j, :])
        return trace

    @staticmethod
    @jit(nopython=True, fastmath=True, parallel=True)
    def __go_fast_opt(a, b):
        # Function is compiled and runs in machine code
        m = a.shape[0]
        n = b.shape[0]
        f = a.shape[1]
        trace = np.empty((m, n))
        for i in prange(m):
            for j in range(n):
                # Use scalar here
                acc = 0
                for k in range(f):
                    acc += a[i, k] * b[j, k]
                trace[i, j] = acc
            return trace

    @staticmethod
    @jit(nopython=True)
    def __go_fast_np_opt(a, b):
        m = a.shape[0]
        n = b.shape[0]
        trace = np.dot(a, b.T)
        return trace

    def go_fast(self, a, b):
        return self.__go_fast(a, b)

    def go_fast_np(self, a, b):
        return self.__go_fast_np(a, b)

    def go_fast_opt(self, a, b):
        return self.__go_fast_opt(a, b)

    def go_fast_np_opt(self, a, b):
        return self.__go_fast_np_opt(a, b)


tt3 = testJIT3()


In [6]:
a = np.arange(20000).astype(float).reshape(1000, 20)
b = np.arange(10000).astype(float).reshape(500, 20)


In [7]:
%timeit go_fast(a, b)
#6.64 ms ± 50.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit go_fast_opt(a, b)
#1.2 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit go_fast_np(a, b)
#33.4 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit go_fast_np_opt(a, b)
#1.1 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


5.69 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
655 µs ± 8.64 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
11.9 ms ± 7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
132 µs ± 17.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit tt.go_fast(a, b)
%timeit tt.go_fast_opt(a, b)
%timeit tt.go_fast_np(a, b)
%timeit tt.go_fast_np_opt(a, b)

87.2 ms ± 237 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
276 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
94.9 ms ± 487 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
88.3 ms ± 963 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%timeit tt2.go_fast(a, b)
%timeit tt2.go_fast_opt(a, b)
%timeit tt2.go_fast_np(a, b)
%timeit tt2.go_fast_np_opt(a, b)

5.56 ms ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
662 µs ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
12 ms ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
126 µs ± 466 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [10]:
%timeit tt3.go_fast(a, b)
# %timeit tt3.go_fast_opt(a, b)
%timeit tt3.go_fast_np(a, b)
%timeit tt3.go_fast_np_opt(a, b)

5.57 ms ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12 ms ± 7.09 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
132 µs ± 2.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [11]:
a = np.arange(12).astype(float).reshape(3, 4)
b = np.arange(16).astype(float).reshape(4, 4)


In [12]:

%timeit go_fast(a, b)
#922 µs ± 1.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit go_fast_opt(a, b)
#79.4 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit go_fast_np(a, b)
#431 µs ± 3.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit go_fast_np_opt(a, b)
#14.6 µs ± 291 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

632 ns ± 17.5 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
4.12 µs ± 52 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
836 ns ± 7.46 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
786 ns ± 2.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [13]:
%timeit tt.go_fast(a, b)
%timeit tt.go_fast_opt(a, b)
%timeit tt.go_fast_np(a, b)
%timeit tt.go_fast_np_opt(a, b)

82.1 ms ± 238 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
277 ms ± 1.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
83 ms ± 194 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
58.5 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
%timeit tt2.go_fast(a, b)
%timeit tt2.go_fast_opt(a, b)
%timeit tt2.go_fast_np(a, b)
%timeit tt2.go_fast_np_opt(a, b)

634 ns ± 5.36 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
4.24 µs ± 73 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
850 ns ± 10.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
804 ns ± 1.81 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [15]:
%timeit tt3.go_fast(a, b)
# %timeit tt3.go_fast_opt(a, b)
%timeit tt3.go_fast_np(a, b)
%timeit tt3.go_fast_np_opt(a, b)

715 ns ± 9.92 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
928 ns ± 13.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
893 ns ± 3.54 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
