In [None]:
import numpy as np
import cupy as cu
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

from scipy.integrate import odeint, solve_ivp
from scipy.integrate._ivp.ivp import OdeResult

from ivp import solve_ivp as solve_ivp_cupy
from ivp.ivp import OdeResult as OdeResultCupyCustom
 
# init plotter and helper runner func
fig = plt.figure(figsize=(20, 10))
nrows = 4
ncols = 2
fig_index = 0
def run_and_plot(name, fun):
    global fig_index, fig, nrows, ncols
    
    start_t = time.time()
    res = fun()
    end_t = time.time()
    print(f'{name} took {end_t-start_t}s')

    fig_index += 1
    ax = fig.add_subplot(nrows, ncols, fig_index, projection='3d')
    if (isinstance(res, np.ndarray)):
        ax.plot(res[:, 0],
                res[:, 1],
                res[:, 2])
    if (isinstance(res, OdeResult)):
        ax.plot(res.y[0, :],
                res.y[1, :],
                res.y[2, :])
    if (isinstance(res, OdeResultCupyCustom)):
        if (hasattr(res.y[0, :], 'get')):
            ax.plot(res.y[0, :].get(),
                    res.y[1, :].get(),
                    res.y[2, :].get())
        else:
            ax.plot(res.y[0, :],
                    res.y[1, :],
                    res.y[2, :])
    ax.set_title(name)
    

def lorenz(t, state, sigma, beta, rho):
    x, y, z = state
     
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
     
    return [dx, dy, dz]

: 