In [101]:
import sys

sys.path.append("../")

import pandas as pd
import numpy as np
import datetime
import os
from pprint import pprint
import matplotlib.pyplot as plt
import time
import vectorbtpro as vbt
from time import time
import helpers as pth
import platform
from dotenv import load_dotenv
import scipy.stats as stats
import time
import helpers as pth
from numba import njit
import talib

theme = "light"
vbt.settings.set_theme(theme)

pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", 20)
# plt.rcParams["axes.grid"] = True
plt.rcParams["figure.figsize"] = (12, 7)
plt.rcParams["axes.formatter.useoffset"] = False
plt.rcParams["axes.formatter.limits"] = [-1000000000, 1000000000]
plt.style.use("classic" if theme == "light" else "dark_background")

if platform.system().lower() == "windows":
    base_data_path = "H:\\phitech-data\\01_raw"
else:
    from core_chains.simple.llm import make_Q_chain

    base_data_path = "../../phitech-data/01_raw"
    load_dotenv("../../sandatasci-core/credentials")
    Q = make_Q_chain("gpt-4o-instance1", __vsc_ipynb_file__)

In [102]:
%%html
<style>
.dataframe {
    font-size: 9pt; /* Adjust font size as needed */
}
</style>

In [103]:
symbols = ["MES", "6B"]
df = pth.SierraChartData.pull(
    symbols,
    timeframe="1min",
    start="2024-08-01",
    end="2024-12-01",
)
df

100%|##########| 2/2 [00:03<00:00,  1.54s/it, symbol=6B]



<helpers.SierraChartData at 0x17dfc9340>

In [104]:
df.data["MES"].info()

<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 119631 entries, 2024-08-01 00:00:00+00:00 to 2024-12-01 23:59:00+00:00
Data columns (total 11 columns):
 #   Column       Non-Null Count   Dtype  
---  ------       --------------   -----  
 0   open         119331 non-null  float64
 1   high         119331 non-null  float64
 2   low          119331 non-null  float64
 3   close        119331 non-null  float64
 4   volume       119331 non-null  float64
 5   #_of_trades  119331 non-null  float64
 6   ohlc_avg     119331 non-null  float64
 7   hlc_avg      119331 non-null  float64
 8   hl_avg       119331 non-null  float64
 9   bid_volume   119331 non-null  float64
 10  ask_volume   119331 non-null  float64
dtypes: float64(11)
memory usage: 11.0 MB


In [105]:
high, low, close = (
    df.get("High").dropna(),
    df.get("Low").dropna(),
    df.get("Close").dropna(),
)
close

symbol,MES,6B
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-08-01 00:00:00+00:00,5714.75,1.2858
2024-08-01 00:01:00+00:00,5715.25,1.2858
2024-08-01 00:02:00+00:00,5715.50,1.2858
2024-08-01 00:03:00+00:00,5714.50,1.2857
2024-08-01 00:05:00+00:00,5713.75,1.2857
...,...,...
2024-12-01 23:55:00+00:00,6115.00,1.2689
2024-12-01 23:56:00+00:00,6115.25,1.2689
2024-12-01 23:57:00+00:00,6115.75,1.2690
2024-12-01 23:58:00+00:00,6116.00,1.2693


### Pandas Implementation

In [106]:
def get_mid_price(high, low):
    return (high + low) / 2


def get_atr(high, low, close, period):
    tr0 = abs(high - low)
    tr1 = abs(high - close.shift())
    tr2 = abs(low - close.shift())
    tr = pd.concat((tr0, tr1, tr2), axis=1).max(axis=1)
    atr = tr.ewm(alpha=1 / period, adjust=False, min_periods=period).mean()
    return atr


def get_basic_bands(med_price, atr, multiplier):
    matr = multiplier * atr
    upper = med_price + matr
    lower = med_price - matr
    return upper, lower


def get_final_bands(close, upper, lower):
    trend = pd.Series(np.full(close.shape, np.nan), index=close.index)
    direction = pd.Series(np.full(close.shape, 1), index=close.index)
    long = pd.Series(np.full(close.shape, np.nan), index=close.index)
    short = pd.Series(np.full(close.shape, np.nan), index=close.index)

    for i in range(1, close.shape[0]):
        if close.iloc[i] > upper.iloc[i - 1]:
            direction.iloc[i] = 1
        elif close.iloc[i] < lower.iloc[i - 1]:
            direction.iloc[i] = -1
        else:
            direction.iloc[i] = direction.iloc[i - 1]
            if direction.iloc[i] > 0 and lower.iloc[i] < lower.iloc[i - 1]:
                lower.iloc[i] = lower.iloc[i - 1]
            if direction.iloc[i] < 0 and upper.iloc[i] > upper.iloc[i - 1]:
                upper.iloc[i] = upper.iloc[i - 1]

        if direction.iloc[i] > 0:
            trend.iloc[i] = long.iloc[i] = lower.iloc[i]
        else:
            trend.iloc[i] = short.iloc[i] = upper.iloc[i]

    return trend, direction, long, short


