In [None]:
from jax import devices, random, jit, vmap, numpy as jnp
from vectorbt import _typing as tp
from numba import njit
from jax import random
import numpy as np
import itertools
from functools import partial
import pandas as pd
import random as pyrandom

In [None]:
print(devices(backend="gpu"))
gpu = devices(backend="gpu")[0]

In [None]:
@njit(cache=True)
def get_return_nb(input_value: float, output_value: float) -> float:
    """Calculate return from input and output value."""
    if input_value == 0:
        if output_value == 0:
            return 0.
        return np.inf * np.sign(output_value)
    return_value = (output_value - input_value) / input_value
    if input_value < 0:
        return_value *= -1
    return return_value


@njit(cache=True)
def returns_1d_nb(value: tp.Array1d, init_value: float) -> tp.Array1d:
    """Calculate returns from value."""
    out = np.empty(value.shape, dtype=np.float_)
    input_value = init_value
    for i in range(out.shape[0]):
        output_value = value[i]
        out[i] = get_return_nb(input_value, output_value)
        input_value = output_value
    return out


@njit(cache=True)
def returns_nb(value: tp.Array2d, init_value: tp.Array1d) -> tp.Array2d:
    """2-dim version of `returns_1d_nb`."""
    out = np.empty(value.shape, dtype=np.float_)
    for col in range(out.shape[1]):
        out[:, col] = returns_1d_nb(value[:, col], init_value[col])
    return out

In [None]:
def _returns_1d_jax(value: tp.Array1d, init_value: float):
	Y = jnp.concatenate((jnp.array([init_value]), value))
	return jnp.divide(
		jnp.diff(Y), 
		Y[:1]
		) * jnp.sign(Y[:1])

returns_1d_jax = jit(_returns_1d_jax, device=gpu)
returns_jax = lambda v, i: vmap(partial(returns_1d_jax, init_value=i), in_axes=1, out_axes=1)(v)

In [None]:
# these functions make random test data, i made one for each of numpy and jax for fairness
def random_walk_nb(start, scale, steps, seed, n_portfolio=None):
    rng = np.random.default_rng(seed)
    shape = (steps - 1, n_portfolio) if n_portfolio else (steps - 1, )
    noise = rng.normal(size=shape) * scale
    noise = np.insert(noise, 0, start, axis=0)
    walk = np.cumsum(noise, axis=0)
    return walk

# nb_data = random_walk_nb(100, 1, 100000, 1)

def random_walk_jax(start, scale, steps, seed, n_portfolio=None):
    key = random.PRNGKey(seed)
    shape = (steps - 1, n_portfolio) if n_portfolio else (steps - 1, )
    noise = random.normal(key, shape=shape) * scale
    noise = jnp.insert(noise, 0, start, axis=0)
    walk = jnp.cumsum(noise, axis=0)
    return walk

# jax_data = random_walk_jax(100, 1, 100000, 1)

In [None]:
# make sure the 1d jax thing works right
# arr = jnp.array([100, 110, 100, 120])
# numerator = jnp.diff(arr)
# print("differences", numerator)
# denominator = arr[:-1]
# print("denominators", denominator)
# returns = numerator / denominator
# print("returns", returns)

In [None]:
# %timeit returns_1d_jax(jax_data, 1.0)

In [None]:
# %timeit returns_1d_nb(nb_data, 1.0)

In [None]:
n_candles = [2048]
n_portfolios = [2 ** i for i in range(2, 12)]
n_loops = 5

timings = {"index": [], "jax": [], "nb": []}
testing = True
timing = True
for n_candle, n_portfolio in itertools.product(n_candles, n_portfolios):
    size = n_candle*n_portfolio
    print(f"\ncalculate ({n_candle}, {n_portfolio}) returns (size {size})\n")
    timings["index"].append(n_portfolio)
    def jax_fun(walk=None):
        if walk is None:
            seed = pyrandom.randint(0, 1000)
            walk = random_walk_jax(1, 0.01, n_candle, seed, n_portfolio=n_portfolio)
        returns = returns_jax(walk, 1.0).block_until_ready()
        return returns
    one = np.ones((n_portfolio,))
    def nb_fun(walk=None):
        if walk is None:
            seed = pyrandom.randint(0, 1000)
            walk = random_walk_nb(1, 0.01, n_candle, seed, n_portfolio=n_portfolio)
        returns = returns_nb(walk, one)
        return returns
    if timing:
        print("jax:")
        jax_timings = %timeit -o jax_fun()
        timings["jax"].append(jax_timings.average)
        print("nb:")
        nb_timings = %timeit -o nb_fun()
        timings["nb"].append(nb_timings.average)
    if testing:
        print("mean absolute error between jax and numba returns:")
        print("jnp.abs(nb_test - jax_test).mean() =")
        maes = []
        for i in range(n_loops):
            seed = pyrandom.randint(0, 1000)
            walk = random_walk_jax(1, 0.01, n_candle, seed, n_portfolio=n_portfolio)
            np_walk = np.array(walk)
            jax_test = jax_fun(walk)
            nb_test = nb_fun(np_walk)
            mae = jnp.abs(nb_test - jax_test).mean()
            maes.append(mae)
        print(maes)
if timing:
    df = pd.DataFrame.from_dict(timings)
    df.index = df["index"]
    df = df.drop(columns="index")
    df.plot(title="vectorbt returns accelerator benchmark",  xlabel="n_portfolios", ylabel="time (lower is better)")