In [None]:
%matplotlib inline
%load_ext Cython
#http://people.duke.edu/~ccc14/sta-663-2016/18D_Cython.html

import matplotlib.pylab as plt
import random
import numpy as np

#### Matrix Multiplication


Let's write a function to multiply 2 square (NxN) matrices together. 

First is out niave wave following the way we would do it by hand.

\begin{equation}
A =\begin{pmatrix}
 a_{11} & a_{12} & \cdots & a_{1m} \\
 a_{21} & a_{22} & \cdots & a_{2m} \\
\vdots & \vdots & \ddots & \vdots \\
 a_{n1} & a_{n2} & \cdots & a_{nm} \\
\end{pmatrix}
\end{equation}

\begin{equation}
B=\begin{pmatrix}
 b_{11} & b_{12} & \cdots & b_{1p} \\
 b_{21} & b_{22} & \cdots & b_{2p} \\
\vdots & \vdots & \ddots & \vdots \\
 b_{m1} & b_{m2} & \cdots & b_{mp} \\
\end{pmatrix}
\end{equation}

\begin{equation}
C = AB
\end{equation}

\begin{equation}
C=\begin{pmatrix}
 c_{11} & c_{12} & \cdots & c_{1p} \\
 c_{21} & c_{22} & \cdots & c_{2p} \\
\vdots & \vdots & \ddots & \vdots \\
 c_{n1} & c_{n2} & \cdots & c_{np} \\
\end{pmatrix}
\end{equation}

\begin{equation}
c_{ij} = a_{i1}b_{1j} + ... + a_{im}b_{mj} = \sum_{k=1}^m a_{ik}b_{kj}
\end{equation}

In [None]:
def mat_mul_py(A, B):
    # Set all the elemets of C to 0
    C = [[0 for i in range(len(A))] for j in range(len(A))]
    # For each col
    for i in range(len(A)):
        # For each row
        for j in range(len(A)):
            # Do the sum for the elements
            for k in range(len(A)):
                # Do the sum of the elements
                C[i][j] += A[i][k] * B[k][j]
    return C

In [None]:
size = 32

A = np.random.random((size,size)).astype(np.float32)
B = np.random.random((size,size)).astype(np.float32)

%time x = mat_mul_py(A,B)

Using what you've learned so far modify the code below to improve the speed of your matrix multiplication function by trying to remove the yellow portions of the anotated code.

Look back at chapter 2 Cython Fast for hints.

In [None]:
%%cython -a

def mat_mul_cy(A, B):
    # Set all the elemets of C to 0
    C = [[0 for i in range(len(A))] for j in range(len(A))]
    # For each col
    for i in range(len(A)):
        # For each row
        for j in range(len(A)):
            # Do the sum for the elements
            for k in range(len(A)):
                # Do the sum of the elements
                C[i][j] += A[i][k] * B[k][j]
    return C

In [None]:
%time x = mat_mul_cy(A,B)

In [None]:
#%load mat_mul_cy_fast.py
# NOTE: %%cython has to be the first line so remove the load line as well

def mat_mul_cy_fast(A,B,C):
    print("\033[1m\033[91m\n\nLoad the py file for the answer\n\n\033[0m")

In [None]:
size = 512

A = np.random.random((size,size)).astype(np.float32)
B = np.random.random((size,size)).astype(np.float32)
C = np.zeros((A.shape[0], B.shape[1])).astype(np.float32)

%time mat_mul_cy_fast(A,B,C)

## Sometimes it's best to just use the predefined functions though....

In [None]:
#A = np.random.random((size,size)).astype(np.float32)
#B = np.random.random((size,size)).astype(np.float32)
#C = np.zeros((A.shape[0], A.shape[1])).astype(np.float32)
#%time C = A * B

Again look though the code below and try to remove as much of the yellow portions as possible to increase the speed of the code.

In [None]:
%%cython -a

cimport cython

### import and use C functions
#cdef extern from "complex.h":
#    double cabs(double complex)

def mandel(x, y, max_iters):
    c = complex(x, y)
    z = 0.0j
    for i in range(max_iters):
        z = z*z + c
        if z.real*z.real + z.imag*z.imag >= 4:
        # if cabs(z) >= 2:
            return i
    return max_iters

def create_fractal(xmin, xmax, ymin, ymax, image, iters):
    height, width = image.shape

    pixel_size_x = (xmax - xmin)/width
    pixel_size_y = (ymax - ymin)/height

    for x in range(width):
        real = xmin + x*pixel_size_x
        for y in range(height):
            imag = ymin + y*pixel_size_y
            color = mandel(real, imag, iters)
            image[y, x]  = color
            

In [None]:
gimage = np.zeros((1080, 1920), dtype=np.uint32)
xmin, xmax, ymin, ymax = [-2.0, 1.0, -1.0, 1.0]
iters = 100

%time create_fractal(xmin, xmax, ymin, ymax, gimage, iters)

plt.figure(figsize=(15,15))
plt.grid(False)
plt.imshow(gimage, cmap='viridis')
plt.show()

In [None]:
#%load fractal.py
# NOTE: %%cython has to be the first line so remove the load line as well

def create_fractal_cython(xmin, xmax, ymin, ymax, gimage, iters):
    print("\033[1m\033[91m\n\nLoad the py file for the answer\n\n\033[0m")

In [None]:
gimage = np.zeros((1080, 1920), dtype=np.uint32)
xmin, xmax, ymin, ymax = [-2.0, 1.0, -1.0, 1.0]
iters = 50

%time create_fractal_cython(xmin, xmax, ymin, ymax, gimage, iters)

plt.figure(figsize=(16,9))
plt.grid(False)
plt.imshow(gimage, cmap='viridis')
plt.show()