# Crystal Ball
AI providing buy/sell/hold actions based on candlestick data.

## Binance
https://www.kaggle.com/code/lucasmorin/getting-all-1m-data-from-binance/notebook

https://developers.binance.com/docs/binance-spot-api-docs/rest-api/market-data-endpoints#klinecandlestick-data

## Reinforcement Learning
https://www.youtube.com/playlist?list=PLMrJAkhIeNNQe1JXNvaFvURxGY4gE9k74
### Pong from Pixels
https://karpathy.github.io/2016/05/31/rl/

### Stable Baselines
https://stable-baselines3.readthedocs.io/en/master/
https://anaconda.org/conda-forge/stable-baselines3

### Policy Gradient Methods
https://youtu.be/5P7I-xPq8u8?si=hXVvvLb1S8XcWGfz

### Example
https://towardsdatascience.com/how-to-train-an-ai-to-play-any-game-f1489f3bc5c
https://github.com/guszejnovdavid/custom_game_reinforcement_learning/blob/main/custom_game_reinforcement_learning.ipynb

### Plotting
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#
https://plotly.com/python/candlestick-charts/
https://github.com/highfestiva/finplot

### Installation
~~~shell
$ conda install pytorch torchvision torchaudio pytorch-cuda pandas -c pytorch -c nvidia
$ conda install pyarrow -c conda-forge
$ #conda install conda-forge::stable-baselines3
$ conda install -c plotly plotly
$ conda install -c conda-forge nbformat
~~~

In [13]:
# CUDA device setup
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print('device:', device)

device: cuda:0


In [14]:
import json
import os
import random
import subprocess
import time
from datetime import date, datetime, timedelta
from datetime import date


import requests
import pandas as pd
import numpy as np

API_BASE = 'https://api.binance.com/api/v3/'

LABELS = [
    'open_time',
    'open',
    'high',
    'low',
    'close',
    'volume',
    'close_time',
    'quote_asset_volume',
    'number_of_trades',
    'taker_buy_base_asset_volume',
    'taker_buy_quote_asset_volume',
    'ignore'
]


In [15]:
def get_batch(symbol, interval='1m', start_time=0, limit=1000):
    """Use a GET request to retrieve a batch of candlesticks. Process the JSON into a pandas
    dataframe and return it. If not successful, return an empty dataframe.
    """

    params = {
        'symbol': symbol,
        'interval': interval,
        'startTime': start_time,
        'limit': limit
    }
    try:
        # timeout should also be given as a parameter to the function
        response = requests.get(f'{API_BASE}klines', params, timeout=30)
    except requests.exceptions.ConnectionError:
        print('Connection error, Cooling down for 5 mins...')
        time.sleep(5 * 60)
        return get_batch(symbol, interval, start_time, limit)
    
    except requests.exceptions.Timeout:
        print('Timeout, Cooling down for 5 min...')
        time.sleep(5 * 60)
        return get_batch(symbol, interval, start_time, limit)
    
    except requests.exceptions.ConnectionResetError:
        print('Connection reset by peer, Cooling down for 5 min...')
        time.sleep(5 * 60)
        return get_batch(symbol, interval, start_time, limit)

    if response.status_code == 200:
        return pd.DataFrame(response.json(), columns=LABELS)
    print(f'Got erroneous response back: {response}')
    return pd.DataFrame([])

# TODO: No new data is available on this channel?
def all_candles_to_csv(base, quote, interval='1m'):
    """Collect a list of candlestick batches with all candlesticks of a trading pair,
    concat into a dataframe and write it to CSV.
    """

    # see if there is any data saved on disk already
    try:
        batches = [pd.read_csv(f'data/{base}-{quote}.csv')]
        last_timestamp = batches[-1]['open_time'].max()
    except FileNotFoundError:
        batches = [pd.DataFrame([], columns=LABELS)]
        last_timestamp = 0
    old_lines = len(batches[-1].index)

    # gather all candlesticks available, starting from the last timestamp loaded from disk or 0
    # stop if the timestamp that comes back from the api is the same as the last one
    previous_timestamp = None

    while previous_timestamp != last_timestamp:
        # stop if we reached data from today
        if date.fromtimestamp(last_timestamp / 1000) >= date.today():
            break

        previous_timestamp = last_timestamp

        new_batch = get_batch(
            symbol=base+quote,
            interval=interval,
            start_time=last_timestamp+1
        )

        # requesting candles from the future returns empty
        # also stop in case response code was not 200
        if new_batch.empty:
            break

        last_timestamp = new_batch['open_time'].max()

        # sometimes no new trades took place yet on date.today();
        # in this case the batch is nothing new
        if previous_timestamp == last_timestamp:
            break

        batches.append(new_batch)
        last_datetime = datetime.fromtimestamp(last_timestamp / 1000)

        covering_spaces = 20 * ' '
        print(datetime.now(), base, quote, interval, str(last_datetime)+covering_spaces, end='\r', flush=True)

    # write clean version of csv to parquet
    parquet_name = f'{base}-{quote}.parquet'
    full_path = f'compressed/{parquet_name}'
    df = pd.concat(batches, ignore_index=True)
    df = quick_clean(df)
    write_raw_to_parquet(df, full_path)

    # in the case that new data was gathered write it to disk
    if len(batches) > 1:
        df.to_csv(f'data/{base}-{quote}.csv', index=False)
        return len(df.index) - old_lines
    return 0

