# Crystal Ball
AI providing buy/sell/hold actions based on candlestick data.

## Binance
https://www.kaggle.com/code/lucasmorin/getting-all-1m-data-from-binance/notebook

https://developers.binance.com/docs/binance-spot-api-docs/rest-api/market-data-endpoints#klinecandlestick-data

## Reinforcement Learning
https://www.youtube.com/playlist?list=PLMrJAkhIeNNQe1JXNvaFvURxGY4gE9k74
### Pong from Pixels
https://karpathy.github.io/2016/05/31/rl/

### Stable Baselines
https://stable-baselines3.readthedocs.io/en/master/
https://anaconda.org/conda-forge/stable-baselines3

### Policy Gradient Methods
https://youtu.be/5P7I-xPq8u8?si=hXVvvLb1S8XcWGfz

### Example
https://towardsdatascience.com/how-to-train-an-ai-to-play-any-game-f1489f3bc5c
https://github.com/guszejnovdavid/custom_game_reinforcement_learning/blob/main/custom_game_reinforcement_learning.ipynb

### Plotting
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#
https://plotly.com/python/candlestick-charts/
https://github.com/highfestiva/finplot

### Installation
Install [Visual Studio Code](https://code.visualstudio.com/download) and [ANACONDA Distribution](https://www.anaconda.com/download).
Create a virtual environment and install packages below. 
Useful extensions include Jupyter, Python
~~~shell
# pytorch
$ conda install -c pytorch pytorch torchvision torchaudio
# tensorboard
$ conda install tensorboard
# cuda (if using nVidia GPU)
$ conda install -c nvidia pytorch-cuda
# data
$ conda install pandas pyarrow plotly nbformat
$ #conda install pyarrow -c conda-forge
$ #conda install conda-forge::stable-baselines3
$ #conda install -c plotly plotly
$ #conda install -c conda-forge nbformat
~~~

In [3]:
import json
import os
import random
import subprocess
import time
from datetime import date, datetime, timedelta
from datetime import date


import requests
import pandas as pd
import numpy as np

API_BASE = 'https://api.binance.com/api/v3/'

LABELS = [
    'open_time',
    'open',
    'high',
    'low',
    'close',
    'volume',
    'close_time',
    'quote_asset_volume',
    'number_of_trades',
    'taker_buy_base_asset_volume',
    'taker_buy_quote_asset_volume',
    'ignore'
]


In [4]:
def get_batch(symbol, interval='1m', start_time=0, limit=1000):
    """Use a GET request to retrieve a batch of candlesticks. Process the JSON into a pandas
    dataframe and return it. If not successful, return an empty dataframe.
    """

    params = {
        'symbol': symbol,
        'interval': interval,
        'startTime': start_time,
        'limit': limit
    }
    try:
        # timeout should also be given as a parameter to the function
        response = requests.get(f'{API_BASE}klines', params, timeout=30)
    except requests.exceptions.ConnectionError:
        print('Connection error, Cooling down for 5 mins...')
        time.sleep(5 * 60)
        return get_batch(symbol, interval, start_time, limit)
    
    except requests.exceptions.Timeout:
        print('Timeout, Cooling down for 5 min...')
        time.sleep(5 * 60)
        return get_batch(symbol, interval, start_time, limit)
    
    except requests.exceptions.ConnectionResetError:
        print('Connection reset by peer, Cooling down for 5 min...')
        time.sleep(5 * 60)
        return get_batch(symbol, interval, start_time, limit)

    if response.status_code == 200:
        return pd.DataFrame(response.json(), columns=LABELS)
    print(f'Got erroneous response back: {response}')
    return pd.DataFrame([])

# TODO: No new data is available on this channel?
def all_candles_to_csv(base, quote, interval='1m'):
    """Collect a list of candlestick batches with all candlesticks of a trading pair,
    concat into a dataframe and write it to CSV.
    """

    # see if there is any data saved on disk already
    try:
        batches = [pd.read_csv(f'data/{base}-{quote}.csv')]
        last_timestamp = batches[-1]['open_time'].max()
    except FileNotFoundError:
        batches = [pd.DataFrame([], columns=LABELS)]
        last_timestamp = 0
    old_lines = len(batches[-1].index)

    # gather all candlesticks available, starting from the last timestamp loaded from disk or 0
    # stop if the timestamp that comes back from the api is the same as the last one
    previous_timestamp = None

    while previous_timestamp != last_timestamp:
        # stop if we reached data from today
        if date.fromtimestamp(last_timestamp / 1000) >= date.today():
            break

        previous_timestamp = last_timestamp

        new_batch = get_batch(
            symbol=base+quote,
            interval=interval,
            start_time=last_timestamp+1
        )

        # requesting candles from the future returns empty
        # also stop in case response code was not 200
        if new_batch.empty:
            break

        last_timestamp = new_batch['open_time'].max()

        # sometimes no new trades took place yet on date.today();
        # in this case the batch is nothing new
        if previous_timestamp == last_timestamp:
            break

        batches.append(new_batch)
        last_datetime = datetime.fromtimestamp(last_timestamp / 1000)

        covering_spaces = 20 * ' '
        print(datetime.now(), base, quote, interval, str(last_datetime)+covering_spaces, end='\r', flush=True)

    # write clean version of csv to parquet
    parquet_name = f'{base}-{quote}.parquet'
    full_path = f'compressed/{parquet_name}'
    df = pd.concat(batches, ignore_index=True)
    df = quick_clean(df)
    write_raw_to_parquet(df, full_path)

    # in the case that new data was gathered write it to disk
    if len(batches) > 1:
        df.to_csv(f'data/{base}-{quote}.csv', index=False)
        return len(df.index) - old_lines
    return 0

def set_dtypes(df):
    """
    set datetimeindex and convert all columns in pd.df to their proper dtype
    assumes csv is read raw without modifications; pd.read_csv(csv_filename)"""

    df['open_time'] = pd.to_datetime(df['open_time'], unit='ms')
    df = df.set_index('open_time', drop=True)

    df = df.astype(dtype={
        'open': 'float64',
        'high': 'float64',
        'low': 'float64',
        'close': 'float64',
        'volume': 'float64',
        'close_time': 'datetime64[ms]',
        'quote_asset_volume': 'float64',
        'number_of_trades': 'int64',
        'taker_buy_base_asset_volume': 'float64',
        'taker_buy_quote_asset_volume': 'float64',
        'ignore': 'float64'
    })

    return df


def set_dtypes_compressed(df):
    """Create a `DatetimeIndex` and convert all critical columns in pd.df to a dtype with low
    memory profile. Assumes csv is read raw without modifications; `pd.read_csv(csv_filename)`."""

    df['open_time'] = pd.to_datetime(df['open_time'], unit='ms')
    df = df.set_index('open_time', drop=True)

    df = df.astype(dtype={
        'open': 'float32',
        'high': 'float32',
        'low': 'float32',
        'close': 'float32',
        'volume': 'float32',
        'number_of_trades': 'uint16',
        'quote_asset_volume': 'float32',
        'taker_buy_base_asset_volume': 'float32',
        'taker_buy_quote_asset_volume': 'float32'
    })

    return df


def assert_integrity(df):
    """make sure no rows have empty cells or duplicate timestamps exist"""

    assert df.isna().all(axis=1).any() == False
    assert df['open_time'].duplicated().any() == False


def quick_clean(df):
    """clean a raw dataframe"""

    # drop dupes
    dupes = df['open_time'].duplicated().sum()
    if dupes > 0:
        df = df[df['open_time'].duplicated() == False]

    # sort by timestamp, oldest first
    df.sort_values(by=['open_time'], ascending=False)

    # just a doublcheck
    assert_integrity(df)

    return df


def write_raw_to_parquet(df, full_path):
    """takes raw df and writes a parquet to disk"""

    # some candlesticks do not span a full minute
    # these points are not reliable and thus filtered
    df = df[~(df['open_time'] - df['close_time'] != -59999)]

    # `close_time` column has become redundant now, as is the column `ignore`
    df = df.drop(['close_time', 'ignore'], axis=1)

    df = set_dtypes_compressed(df)

    # give all pairs the same nice cut-off
    df = df[df.index < str(date.today())]

    df.to_parquet(full_path)


def groom_data(dirname='data'):
    """go through data folder and perform a quick clean on all csv files"""

    for filename in os.listdir(dirname):
        if filename.endswith('.csv'):
            full_path = f'{dirname}/{filename}'
            quick_clean(pd.read_csv(full_path)).to_csv(full_path)


def compress_data(dirname='data'):
    """go through data folder and rewrite csv files to parquets"""

    os.makedirs('compressed', exist_ok=True)
    for filename in os.listdir(dirname):
        if filename.endswith('.csv'):
            full_path = f'{dirname}/{filename}'

            df = pd.read_csv(full_path)

            new_filename = filename.replace('.csv', '.parquet')
            new_full_path = f'compressed/{new_filename}'
            write_raw_to_parquet(df, new_full_path)


In [5]:
import pandas as pd

all_symbols = pd.DataFrame(requests.get(f'{API_BASE}exchangeInfo').json()['symbols'])

In [6]:
dict_ticker = {
    'Bitcoin Cash':'BCH',
    'Binance Coin':'BNB',
    'Bitcoin':'BTC',
    'EOS.IO':'EOS',
    'Ethereum Classic':'ETC',
    'Ethereum':'ETH',
    'Litecoin':'LTC',
    'Monero':'XMR',
    'TRON':'TRX',
    'Stellar':'XLM',
    'Cardano':'ADA',
    'IOTA':'IOTA',
    'Maker':'MKR',
    'Dogecoin':'DOGE'
}

In [7]:
for a in dict_ticker:
    quoteAssetsa = all_symbols[all_symbols.baseAsset == dict_ticker[a]].quoteAsset.unique()
    USDquoteAssetsa = [qA for qA in quoteAssetsa if 'USD' in qA]
    print(USDquoteAssetsa)

['USDT', 'USDC', 'TUSD', 'BUSD', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'USDS', 'BUSD', 'USDP', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'USDS', 'BUSD', 'USDP', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'USDC', 'TUSD', 'BUSD', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'USDP', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'BUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'BUSD', 'FDUSD', 'USDC']
['USDT', 'BUSD', 'USDC']
['USDT', 'USDC', 'BUSD', 'TUSD', 'FDUSD']


In [8]:
quote = 'BUSD'

In [9]:
all_pairs = [(dict_ticker[a],quote) for a in dict_ticker]

In [10]:
all_pairs

[('BCH', 'BUSD'),
 ('BNB', 'BUSD'),
 ('BTC', 'BUSD'),
 ('EOS', 'BUSD'),
 ('ETC', 'BUSD'),
 ('ETH', 'BUSD'),
 ('LTC', 'BUSD'),
 ('XMR', 'BUSD'),
 ('TRX', 'BUSD'),
 ('XLM', 'BUSD'),
 ('ADA', 'BUSD'),
 ('IOTA', 'BUSD'),
 ('MKR', 'BUSD'),
 ('DOGE', 'BUSD')]

In [11]:
# make sure data folders exist
os.makedirs('data', exist_ok=True)
os.makedirs('compressed', exist_ok=True)

# do a full update on all pairs
n_count = len(all_pairs)
for n, pair in enumerate(all_pairs, 1):
    base, quote = pair
    new_lines = all_candles_to_csv(base=base, quote=quote)
    if new_lines > 0:
        print(f'{datetime.now()} {n}/{n_count} Wrote {new_lines} new lines to file for {base}-{quote}')
    else:
        print(f'{datetime.now()} {n}/{n_count} Already up to date with {base}-{quote}')

2025-03-20 18:17:37.483265 1/14 Already up to date with BCH-BUSD
2025-03-20 18:17:40.429089 2/14 Already up to date with BNB-BUSD
2025-03-20 18:17:43.433477 3/14 Already up to date with BTC-BUSD
2025-03-20 18:17:46.061443 4/14 Already up to date with EOS-BUSD
2025-03-20 18:17:48.802254 5/14 Already up to date with ETC-BUSD
2025-03-20 18:17:51.807187 6/14 Already up to date with ETH-BUSD
2025-03-20 18:17:54.654603 7/14 Already up to date with LTC-BUSD
2025-03-20 18:17:57.165996 8/14 Already up to date with XMR-BUSD
2025-03-20 18:17:59.865955 9/14 Already up to date with TRX-BUSD
2025-03-20 18:18:02.488731 10/14 Already up to date with XLM-BUSD
2025-03-20 18:18:05.276881 11/14 Already up to date with ADA-BUSD
2025-03-20 18:18:07.555359 12/14 Already up to date with IOTA-BUSD
2025-03-20 18:18:09.895286 13/14 Already up to date with MKR-BUSD
2025-03-20 18:18:12.273382 14/14 Already up to date with DOGE-BUSD


# Deep Q Network (DQN)
## Create dataset class and instantiate a BTC-BUSD dataset object

In [1]:
# CUDA device setup
import torch
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print('device:', device)

device: cuda:0


In [2]:
#import pyarrow as pa
#import pyarrow.parquet as pq
import pyarrow.dataset as ds
#import pandas as pd

from typing import Any
#from typing import List

import torch
import pyarrow.dataset
#import pyarrow.dataset as ds
from torch.utils.data import Dataset

class BinanceDataset(Dataset):
    def __init__(self, data: pyarrow.dataset.Dataset) -> None:
        self.data = data

    def __len__(self) -> None:
        return self.data.count_rows()

    def __getitem__(self, index) -> tuple[list[Any], Any]:
        if type(index) is tuple:
            start, count = index
        else:
            start = index
            count = 1
        rows = self.data.take(range(start, start + count, 1)).to_pydict()
        return rows

dataset = BinanceDataset(data=pyarrow.dataset.dataset('compressed/BTC-BUSD.parquet', format='parquet'))

print(len(dataset))
print(dataset[0].keys())

2226200
dict_keys(['open', 'high', 'low', 'close', 'volume', 'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'open_time'])


## Plot a random sequence (24h) of data

In [3]:
# plot data for inspection
import plotly.graph_objects as go

import pandas as pd
from datetime import datetime
import random

# Get a random sequence from the dataset
df = dataset[random.randint(0, len(dataset) - 1), 1440]

fig = go.Figure(data=[go.Candlestick(x=df['open_time'], open=df['open'], high=df['high'], low=df['low'], close=df['close'])])

fig.show()

In [4]:
# A functor class calculating calculating a commission based on a percentage of the net_price,
# optionally applying minimum and maximum limits.
class Commission:
    def __init__(self, percentage: float, min_commission: float = 0, max_commission: float = float('inf')):
        self.__percentage = percentage
        self.__min_commission = min_commission
        self.__max_commission = max_commission
    def __call__(self, net_price: float) -> float:
        commission = net_price * self.__percentage/100
        commission = max(commission, self.__min_commission)
        commission = min(commission, self.__max_commission)
        return commission

# calculate the max amount of an asset than can purchased given liquidity, price, min fraction and commission
def max_buy(liquidity: float, price: float, min_fraction: float = 1, commission: Commission = None) -> tuple[float, float]:
    # buy max: division rounding down to 'min_fraction', assuming no commission
    amount = (liquidity / price)//min_fraction/(1/min_fraction)

    while True:
        # calculate total price
        total = amount * price

        # add commission if present
        if commission is not None:
            total = total + commission(total)

        # if we can afford to buy 'amount' return
        if total <= liquidity:
            return amount, total

        # reduce 'amount' by 'min_fraction' and try again
        amount -= min_fraction

# calculate max amount of asset that can be sold given holding, price and commission
def max_sell(holding: float, price: float, commission: Commission = None) -> tuple[float, float]:
    amount = holding # sell all
    total = amount * price

    # add commission if present
    if commission is not None:
        total = total - commission(total)

    return amount, total

from enum import Enum

# Portfolio management is calculating the relative distribution of assets in the portfolio - cash is one asset...

# Available agent actions
class Action(Enum):
    BUY = 0
    KEEP = 1 # TODO: HOLD
    SELL = 2

import random
from collections import deque

# Crypto environment for RL model
class Crypto:
    def __init__(self, dataset: Dataset, seq_len: int, epoch_size: int, investment: float, min_fraction: float = 0.01, commission: Commission = None) -> None:
        self.__dataset = dataset
        self.__seq_len = seq_len # time-series length
        self.__epoch_size = epoch_size # max number of iterations in an epoch
        self.__investment = investment # cash investment
        self.__min_fraction = min_fraction  # minimum fraction of the asset that can be traded
        self.__commission = commission # brokerage fee (percentage) added when buying and selling assets
        self.__keys = ['open', 'high', 'low', 'close', 'volume' ] # relevant candlestick keys
        self.new_epoch()

    def __advance_state(self, count: int = 1):
        # read n indicies and advance index by n
        candlesticks = self.__dataset[self.__index, count]
        n = len(candlesticks[self.__keys[0]])
        self.__index += n

        # add relevant keys to the sequence
        for key in self.__keys:
            self.__state[key].extend(candlesticks[key])

        return n    

    def __update_holding(self, liquidity: float, holding: float, count: int = 1):
        # add liquidity and holding to the sequence (expand to same length as candlestick data)
        self.__state['liquidity'].extend([liquidity] * count) # cash currently held
        self.__state['holding'].extend([holding] * count) # number of assets currently held

    def new_epoch(self) -> None:
        # create new state with candlestick keys + liquidity and holding
        self.__state = { key: deque(maxlen=self.__seq_len) for key in self.__keys + ['liquidity', 'holding'] }

        # reinitialize the start index, make sure we get a full sequence
        self.__index = random.randint(0, len(self.__dataset) - self.__seq_len)
        self.__i = 0
        self.__transactions_buy = 0
        self.__transactions_sell = 0
        self.__transactions_hold = 0

        # get a full sequence of data
        self.__advance_state(count=self.__seq_len)
        self.__update_holding(liquidity=self.__investment, holding=0, count=self.__seq_len)

    def is_terminal(self) -> bool:
        return not ((self.__index < len(self.__dataset)) and (self.__i < self.__epoch_size))

    def observe(self) -> dict:
        return self.__state

    def iterations(self) -> int:
        return self.__i

    def value(self) -> float:
        close_price = self.__state['close'][-1]
        liquidity = self.__state['liquidity'][-1]
        holding = self.__state['holding'][-1]
        value = holding * close_price + liquidity
        return value

    def transactions(self) -> int:
        return { 'buy': self.__transactions_buy, 'hold': self.__transactions_hold, 'sell': self.__transactions_sell, 'all': self.__transactions_buy + self.__transactions_sell }
    
    # Take an action in this environment, advance to next state and return
    # a tuple containing; the new state, the reward and the "terminality" of the new state
    def act(self, action: Action) -> tuple[dict, float, bool]:
        if self.is_terminal():
            return 0 # return no reward if no action can be taken

        # increase iteration counter
        self.__i += 1

        # get closing price, liquidity and holding before action
        close_price = self.__state['close'][-1]

        liquidity = self.__state['liquidity'][-1]
        holding = self.__state['holding'][-1]

        # calculate value before action
        value_before_action = holding * close_price + liquidity

        # advance the time series
        n = self.__advance_state()

        # get opening price for action
        open_price = self.__state['open'][-1]

        # take action
        if action == Action.BUY:
            # calculate amount to buy and price to pay
            amount, price = max_buy(liquidity=liquidity, price=open_price, min_fraction=self.__min_fraction, commission=self.__commission)

            if amount > 0:
                self.__transactions_buy += 1

            # add the amount of the asset to the current holding
            holding += amount
            # subtract the total price from the cash holding
            liquidity -= price
        elif action == Action.KEEP:
            # do nothing
            self.__transactions_hold += 1
            pass
        elif action == Action.SELL:
            amount, price = max_sell(holding=holding, price=open_price, commission=self.__commission)

            if amount > 0:
                self.__transactions_sell += 1

            holding -= amount
            liquidity += price
        else:
            # should never happen
            pass

        self.__update_holding(liquidity=liquidity, holding=holding, count=n)

        # get closing price after action
        close_price = self.__state['close'][-1]

        # calculate value after action
        value_after_action = holding * close_price + liquidity

        reward = value_after_action - value_before_action

        return self.observe(), reward, self.is_terminal()

# JUST TESTING - REMOVE!
total = 0

for i in range(10):
    action = random.choice(list(Action))
    state, reward, done = environment.act(action=action)
    total += reward
    print('action:', action, 'reward:', reward, 'total:', total)

print(state.keys())

In [5]:
# DDQN reinforcement learner
# Advantage Actor-Critic Network: https://www.youtube.com/watch?v=wDVteayWWvU&list=PLMrJAkhIeNNQe1JXNvaFvURxGY4gE9k74&index=7&t=1011s

from torch import nn
from collections import OrderedDict

# NOTE: Should not use dropout and batch normalization together
# TODO: Normalize input
# Input tensor shape: (batches, input_channels, input_dim)
# Output tensor shape: (batches, actions)
class QFunction(nn.Module):
    def __init__(self, input_dim: int, input_channels: int, output_dim: int, conv_layers: int = 3, fc_layers: int = 1, fc_size: int = 1024):
        super(QFunction, self).__init__()

        # 1d convolutions
        self.conv = nn.Sequential(OrderedDict(
            [ ('conv' + str(n + 1), self.__conv_layer_set(in_channels=input_channels*(2**n), out_channels=input_channels*(2**(n+1)))) for n in range(conv_layers) ]))
        """
        self.conv = nn.Sequential(OrderedDict([
            ('conv1', self.__conv_layer_set(in_channels=input_channels, out_channels=input_channels*2)),
            #('dropout1', nn.Dropout(p=0.1)),
            ('conv2', self.__conv_layer_set(in_channels=input_channels*2, out_channels=input_channels*4)),
            #('dropout2', nn.Dropout(p=0.1)),
            ('conv3', self.__conv_layer_set(in_channels=input_channels*4, out_channels=input_channels*8)),
            #('dropout3', nn.Dropout(p=0.1)),
        ]))
        """

        # fully connected
        self.fc = nn.Sequential(OrderedDict(
            [ ('fc1', self.__fc_layer_set(in_features=input_dim*input_channels, out_features=fc_size)) ] +
            [ ('fc' + str(n + 2), self.__fc_layer_set(in_features=fc_size, out_features=fc_size)) for n in range(fc_layers) ] +
            [ ('linear', nn.Linear(in_features=fc_size, out_features=output_dim)),
              ('softmax', nn.Softmax(dim=1)) ]))
        """
        self.fc = nn.Sequential(OrderedDict([
            ('fc1', self.__fc_layer_set(in_features=input_dim*input_channels, out_features=1024)),
            #('dropout4', nn.Dropout(p=0.1)),
            ('fc2', self.__fc_layer_set(in_features=1024, out_features=512)),
            #('dropout5', nn.Dropout(p=0.1)),
            ('linear', nn.Linear(in_features=512, out_features=output_dim)),
            ('softmax', nn.Softmax(dim=1))
        ]))
        """

    def __conv_layer_set(self, in_channels: int, out_channels: int) -> nn.Module:
        conv_layer = nn.Sequential(OrderedDict([
            ('conv1d', nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)),
            ('maxpool', nn.MaxPool1d(kernel_size=2)),
            ('leaky_relu', nn.LeakyReLU()),
            ('batch_norm', nn.BatchNorm1d(num_features=out_channels))
        ]))
        return conv_layer
    
    def __fc_layer_set(self, in_features: int, out_features: int) -> nn.Module:
        fc_layer = nn.Sequential(OrderedDict([
            ('linear', nn.Linear(in_features=in_features, out_features=out_features)),
            ('leaky_relu', nn.LeakyReLU()),
            ('batch_norm', nn.BatchNorm1d(num_features=out_features))
        ]))
        return fc_layer

    def forward(self, x):
        #print('x.shape #0:', x.shape)
        x = self.conv(x)
        #print('x.shape #1:', x.shape)
        x = x.view(x.size(0), -1)
        #print('x.shape #2:', x.shape)
        x = self.fc(x)
        #print('x.shape #3:', x.shape)
        return x

In [6]:
import copy

q_fn = QFunction(input_dim=120, input_channels=7, conv_layers=1, fc_layers=0, fc_size=256, output_dim=len(Action)).to(device=device)
q_fn_target = copy.deepcopy(q_fn)
q_fn_target.load_state_dict(q_fn.state_dict())
q_fn_target.eval()

QFunction(
  (conv): Sequential(
    (conv1): Sequential(
      (conv1d): Conv1d(7, 14, kernel_size=(3,), stride=(1,), padding=(1,))
      (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (leaky_relu): LeakyReLU(negative_slope=0.01)
      (batch_norm): BatchNorm1d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (fc): Sequential(
    (fc1): Sequential(
      (linear): Linear(in_features=840, out_features=256, bias=True)
      (leaky_relu): LeakyReLU(negative_slope=0.01)
      (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (linear): Linear(in_features=256, out_features=3, bias=True)
    (softmax): Softmax(dim=1)
  )
)

In [7]:
actions = [action for action in (Action)]

#import numpy as np
#probabilities = [0.2, 0.5, 0.3]
#print(np.random.choice(actions, 1, p=probabilities))

#x = torch.randn(2, 7, 1440)
#y = q_fn(x)

#print(y.shape)
#print(y)

#x = torch.randn(1, 32)
#bn = nn.BatchNorm1d(num_features=32)
#bn(x)


# check dimension of state
#print(make_tensor([environment.observe()]).shape)


In [8]:
def make_tensor(states):
    keys = [ 'open', 'high', 'low', 'close', 'volume', 'liquidity', 'holding' ]
    tensor = nn.functional.normalize(torch.tensor([ [ state[key] for key in keys ] for state in states ], dtype=torch.float).to(device=device))
    #tensor = torch.tensor([ [ state[key] for key in keys ] for state in states ], dtype=torch.float).to(device=device)
    return tensor


In [None]:
import numpy as np

from collections import deque
#import torch.utils.tensorboard

# environment: environment
# epochs: number of epochs to train
# batch_size: training batch size
# replay_size: size of experience replay buffer, once full old items will be dropped when new ones are insterted
# learning_rate: learning rate - duh
# exploration_rate: epsilon in epsilon-greedy algorithm
# discount: gamma in quality (reward) calculation
# sync_rate: target network synchronization rate
# TODO: episode, bootstrap experience replay buffer, collect in class, proper count of target sync + normalize input! + check randomness of episode start!
def train(environment, epochs: int, batch_size: int, replay_size: int, learning_rate:float, exploration_rate: float = .1, discount:float = .9, sync_rate: int = 50, writer = None):
    # initialize the experience replay buffer
    replay = deque(maxlen=replay_size)

    # loss function and optimizer
    loss_fn = nn.MSELoss().to(device=device)
    optimizer = torch.optim.Adam(q_fn.parameters(), lr=learning_rate)

    # exploration rate and discount of future rewards
    epsilon = exploration_rate
    gamma = discount

    # initialize losses list
    losses = []

    for epoch in range(epochs):
        epoch_loss = 0

        environment.new_epoch()
        state = environment.observe()

        while not environment.is_terminal():
            # agent taking an action with epsilon-greedy strategy
            if np.random.random() < epsilon:
                action = np.random.choice(actions)
            else:
                # Select the action with the largest Q-value
                with torch.no_grad():
                    q_values = q_fn_target(make_tensor([state])).cpu().data.numpy()
                action = Action(np.argmax(q_values))

            # advancing the state of the board
            new_state, reward, done = environment.act(action)

            # saving the (state, action, new_state, reward) values in the experience replay buffer
            replay.append((state, action, new_state, reward, done))

            # when we get to the batch size we train the network
            if len(replay) > batch_size:
                batch = random.sample(replay, batch_size)

                # TODO: Optimize! Extract all in one loop!
                batch_states = make_tensor([state for state, _, _, _, _ in batch])
                batch_actions = torch.tensor([action.value for _, action, _, _, _ in batch], dtype=torch.float).to(device=device)
                new_stateBatch = make_tensor([state for _, _, state, _, _ in batch])
                batch_rewards = torch.tensor([reward for _, _, _, reward, _ in batch], dtype=torch.float).to(device=device)
                batch_done = torch.tensor([+(done) for _, _, _, _, done in batch], dtype=torch.int).to(device=device)

                # use the target network to bootstrap
                with torch.no_grad():
                    new_stateQ = q_fn_target(new_stateBatch)

                # Actual reward, discounted by gamma
                # compute the predition of the networks for R + gamma*MaxRewards
                Y = batch_rewards + gamma * (1 - batch_done)*torch.max(new_stateQ,dim=1)[0]

                # Estimate of future reward
                # Compute the discounted rewards using the first original network
                stateQ = q_fn(batch_states).gather(dim=1, index=batch_actions.long().unsqueeze(dim=1)).squeeze()

                # compute the loss of the model and backpropagate
                loss = loss_fn(stateQ, Y.detach())
                optimizer.zero_grad()
                loss.backward()

                # keeping track of the losses
                epoch_loss += loss.item()

                optimizer.step()

                # periodically copy parameters to the target network
                if environment.iterations() % sync_rate == 0:
                    print('update target network')
                    q_fn_target.load_state_dict(q_fn.state_dict())

            # advance the state
            state = new_state

        if environment.iterations() > 0:
            epoch_avg_loss = epoch_loss / environment.iterations()
        else:
            epoch_avg_loss = epoch_loss

        if writer is not None:
            writer.add_scalar('Loss/train', epoch_avg_loss, epoch)
            writer.add_scalar('Reward/train', environment.value(), epoch)
            writer.add_scalars('Transactions/train', {'all': environment.transactions()['all'], 'buy': environment.transactions()['buy'], 'hold': environment.transactions()['hold'], 'sell': environment.transactions()['sell']}, epoch)
            #writer.add_scalar('Transactions/train', environment.transactions(), epoch)

        print('epoch: ', epoch, 'loss: ', epoch_avg_loss)

        losses.append(epoch_avg_loss)

    return losses

In [10]:
# load models
import os

q_fn = torch.load(os.path.join('models', 'q_fn'), weights_only=False)
q_fn_target = torch.load(os.path.join('models', 'q_fn_target'), weights_only=False)


In [11]:
import datetime
log_dir = "runs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

import torch
from torch.utils.tensorboard import SummaryWriter

# create tensorboard log writer
writer = SummaryWriter(log_dir=log_dir)

# creating the environment
btcdata = BinanceDataset(data=pyarrow.dataset.dataset('compressed/BTC-BUSD.parquet', format='parquet'))
commission = Commission(percentage=0.1, min_commission=2.5)
environment = Crypto(dataset=btcdata, seq_len=120, epoch_size=360, investment=10000, min_fraction=0.01, commission=commission)

# start training
losses = train(environment=environment, epochs=500, batch_size=512, replay_size=10000, learning_rate=0.00001, sync_rate=100, writer=writer)

# flush the writer
writer.flush()
writer.close()

epoch:  0 loss:  0.0
update target network
update target network
epoch:  1 loss:  21.612644116083782
update target network
update target network
update target network
epoch:  2 loss:  29.822107945548165
update target network
update target network
update target network
epoch:  3 loss:  25.979192580117118
update target network
update target network
update target network
epoch:  4 loss:  22.805673413806492
update target network
update target network
update target network
epoch:  5 loss:  20.904155574904546
update target network
update target network
update target network
epoch:  6 loss:  19.822545199924043
update target network
update target network
update target network
epoch:  7 loss:  24.84331144226922
update target network
update target network
update target network
epoch:  8 loss:  30.0245238410102
update target network
update target network
update target network
epoch:  9 loss:  31.382758055792916
update target network
update target network
update target network
epoch:  10 loss:  30

In [12]:
# save models
import os

if not os.path.exists('models'):
    os.makedirs('models')

torch.save(q_fn, os.path.join('models', 'q_fn'))
torch.save(q_fn_target, os.path.join('models', 'q_fn_target'))
