In [1]:
import ccxt
import pandas as pd

binance = ccxt.binanceusdm()
binance.load_markets()
ohlcv = binance.fetch_ohlcv('BTC/USDT', '1d', limit=1000)
df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])

In [2]:
df

Unnamed: 0,timestamp,open,high,low,close,volume
0,1590019200000,9508.00,9573.00,8812.20,9060.00,461439.193
1,1590105600000,9060.00,9259.38,8920.00,9166.40,276008.073
2,1590192000000,9166.43,9300.00,9076.90,9176.41,201589.660
3,1590278400000,9176.32,9294.44,8674.00,8711.37,339643.378
4,1590364800000,8711.37,8977.00,8623.38,8895.65,302980.106
...,...,...,...,...,...,...
995,1675987200000,21793.30,21933.80,21405.00,21618.60,439986.497
996,1676073600000,21618.60,21891.90,21594.40,21851.30,189249.667
997,1676160000000,21851.20,22080.90,21618.00,21771.10,276787.093
998,1676246400000,21771.00,21888.00,21338.00,21766.20,511595.693


In [3]:
import sys 
sys.path.append('..')
import jaxbt.backtest as jbt
import jax
import jax.numpy as jnp

@jax.jit
def f(param: jax.Array, bt: jbt.Backtest, idx: int):
    a = jax.lax.cond(
        bt.position[idx] == 0.,
        lambda _: jbt.OrderType.LIMIT_BUY,  # if position is 0 limit buy
        lambda _: jbt.OrderType.MARKET_SELL,  # else limit buy
        None,
    )
    size = jax.lax.cond(
        bt.position[idx] >= 1.,
        lambda _: 0. ,  # if position is larger than 1, do nothing
        lambda _: param[0],  # else limit buy
        None,
    )
    
    return a, size, bt.close[idx]

@jax.jit
def loss(param: jax.Array):
    result = jbt.backtest_from_order_func(
        df, lambda bt, idx: f(param, bt, idx)
    )
    return -result.pl.sum()

grad_fun = jax.grad(loss, argnums=0)
print(grad_fun(jnp.array([5.0])))

@jax.jit
def train(epoch, params, lr=0.01):
    def body_fun(idx, params):
        grads = grad_fun(params)
        params = params - lr * grads
        return params

    params = jax.lax.fori_loop(0, epoch, body_fun, params)
    return params

init_params = jnp.array([0.1])
result_params = train(100, init_params)
result_params

[0.50496995]


Array([-1.2157357], dtype=float32)

In [4]:
for i in range(1, 100, 10):
    result_params = train(i, init_params)
    print(result_params, grad_fun(result_params), -loss(result_params))

[0.0949503] [0.50496995] 0.45702305
[0.04445332] [0.50496995] 0.48252255
[-0.01617823] [1.5184261] 1.5429919
[-0.16802081] [1.5184261] 1.7735536
[-0.31986335] [1.5184261] 2.004115
[-0.47170588] [1.5184261] 2.234677
[-0.6235487] [1.5184261] 2.465239
[-0.7753915] [1.5184261] 2.6958015
[-0.92723435] [1.5184261] 2.9263637
[-1.0790771] [1.5184261] 3.1569254