def set_dtypes(df):
    """
    set datetimeindex and convert all columns in pd.df to their proper dtype
    assumes csv is read raw without modifications; pd.read_csv(csv_filename)"""

    df['open_time'] = pd.to_datetime(df['open_time'], unit='ms')
    df = df.set_index('open_time', drop=True)

    df = df.astype(dtype={
        'open': 'float64',
        'high': 'float64',
        'low': 'float64',
        'close': 'float64',
        'volume': 'float64',
        'close_time': 'datetime64[ms]',
        'quote_asset_volume': 'float64',
        'number_of_trades': 'int64',
        'taker_buy_base_asset_volume': 'float64',
        'taker_buy_quote_asset_volume': 'float64',
        'ignore': 'float64'
    })

    return df


def set_dtypes_compressed(df):
    """Create a `DatetimeIndex` and convert all critical columns in pd.df to a dtype with low
    memory profile. Assumes csv is read raw without modifications; `pd.read_csv(csv_filename)`."""

    df['open_time'] = pd.to_datetime(df['open_time'], unit='ms')
    df = df.set_index('open_time', drop=True)

    df = df.astype(dtype={
        'open': 'float32',
        'high': 'float32',
        'low': 'float32',
        'close': 'float32',
        'volume': 'float32',
        'number_of_trades': 'uint16',
        'quote_asset_volume': 'float32',
        'taker_buy_base_asset_volume': 'float32',
        'taker_buy_quote_asset_volume': 'float32'
    })

    return df


def assert_integrity(df):
    """make sure no rows have empty cells or duplicate timestamps exist"""

    assert df.isna().all(axis=1).any() == False
    assert df['open_time'].duplicated().any() == False


def quick_clean(df):
    """clean a raw dataframe"""

    # drop dupes
    dupes = df['open_time'].duplicated().sum()
    if dupes > 0:
        df = df[df['open_time'].duplicated() == False]

    # sort by timestamp, oldest first
    df.sort_values(by=['open_time'], ascending=False)

    # just a doublcheck
    assert_integrity(df)

    return df


def write_raw_to_parquet(df, full_path):
    """takes raw df and writes a parquet to disk"""

    # some candlesticks do not span a full minute
    # these points are not reliable and thus filtered
    df = df[~(df['open_time'] - df['close_time'] != -59999)]

    # `close_time` column has become redundant now, as is the column `ignore`
    df = df.drop(['close_time', 'ignore'], axis=1)

    df = set_dtypes_compressed(df)

    # give all pairs the same nice cut-off
    df = df[df.index < str(date.today())]

    df.to_parquet(full_path)


def groom_data(dirname='data'):
    """go through data folder and perform a quick clean on all csv files"""

    for filename in os.listdir(dirname):
        if filename.endswith('.csv'):
            full_path = f'{dirname}/{filename}'
            quick_clean(pd.read_csv(full_path)).to_csv(full_path)


def compress_data(dirname='data'):
    """go through data folder and rewrite csv files to parquets"""

    os.makedirs('compressed', exist_ok=True)
    for filename in os.listdir(dirname):
        if filename.endswith('.csv'):
            full_path = f'{dirname}/{filename}'

            df = pd.read_csv(full_path)

            new_filename = filename.replace('.csv', '.parquet')
            new_full_path = f'compressed/{new_filename}'
            write_raw_to_parquet(df, new_full_path)


