In [209]:
import talib
from talib.abstract import Function
import numpy as np
import pandas as pd

from gym import Space
from copy import copy
from abc import abstractmethod
from typing import Union, List, Callable, Dict

from tensortrade.features import FeatureTransformer


class TAlibIndicator(FeatureTransformer):
    """Adds one or more TAlib indicators to a data frame, based on existing open, high, low, and close column values."""

    def __init__(self, indicators: List[str], lows: Union[List[float], List[int]] = None, highs: Union[List[float], List[int]] = None, **kwargs):
        indicators = self._error_check(indicators)
        self._indicator_names = [indicator.upper() for indicator in indicators]
        self._indicators = [getattr(talib, name.split('-')[0]) for name in self._indicator_names]
        # Here we get the stats for each indicator for TA-Lib
        self._stats = {indicator:self._get_stats(indicator) for indicator in self._indicator_names}
        
    def _error_check(self, a):
        err_indexes = []
        for n, i in enumerate(a):
            if i == "BBAND":
                a[n] = "BBANDS"
            elif i == "BB":
                pass
            elif i == "RIS":
                a[n] = "RSI"
            elif i == "":
                err_indexes.append(n)
            elif i == None:
                err_indexes.append(n)
        for n in sorted(err_indexes, reverse=True):
            del a[n]
        return a
    
    def _get_stats(self, indicator_name:str) -> Dict:
        """ Get the relavent indicator information.

        Parameters:
        -------
        code: code of symbol (required)
            get help information of a symbol
        """
        if indicator_name is None:
            print("Usage: help_indicator(symbol), symbol is indicator name")
            return {
                "parameters": {},
                "inputs": []
            }
        else:
            upper_code = indicator_name.upper()
            if upper_code not in talib.get_functions():
                print(f"ERROR: indicator {upper_code} not in list")
                return {
                    "parameters": {},
                    "inputs": []
                }
            else:
                func = Function(upper_code)
                parameters = dict(func.parameters)
                inputs = list(func.input_names.values())
                return {
                    "parameters": parameters,
                    "inputs": inputs
                }

    def transform(self, X: pd.DataFrame) -> pd.DataFrame:
        for idx, indicator in enumerate(self._indicators):
            indicator_name = self._indicator_names[idx]
            indicator_params = self._stats[indicator_name]['parameters']
            indicator_args = [X[arg].values for arg in self._stats[indicator_name]["inputs"]]
            
            if indicator_name == 'BBANDS':
                upper, middle, lower = indicator(*indicator_args,**indicator_params)

                X["bb_upper"] = upper
                X["bb_middle"] = middle
                X["bb_lower"] = lower
            else:
                try:
                    value = indicator(*indicator_args,**indicator_params)

                    if type(value) == tuple:
                        X[indicator_name] = value[0][0]
                    else:
                        X[indicator_name] = value

                except:
                    X[indicator_name] = indicator(*indicator_args,**indicator_params)[0]

        return X


In [210]:
import pandas as pd

from tensortrade.features.scalers import MinMaxNormalizer, ComparisonNormalizer, PercentChangeNormalizer
from tensortrade.features.stationarity import FractionalDifference

ohlcv_data = pd.read_csv('./data/Coinbase_BTCUSD_1h.csv', skiprows=1)
ohlcv_data = ohlcv_data[['open','high','low','close','volume']]

In [211]:
taindicator = TAlibIndicator(indicators=["BBAND", "RSI", "EMA", "SMA", "", None])

In [208]:
taindicator.transform(ohlcv_data)

Unnamed: 0,open,high,low,close,volume,bb_upper,bb_middle,bb_lower,RSI,EMA,SMA
0,8051.00,8056.83,8021.23,8035.88,492394.56,,,,,,
1,7975.89,8070.00,7975.89,8051.00,2971610.86,,,,,,
2,7964.62,7987.82,7964.61,7975.89,970521.83,,,,,,
3,7984.02,7993.97,7958.29,7964.62,1692336.84,,,,,,
4,7941.71,7986.99,7937.01,7984.02,774064.91,8071.279007,8002.282,7933.284993,,,
5,7970.19,7973.23,7937.50,7941.71,1177321.50,8056.754357,7983.448,7910.141643,,,
6,7986.62,7988.54,7965.61,7970.19,805106.70,7995.903870,7967.286,7938.668130,,,
7,7960.00,7989.62,7959.74,7986.62,743583.37,8001.686938,7969.432,7937.177062,,,
8,7966.43,7977.48,7959.74,7960.00,1325771.25,8001.517268,7968.508,7935.498732,,,
9,7993.54,7995.81,7957.55,7966.43,1584534.00,7994.162988,7964.990,7935.817012,,,


In [158]:
get_indicator_stats("rsi")

['price']
{'timeperiod': 14}
