In [1]:
from tda import auth, client
from dotenv import load_dotenv
from selenium import webdriver
from pathlib import Path
import time
import datetime
import psycopg2
import pandas as pd
import numpy as np
import json
import os

def connect_client(api_key, redirect_uri, token_path, webdriver_func):
    client = auth.easy_client(
        api_key = api_key,
        redirect_uri = REDIRECT_URI,
        token_path = TOKEN_PATH,
        webdriver_func = webdriver_func,
    )
    return client


def get_webdriver(path = None):
    if not path:
        path = Path(__file__).parents[1].absolute()
        path = os.path.join(path, 'chromedriver')
    return webdriver.Chrome(path)

In [2]:
load_dotenv()

# Get .env variables
TD_KEY = os.getenv('CONSUMER_KEY')
ACC_NUMBER = os.getenv('ACC_NUMBER')
REDIRECT_URI = os.getenv('REDIRECT_URI')
TOKEN_PATH = os.path.join('../tokens/token.pickle')
API_KEY = TD_KEY + '@AMER.OAUTHAP'
CLIENT = connect_client(API_KEY, REDIRECT_URI, TOKEN_PATH, get_webdriver)

In [3]:
# Open db connection
conn = psycopg2.connect("dbname = historical user=postgres")
cur = conn.cursor()

In [4]:
# Get list of supported tickers
ticker_list = pd.read_csv('../tda_supported_tickers.csv')

In [5]:
response = CLIENT.get_price_history('SPY',
                                    period_type=client.Client.PriceHistory.PeriodType.YEAR,
                                    period=client.Client.PriceHistory.Period.ONE_YEAR,
                                    frequency_type=client.Client.PriceHistory.FrequencyType.DAILY,
                                    frequency=client.Client.PriceHistory.Frequency.DAILY).json()
print(json.dumps(response, indent = 4))

{
    "candles": [
        {
            "open": 286.14,
            "high": 289.07,
            "low": 285.25,
            "close": 288.89,
            "volume": 59852324,
            "datetime": 1566968400000
        },
        {
            "open": 291.72,
            "high": 293.16,
            "low": 290.61,
            "close": 292.58,
            "volume": 57998913,
            "datetime": 1567054800000
        },
        {
            "open": 294.22,
            "high": 294.2399,
            "low": 291.42,
            "close": 292.45,
            "volume": 62961780,
            "datetime": 1567141200000
        },
        {
            "open": 290.57,
            "high": 291.58,
            "low": 289.27,
            "close": 290.74,
            "volume": 69231875,
            "datetime": 1567486800000
        },
        {
            "open": 293.14,
            "high": 294.055,
            "low": 292.31,
            "close": 294.04,
            "volume": 47003957,
            

In [6]:
ticker_list.head()

Unnamed: 0.1,Unnamed: 0,symbol,description,asset_type,main_type
0,0,QUS,SPDR MSCI USA StrategicFactors ETF,ETF,EQUITY
1,1,GVAL,Cambria Global Value ETF,ETF,EQUITY
2,2,IEUR,iShares Core MSCI Europe ETF,ETF,EQUITY
3,3,ASHR,Xtrackers Harvest CSI 300 China A-Shares ETF,ETF,EQUITY
4,4,ASHS,Xtrackers Harvest CSI 500 China A-Shares Small...,ETF,EQUITY


In [7]:
def commit_batch_wait(cursor, connection, args_str):
    args_str = ','.join(args_str)
    query = "INSERT INTO prices (symbol, datetime, open, high, low, close) VALUES " + args_str + ";"
    cursor.execute(query)
    connection.commit()
    time.sleep(60)
    return
    

In [11]:
batch_count = 0
args_str = list()

for index, row in ticker_list.iterrows():
    batch_count += 1
    
    response = CLIENT.get_price_history(row['symbol'],
                                    period_type=client.Client.PriceHistory.PeriodType.YEAR,
                                    period=client.Client.PriceHistory.Period.ONE_YEAR,
                                    frequency_type=client.Client.PriceHistory.FrequencyType.DAILY,
                                    frequency=client.Client.PriceHistory.Frequency.DAILY).json()
    try:
        for candle in response['candles']:
            date = datetime.datetime.fromtimestamp(candle['datetime'] / 1e3)
            date = date.isoformat()
            entry = cur.mogrify("(%s, %s, %s, %s, %s, %s)", (row['symbol'], date, candle['open'], candle['high'], candle['low'], candle['close']))
            entry = entry.decode('ASCII')
            args_str.append(entry)
    except:
        pass
        
    if batch_count >= 99:
        commit_batch_wait(cur, conn, args_str)
        args_str = list()
        batch_count = 0
        
commit_batch_wait(cur, conn, args_str)

In [30]:
conn.rollback()