In [16]:
all_symbols = pd.DataFrame(requests.get(f'{API_BASE}exchangeInfo').json()['symbols'])

In [17]:
dict_ticker = {
    'Bitcoin Cash':'BCH',
    'Binance Coin':'BNB',
    'Bitcoin':'BTC',
    'EOS.IO':'EOS',
    'Ethereum Classic':'ETC',
    'Ethereum':'ETH',
    'Litecoin':'LTC',
    'Monero':'XMR',
    'TRON':'TRX',
    'Stellar':'XLM',
    'Cardano':'ADA',
    'IOTA':'IOTA',
    'Maker':'MKR',
    'Dogecoin':'DOGE'
}

In [18]:
for a in dict_ticker:
    quoteAssetsa = all_symbols[all_symbols.baseAsset == dict_ticker[a]].quoteAsset.unique()
    USDquoteAssetsa = [qA for qA in quoteAssetsa if 'USD' in qA]
    print(USDquoteAssetsa)

['USDT', 'USDC', 'TUSD', 'BUSD', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'USDS', 'BUSD', 'USDP', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'USDS', 'BUSD', 'USDP', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'USDC', 'TUSD', 'BUSD', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'USDP', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'BUSD']
['USDT', 'TUSD', 'USDC', 'BUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'TUSD', 'USDC', 'BUSD', 'FDUSD']
['USDT', 'BUSD', 'FDUSD']
['USDT', 'BUSD']
['USDT', 'USDC', 'BUSD', 'TUSD', 'FDUSD']


In [19]:
quote = 'BUSD'

In [20]:
all_pairs = [(dict_ticker[a],quote) for a in dict_ticker]

In [21]:
all_pairs

[('BCH', 'BUSD'),
 ('BNB', 'BUSD'),
 ('BTC', 'BUSD'),
 ('EOS', 'BUSD'),
 ('ETC', 'BUSD'),
 ('ETH', 'BUSD'),
 ('LTC', 'BUSD'),
 ('XMR', 'BUSD'),
 ('TRX', 'BUSD'),
 ('XLM', 'BUSD'),
 ('ADA', 'BUSD'),
 ('IOTA', 'BUSD'),
 ('MKR', 'BUSD'),
 ('DOGE', 'BUSD')]

In [22]:
# make sure data folders exist
os.makedirs('data', exist_ok=True)
os.makedirs('compressed', exist_ok=True)

# do a full update on all pairs
n_count = len(all_pairs)
for n, pair in enumerate(all_pairs, 1):
    base, quote = pair
    new_lines = all_candles_to_csv(base=base, quote=quote)
    if new_lines > 0:
        print(f'{datetime.now()} {n}/{n_count} Wrote {new_lines} new lines to file for {base}-{quote}')
    else:
        print(f'{datetime.now()} {n}/{n_count} Already up to date with {base}-{quote}')

2025-02-05 19:06:54.206969 1/14 Already up to date with BCH-BUSD
2025-02-05 19:06:57.116453 2/14 Already up to date with BNB-BUSD
2025-02-05 19:06:59.993701 3/14 Already up to date with BTC-BUSD
2025-02-05 19:07:02.561299 4/14 Already up to date with EOS-BUSD
2025-02-05 19:07:05.200681 5/14 Already up to date with ETC-BUSD
2025-02-05 19:07:08.104281 6/14 Already up to date with ETH-BUSD
2025-02-05 19:07:10.870078 7/14 Already up to date with LTC-BUSD
2025-02-05 19:07:13.299854 8/14 Already up to date with XMR-BUSD
2025-02-05 19:07:15.932879 9/14 Already up to date with TRX-BUSD
2025-02-05 19:07:18.442430 10/14 Already up to date with XLM-BUSD
2025-02-05 19:07:21.098272 11/14 Already up to date with ADA-BUSD
2025-02-05 19:07:23.217265 12/14 Already up to date with IOTA-BUSD
2025-02-05 19:07:25.430887 13/14 Already up to date with MKR-BUSD
2025-02-05 19:07:27.772429 14/14 Already up to date with DOGE-BUSD


0.33
0.33


# Deep Q Network stuff below

In [4]:
#import pyarrow as pa
#import pyarrow.parquet as pq
import pyarrow.dataset as ds
#import pandas as pd

from typing import Any
#from typing import List

import torch
import pyarrow.dataset
#import pyarrow.dataset as ds
from torch.utils.data import Dataset

class BinanceDataset(Dataset):
    def __init__(self, data: pyarrow.dataset.Dataset) -> None:
        self.data = data

    def __len__(self) -> None:
        return self.data.count_rows()

    def __getitem__(self, index) -> tuple[list[Any], Any]:
        if type(index) is tuple:
            start, count = index
        else:
            start = index
            count = 1
        rows = self.data.take(range(start, start + count, 1)).to_pydict()
        return rows

dataset = BinanceDataset(data=pyarrow.dataset.dataset('compressed/BTC-BUSD.parquet', format='parquet'))

print(len(dataset))
print(dataset[0].keys())

2226200
dict_keys(['open', 'high', 'low', 'close', 'volume', 'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'open_time'])


In [6]:
# plot data for inspection
import plotly.graph_objects as go

import pandas as pd
from datetime import datetime
import random

# Get a random sequence from the dataset
df = dataset[random.randint(0, len(dataset) - 1), 1440]

fig = go.Figure(data=[go.Candlestick(x=df['open_time'], open=df['open'], high=df['high'], low=df['low'], close=df['close'])])

fig.show()

In [7]:
# A functor class calculating calculating a commission based on a percentage of the net_price,
# optionally applying minimum and maximum limits.
class Commission:
    def __init__(self, percentage: float, min_commission: float = 0, max_commission: float = float('inf')):
        self.__percentage = percentage
        self.__min_commission = min_commission
        self.__max_commission = max_commission
    def __call__(self, net_price: float) -> float:
        commission = net_price * self.__percentage/100
        commission = max(commission, self.__min_commission)
        commission = min(commission, self.__max_commission)
        return commission

# calculate the max amount of an asset than can purchased given liquidity, price, min fraction and commission
def max_buy(liquidity: float, price: float, min_fraction: float = 1, commission: Commission = None) -> tuple[float, float]:
    # buy max: division rounding down to 'min_fraction', assuming no commission
    amount = (liquidity / price)//min_fraction/(1/min_fraction)

    while True:
        # calculate total price
        total = amount * price

        # add commission if present
        if commission is not None:
            total = total + commission(total)

        # if we can afford to buy 'amount' return
        if total <= liquidity:
            return amount, total

        # reduce 'amount' by 'min_fraction' and try again
        amount -= min_fraction

# calculate max amount of asset that can be sold given holding, price and commission
def max_sell(holding: float, price: float, commission: Commission = None) -> tuple[float, float]:
    amount = holding # sell all
    total = amount * price

    # add commission if present
    if commission is not None:
        total = total - commission(total)

    return amount, total

from enum import Enum

# Portfolio management is calculating the relative distribution of assets in the portfolio - cash is one asset...

# Available agent actions
class Action(Enum):
    BUY = 0
    KEEP = 1
    SELL = 2

import random
from collections import deque

# Crypto environment for RL model
class Crypto:
    def __init__(self, dataset: Dataset, seq_len: int, investment: float, min_fraction: float = 0.01, commission: Commission = None) -> None:
        self.__dataset = dataset
        self.__seq_len = seq_len # time-series length
        self.__investment = investment # cash investment
        self.__min_fraction = min_fraction  # minimum fraction of the asset that can be traded
        self.__commission = commission # brokerage fee (percentage) added when buying and selling assets
        self.__keys = ['open', 'high', 'low', 'close', 'volume' ] # relevant candlestick keys
        self.reinit()

    def __advance_state(self, liquidity: float, holding: float, count: int = 1):
        # read n indicies and advance index by n
        candlesticks = self.__dataset[self.__index, count]
        n = len(candlesticks[self.__keys[0]])
        self.__index += n

        # add relevant keys to the sequence
        for key in self.__keys:
            self.__state[key].extend(candlesticks[key])
    
        # add liquidity and holding to the sequence (expand to same length as candlestick data)
        self.__state['liquidity'].extend([liquidity] * n) # cash currently held
        self.__state['holding'].extend([holding] * n) # number of assets currently held


    def reinit(self) -> None:
        # create new state with candlestick keys + liquidity and holding
        self.__state = { key: deque(maxlen=self.__seq_len) for key in self.__keys + ['liquidity', 'holding'] }

        # reinitialize the start index, make sure we get a full sequence
        self.__index = random.randint(0, len(self.__dataset) - self.__seq_len)

        # get a full sequence of data
        self.__advance_state(liquidity=self.__investment, holding=0, count=self.__seq_len)

    def is_terminal(self) -> bool:
        return not (self.__index < len(self.__dataset))

    def observe(self) -> dict:
        return self.__state

    # Take an action in this environment, advance to next state and return
    # a tuple containing; the new state, the reward and the "terminality" of the new state
    def act(self, action: Action) -> tuple[dict, float, bool]:
        if self.is_terminal():
            return 0 # return no reward if no action can be taken

        # get opening and closing price + liquidity and holding in the current iteration
        open_price = self.__state['open'][-1]
        close_price = self.__state['close'][-1]

        liquidity = self.__state['liquidity'][-1]
        holding = self.__state['holding'][-1]

        # calculate value before action
        value_before_action = holding * close_price + liquidity

        # take action
        if action == Action.BUY:
            # calculate amount to buy and price to pay
            amount, price = max_buy(liquidity=liquidity, price=open_price, min_fraction=self.__min_fraction, commission=self.__commission)
            # add the amount of the asset to the current holding
            holding += amount
            # subtract the total price from the cash holding
            liquidity -= price
        elif action == Action.KEEP:
            # do nothing
            pass
        elif action == Action.SELL:
            amount, price = max_sell(holding=holding, price=open_price, commission=self.__commission)
            holding -= amount
            liquidity += price
        else:
            # should never happen
            pass

        # advance the time series
        self.__advance_state(liquidity=liquidity, holding=holding)

        # get new closing price
        close_price = self.__state['close'][-1]

        # calculate value after action
        value_after_action = holding * close_price + liquidity

        reward = value_after_action - value_before_action

        return self.observe(), reward, self.is_terminal()

In [10]:
# creating the environment
btcdata = BinanceDataset(data=pyarrow.dataset.dataset('compressed/BTC-BUSD.parquet', format='parquet'))
commission = Commission(percentage=1.0, min_commission=2.5)
environment = Crypto(dataset=btcdata, seq_len=1440, investment=10000, min_fraction=0.01, commission=commission)

# Test code below
total = 0

for i in range(10):
    action = random.choice(list(Action))
    state, reward, done = environment.act(action=action)
    total += reward
    print('action:', action, 'reward:', reward, 'total:', total)

print(state.keys())

action: Action.BUY reward: -72.42991430664006 total: -72.42991430664006
action: Action.KEEP reward: 32.61822265624869 total: -39.81169165039137
action: Action.BUY reward: -33.28260253906228 total: -73.09429418945365
action: Action.SELL reward: -68.446846630859 total: -141.54114082031265
action: Action.KEEP reward: 0.0 total: -141.54114082031265
action: Action.KEEP reward: 0.0 total: -141.54114082031265
action: Action.BUY reward: -190.13086464843764 total: -331.6720054687503
action: Action.BUY reward: 83.69383789062522 total: -247.97816757812507
action: Action.BUY reward: -26.24090820312449 total: -274.21907578124956
action: Action.SELL reward: -76.05160273437468 total: -350.27067851562424
dict_keys(['open', 'high', 'low', 'close', 'volume', 'liquidity', 'holding'])


In [None]:
# DDQN
# Actor-Critic reinforcement learner
# Advantage Actor-Critic Network: https://www.youtube.com/watch?v=wDVteayWWvU&list=PLMrJAkhIeNNQe1JXNvaFvURxGY4gE9k74&index=7&t=1011s

from torch import nn
from collections import OrderedDict
# TODO: batchnorm + dropout + view in forward()
class QFunction(nn.Module):
    def __init__(self, input_dim, input_channels, output_dim):
        super(QFunction, self).__init__()

        # 1d convolutions
        self.conv = nn.Sequential(OrderedDict([
            ('conv1', self.__conv_layer_set(in_channels=input_channels, out_channels=input_channels*2)),
            #nn.DropOut(p=0.1),
            ('conv2', self.__conv_layer_set(in_channels=input_channels*2, out_channels=input_channels*4)),
            #nn.DropOut(p=0.1),
            ('conv3', self.__conv_layer_set(in_channels=input_channels*4, out_channels=input_channels*8)),
            #nn.DropOut(p=0.1)
        ]))

        # fully connected
        self.fc = nn.Sequential(
            nn.Linear(input_dim//8, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Linear(512, output_dim),
            nn.LeakyReLU(),
            nn.Softmax(dim=1)
        )

    def __conv_layer_set(self, in_channels, out_channels):
        conv_layer = nn.Sequential(OrderedDict([
            ('conv1d', nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)),
            ('maxpool', nn.MaxPool1d(kernel_size=2)),
            ('leakyrelu', nn.LeakyReLU()),
            ('batch_norm', nn.BatchNorm1d(num_features=out_channels))
        ]))
        return conv_layer

    def forward(self, x):
        print('x.shape:', x.shape)
        out = self.conv(x)
        print('out.shape #1:', out.shape)
        out = out.view(self.batch_size, -1)
        print('out.shape #2:', out.shape)
        out = self.fc(out)
        print('out.shape #3:', out.shape)
        return out

In [7]:
import copy

qFunction = QFunction(input_dim = seq_len, output_dim=len(Action))
qFunctionTarget = copy.deepcopy(qFunction)
qFunctionTarget.load_state_dict(qFunction.state_dict())

<All keys matched successfully>

In [None]:
# JUST TESTING - REMOVE!
import numpy as np

actions = [action for action in (Action)]
probabilities = [0.2, 0.5, 0.3]
np.random.choice(actions, 1, p=probabilities)

array([<Action.SELL: 2>], dtype=object)

In [None]:
from collections import deque

# nGames: epochs (games)
# lGames: max iterations per epoch
# batchSize: training batch size
# replaySize: size of replay buffer, once full old items will be dropped when new ones are insterted
# learning_rate: learning rate - duh
# syncFreq: Target network synchronization rate??????
def train(nGames, lGames, batchSize, replaySize, learning_rate, syncFreq):
    epsilon = .1 # the initial exploration rate
    gamma = .9 # the discount of the rewards

    replay = deque(maxlen=replaySize) # initialize the experience replay buffer

    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(qFunction.parameters(), lr=learning_rate)

    j        = 0
    losses   = []

    # TODO: e == epoch
    for e in range(nGames):

        #progressBar(e, nGames) # Just a progress bar function. Not relevant

        game = environment # Environment

        #candlesticks, portfolio = game.observe()
        state, _, done = game.observe()

        # start playing the game
        i = 0
        while not done:
            # agent taking an action with epsilon-greedy strategy
            if np.random.random() < epsilon:
                action = np.random.choice(actions)
            else:
                # Select the action with the largest Q-value
                qValues = qFunction(state).data.numpy()
                action = Action(np.argmax(qValues))

            # advancing the state of the board
            new_state, reward, done = game.act(action)

            # check if the current state is terminal.
            done = done or (not (i < lGames))

            # saving the (state, action, new_state, reward) values in the experience replay buffer
            replay.append((state, action, new_state, reward, done))

            # when we get to the batch size we train the network
            if len(replay) > batchSize:
                miniBatch     = random.sample(replay, batchSize)
                stateBatch    = torch.cat([sC for sC, _, _, _, _ in miniBatch])
                actionBatch   = torch.Tensor([actionIndex for _, actionIndex, _, _, _ in miniBatch])
                rewardBatch   = torch.Tensor([reward for _, _, _, reward, _ in miniBatch])
                new_stateBatch = torch.cat([converter(new_state) for _, _, new_state, _, _ in miniBatch])
                doneBatch     = torch.Tensor([done for _, _, _, _, done in miniBatch])

                # use the target network to bootstrap
                with torch.no_grad():
                    new_stateQ = qFunctionTarget(new_stateBatch)

                # Actual reward, discounted by gamma (MR)
                # compute the predition of the networks for R + gamma*MaxRewards
                Y = rewardBatch + gamma * (1 - doneBatch)*torch.max(new_stateQ,dim=1)[0]

                # Estimate of future reward (MR)
                # Compute the discounted rewards using the first original network
                stateQ = qFunction(stateBatch).gather(dim=1, index=actionBatch.long().unsqueeze(dim=1)).squeeze()

                # compute the loss of the model and backpropagate
                loss = loss_fn(stateQ, Y.detach())
                optimizer.zero_grad()
                loss.backward()
                # keeping track of the losses
                losses.append(loss.item())
                optimizer.step()

                # periodically copy parameters to the target network
                if j % syncFreq == 0:
                    qFunctionTarget.load_state_dict(qFunction.state_dict())

            # advance the state of the game
            state = new_state

            i += 1
            j += 1 # advance j
    return losses