<a href="https://colab.research.google.com/github/rajeevraizada/fMRI_tutorial_Jupyter_notebooks/blob/master/v3_Numba_Interactive_mandelbrot_ipympl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ipympl
!pip install line_profiler
%load_ext line_profiler
from line_profiler import LineProfiler
%matplotlib ipympl
from google.colab import output
output.enable_custom_widget_manager()
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from numba import jit, njit
import time

In [None]:
# This is the function which gets called when the mouse is clicked in the figure window
def do_this_when_the_mouse_is_clicked(this_event):
    global xmin
    global xmax
    global ymin
    global ymax
    global num_grid_points
    global zoom_per_click
    mouse_x = this_event.xdata
    mouse_y = this_event.ydata
    xrange = xmax - xmin
    yrange = ymax - ymin
    ### If the click is outside the range, then reset the view
    if this_event.xdata is None: # This means we clicked outside the axis
        xmin = -2
        xmax = 0.5
        ymin = -1.25
        ymax = 1.25
    else: # We clicked inside the axis
        # Make new xmin,xmax etc. values, to zoom in on where mouse was clicked
        xmin = mouse_x - xrange/(2*zoom_per_click)
        xmax = mouse_x + xrange/(2*zoom_per_click)
        ymin = mouse_y - yrange/(2*zoom_per_click)
        ymax = mouse_y + yrange/(2*zoom_per_click)

    max_its = np.real( 5 * (round(-np.log2(xrange)))**2.5 )
    #print(max_its)
    # Make sure that max_its is an integer and between 100 and 10,000
    max_its = int(np.max((max_its,100)))
    max_its = np.min( (max_its, 10**4) )

    t0 = time.perf_counter()
    timesteps_in_bound = raj_mandel_looped(xmin,xmax,ymin,ymax,num_grid_points,max_its)
    t1 = time.perf_counter()
    plt.clf()
    cmap_custom = matplotlib.cm.gnuplot2
    cmap_custom.set_over('k')
    vals_to_plot = timesteps_in_bound + 1
    cut_off_val = 0.99* np.log(np.max(vals_to_plot))
    plt.imshow(np.log(vals_to_plot),extent=[xmin,xmax,ymin,ymax],\
                origin='lower',cmap=cmap_custom, vmax=cut_off_val )
    plt.title('Magnification: %i   Iters: %i   Time: %.3fs' \
              %((2.5/(xmax-xmin)),max_its,(t1-t0)), fontsize=14)
    plt.xlabel('Click on where you want to zoom in. Click outside axes to reset.')
    plt.draw()

@njit
def mandel_one_pixel(c,max_its):
    z_old = 0
    for this_iter_num in range(0,max_its):
        z_new = z_old**2 + c
        abs_val_sq = np.real(z_new)**2 + np.imag(z_new)**2
        if abs_val_sq > 4:
            return(this_iter_num)
        else:
            z_old = z_new
    return(this_iter_num)

@njit
def raj_mandel_looped(xmin,xmax,ymin,ymax,num_grid_points,max_its):
    real_range = np.linspace(xmin,xmax,num_grid_points)
    imag_range = np.linspace(ymin,ymax,num_grid_points)
    timesteps_in_bound = np.zeros((num_grid_points,num_grid_points))
    for x_idx, x in enumerate(real_range):
        for y_idx, y in enumerate(imag_range):
            c = complex(x,y)
            # Note that rows in the matrix go up and down, so they span
            # the complex dimension. And cols go left-right, which is real!
            # So, we need to swap x_idx and y_idx below from what you might expect
            timesteps_in_bound[y_idx,x_idx] = mandel_one_pixel(c,max_its)
    return(timesteps_in_bound)

xmin = -2
xmax = 0.5
ymin = -1.25
ymax = 1.25
zoom_per_click = 2
num_grid_points = 600
max_its = 100

timesteps_in_bound = raj_mandel_looped(xmin,xmax,ymin,ymax,num_grid_points,max_its)

plt.close('all')
plt.figure(figsize=(7,7))
cmap_custom = matplotlib.cm.gnuplot2
cmap_custom.set_over('k')
vals_to_plot = timesteps_in_bound+1
cut_off_val = 0.99* np.log(np.max(vals_to_plot))
plt.imshow(np.log(vals_to_plot),extent=[xmin,xmax,ymin,ymax],\
                origin='lower',cmap=cmap_custom, vmax=cut_off_val )

plt.connect('button_press_event', do_this_when_the_mouse_is_clicked)
plt.xlabel('Click on where you want to zoom in. Click outside axes to reset.')
plt.show()
