## Install TensorTrade

In [None]:
!python3 -m pip install git+https://github.com/tensortrade-org/tensortrade.git

## Setup Data Fetching

In [1]:
import pandas as pd
import tensortrade.env.default as default

from tensortrade.data.cdd import CryptoDataDownload
from tensortrade.feed.core import Stream, DataFeed
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.instruments import USD, BTC, ETH
from tensortrade.oms.wallets import Wallet, Portfolio
from tensortrade.agents import DQNAgent


%matplotlib inline

In [2]:
cdd = CryptoDataDownload()

data = cdd.fetch("Coinbase", "USD", "BTC", "1h")

In [3]:
data.head()

Unnamed: 0,date,open,high,low,close,volume
0,2017-07-01 11:00:00,2505.56,2513.38,2495.12,2509.17,287000.32
1,2017-07-01 12:00:00,2509.17,2512.87,2484.99,2488.43,393142.5
2,2017-07-01 13:00:00,2488.43,2488.43,2454.4,2454.43,693254.01
3,2017-07-01 14:00:00,2454.43,2473.93,2450.83,2459.35,712864.8
4,2017-07-01 15:00:00,2459.35,2475.0,2450.0,2467.83,682105.41


## Create features with the feed module

In [4]:
def rsi(price: Stream[float], period: float) -> Stream[float]:
    r = price.diff()
    upside = r.clamp_min(0).abs()
    downside = r.clamp_max(0).abs()
    rs = upside.ewm(alpha=1 / period).mean() / downside.ewm(alpha=1 / period).mean()
    return 100*(1 - (1 + rs) ** -1)


def macd(price: Stream[float], fast: float, slow: float, signal: float) -> Stream[float]:
    fm = price.ewm(span=fast, adjust=False).mean()
    sm = price.ewm(span=slow, adjust=False).mean()
    md = fm - sm
    signal = md - md.ewm(span=signal, adjust=False).mean()
    return signal


features = []
for c in data.columns[1:]:
    s = Stream.source(list(data[c]), dtype="float").rename(data[c].name)
    features += [s]

cp = Stream.select(features, lambda s: s.name == "close")

features = [
    cp.log().diff().rename("lr"),
    rsi(cp, period=20).rename("rsi"),
    macd(cp, fast=10, slow=50, signal=5).rename("macd")
]

feed = DataFeed(features)
feed.compile()

In [5]:
for i in range(5):
    print(feed.next())

{'lr': nan, 'rsi': nan, 'macd': 0.0}
{'lr': -0.008300031641449657, 'rsi': 0.0, 'macd': -1.9717171717171975}
{'lr': -0.01375743446296962, 'rsi': 0.0, 'macd': -6.082702245269603}
{'lr': 0.0020025323250756344, 'rsi': 8.795475693113076, 'macd': -7.287625162566419}
{'lr': 0.00344213459739251, 'rsi': 21.34663357024277, 'macd': -6.522181201739986}


## Setup Trading Environment

In [6]:
coinbase = Exchange("coinbase", service=execute_order)(
    Stream.source(list(data["close"]), dtype="float").rename("USD-BTC")
)

portfolio = Portfolio(USD, [
    Wallet(coinbase, 10000 * USD),
    Wallet(coinbase, 10 * BTC)
])


renderer_feed = DataFeed([
    Stream.source(list(data["date"])).rename("date"),
    Stream.source(list(data["open"]), dtype="float").rename("open"),
    Stream.source(list(data["high"]), dtype="float").rename("high"),
    Stream.source(list(data["low"]), dtype="float").rename("low"),
    Stream.source(list(data["close"]), dtype="float").rename("close"), 
    Stream.source(list(data["volume"]), dtype="float").rename("volume") 
])


env = default.create(
    portfolio=portfolio,
    action_scheme="managed-risk",
    reward_scheme="risk-adjusted",
    feed=feed,
    renderer_feed=renderer_feed,
    renderer=default.renderers.PlotlyTradingChart(),
    window_size=20
)

In [7]:
env.observer.feed.next()

{'internal': {'coinbase:/USD-BTC': 2509.17,
  'coinbase:/USD:/free': 10000.0,
  'coinbase:/USD:/locked': 0.0,
  'coinbase:/USD:/total': 10000.0,
  'coinbase:/BTC:/free': 10.0,
  'coinbase:/BTC:/locked': 0.0,
  'coinbase:/BTC:/total': 10.0,
  'coinbase:/BTC:/worth': 25091.7,
  'net_worth': 35091.7},
 'external': {'lr': nan, 'rsi': nan, 'macd': 0.0},
 'renderer': {'date': Timestamp('2017-07-01 11:00:00'),
  'open': 2505.56,
  'high': 2513.38,
  'low': 2495.12,
  'close': 2509.17,
  'volume': 287000.32}}

## Setup and Train DQN Agent

In [8]:
agent = DQNAgent(env)

agent.train(n_steps=200, n_episodes=2, save_path="agents/")

FigureWidget({
    'data': [{'close': array([2509.17, 2488.43, 2454.43, ..., 2553.79, 2539.82, 2542.72]),
    …

49999.36214535988