In [7]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
size = 400
iterations = 100

In [8]:
from numba import jit

In [9]:
@jit
def mandelbrot_numba(size, iterations):
    m = np.zeros((size, size))
    for i in range(size):
        for j in range(size):
            c = (-2 + 3. / size * j +
                 1j * (1.5 - 3. / size * i))
            z = 0
            for n in range(iterations):
                if np.abs(z) <= 10:
                    z = z * z + c
                    m[i, j] = n
                else:
                    break
    return m

In [10]:
mandelbrot_numba(size, iterations)
%timeit mandelbrot_numba(size, iterations)

55.9 ms ± 73.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
def initialize(size):
    x, y = np.meshgrid(np.linspace(-2, 1, size),
                       np.linspace(-1.5, 1.5, size))
    c = x + 1j * y
    z = c.copy()
    m = np.zeros((size, size))
    return c, z, m
def mandelbrot_numpy(c, z, m, iterations):
    for n in range(iterations):
        indices = np.abs(z) <= 10
        z[indices] = z[indices] ** 2 + c[indices]
        m[indices] = n


In [14]:
%%timeit -n1 -r10 c, z, m = initialize(size)
mandelbrot_numpy(c, z, m, iterations)

293 ms ± 29.7 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)