def supertrend(high, low, close, period=7, multiplier=3):
    midprice = get_mid_price(high, low)
    atr = get_atr(high, low, close, period=period)
    upper, lower = get_basic_bands(midprice, atr, multiplier=multiplier)
    return get_final_bands(close, upper, lower)

In [None]:
%%timeit
ticker = "MES"
supert, superd, superl, supers = supertrend(high[ticker], low[ticker], close[ticker])

In [108]:
fig = close[ticker].vbt.plot()
supers.vbt.plot(fig=fig)
superl.vbt.plot(fig=fig)

FigureWidget({
    'data': [{'name': 'MES',
              'showlegend': True,
              'type': 'scatter',
              'uid': '12cca9a1-1aa9-4f72-8cec-20ade66a23a6',
              'x': array([datetime.datetime(2024, 8, 1, 0, 0, tzinfo=datetime.timezone.utc),
                          datetime.datetime(2024, 8, 1, 0, 1, tzinfo=datetime.timezone.utc),
                          datetime.datetime(2024, 8, 1, 0, 2, tzinfo=datetime.timezone.utc), ...,
                          datetime.datetime(2024, 12, 1, 23, 57, tzinfo=datetime.timezone.utc),
                          datetime.datetime(2024, 12, 1, 23, 58, tzinfo=datetime.timezone.utc),
                          datetime.datetime(2024, 12, 1, 23, 59, tzinfo=datetime.timezone.utc)],
                         dtype=object),
              'y': array([5714.75, 5715.25, 5715.5 , ..., 6115.75, 6116.  , 6115.75])},
             {'showlegend': False,
              'type': 'scatter',
              'uid': 'e6cc956d-cd34-451b-b4b4-5841c04350b1'

### Numpy + Numba

In [121]:
def get_atr_np(high, low, close, period):
    shifted_close = vbt.nb.fshift_1d_nb(close)
    tr0 = np.abs(high - low)
    tr1 = np.abs(high - shifted_close)
    tr2 = np.abs(low - shifted_close)
    tr = np.column_stack((tr0, tr1, tr2)).max(axis=1)
    atr = vbt.nb.wwm_mean_1d_nb(tr, period)
    return atr


@njit
def get_final_bands_nb(close, upper, lower):
    trend = np.full(close.shape, np.nan)
    direction = np.full(close.shape, 1)
    long = np.full(close.shape, np.nan)
    short = np.full(close.shape, np.nan)

    for i in range(1, close.shape[0]):
        if close[i] > upper[i - 1]:
            direction[i] = 1
        elif direction[i] < lower[i - 1]:
            direction[i] = -1
        else:
            direction[i] = direction[i - 1]
            if direction[i] > 0 and lower[i] < lower[i - 1]:
                lower[i] = lower[i - 1]
            if direction[i] < 0 and upper[i] > upper[i - 1]:
                upper[i] = upper[i - 1]

        if direction[i] > 0:
            trend[i] = long[i] = lower[i]
        else:
            trend[i] = short[i] = upper[i]

    return trend, direction, long, short


def faster_supertrend(high, low, close, period=7, multiplier=3):
    midprice = get_mid_price(high, low)
    atr = get_atr_np(high, low, close, period)
    upper, lower = get_basic_bands(midprice, atr, multiplier)
    return get_final_bands_nb(close, upper, lower)

In [None]:
%%timeit
faster_supertrend(
    high[ticker].values, low[ticker].values, close[ticker].values
)

array([          nan,           nan,           nan, ..., 6118.02512365,
       6118.25367742, 6118.00315207])

In [123]:
def faster_supertrend_talib(high, low, close, period=7, multiplier=3):
    midprice = talib.MEDPRICE(high, low)
    atr = talib.ATR(high, low, close, period)
    upper, lower = get_basic_bands(midprice, atr, multiplier)
    return get_final_bands_nb(close, upper, lower)

In [124]:
%%timeit
faster_supertrend_talib(high[ticker].values, low[ticker].values, close[ticker].values)

1.07 ms ± 13.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
