Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored tdclient with session handler and exceptions #50

Merged
merged 7 commits into from Mar 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 84 additions & 37 deletions tdameritrade/client.py
@@ -1,17 +1,34 @@
import os
import requests
import pandas as pd
from .session import TDASession
from .urls import ACCOUNTS, INSTRUMENTS, QUOTES, SEARCH, HISTORY, OPTIONCHAIN, MOVERS
from .exceptions import handle_error_response


def response_is_valid(resp):
valid_codes = [200, 201]
return resp.status_code in valid_codes


class TDClient(object):
def __init__(self, access_token=None, accountIds=None):
self._token = access_token or os.environ['ACCESS_TOKEN']
self.accountIds = accountIds or []
self.session = TDASession()
if self._token:
self.session.set_token(self._token)

def _headers(self):
return {'Authorization': 'Bearer ' + self._token}

def _request(self, method, params=None, *args, **kwargs):
resp = self.session.request('GET', method, params=params, *args, **kwargs)
if not response_is_valid(resp):
handle_error_response(resp)

return resp

# TODO: output results to self.accountIds
def accounts(self, positions=False, orders=False):
ret = {}

Expand All @@ -28,28 +45,43 @@ def accounts(self, positions=False, orders=False):

if self.accountIds:
for acc in self.accountIds:
resp = requests.get(ACCOUNTS + str(acc) + fields, headers=self._headers())
if resp.status_code == 200:
ret[acc] = resp.json()
else:
raise Exception(resp.text)
resp = self._request(ACCOUNTS + str(acc) + fields, headers=self._headers())
ret[acc] = resp.json()

else:
resp = requests.get(ACCOUNTS + fields, headers=self._headers())
if resp.status_code == 200:
for account in resp.json():
ret[account['securitiesAccount']['accountId']] = account
else:
raise Exception(resp.text)
resp = self._request(ACCOUNTS + fields, headers=self._headers())
for account in resp.json():
ret[account['securitiesAccount']['accountId']] = account

return ret

def accountsDF(self):
return pd.io.json.json_normalize(self.accounts())
return pd.json_normalize(self.accounts())

def transactions(self, acc=None, type=None, symbol=None, start_date=None, end_date=None):
if acc is None:
acc = self.accounts
transactions = ACCOUNTS + str(acc) + "/transactions"
resp = self._request(transactions,
headers=self._headers(),
params={
'type': type,
'symbol': symbol,
'startDate': start_date,
'endDate': end_date
}).json()

return resp

def transactionsDF(self, acc, **kwargs):
return pd.json_normalize(self.transactions(acc, kwargs))

def search(self, symbol, projection='symbol-search'):
return requests.get(SEARCH,
headers=self._headers(),
params={'symbol': symbol,
'projection': projection}).json()
resp = self._request(SEARCH,
headers=self._headers(),
params={'symbol': symbol,
'projection': projection}).json()
return resp

def searchDF(self, symbol, projection='symbol-search'):
ret = []
Expand All @@ -65,36 +97,40 @@ def fundamentalDF(self, symbol):
return self.searchDF(symbol, 'fundamental')

def instrument(self, cusip):
return requests.get(INSTRUMENTS + str(cusip),
headers=self._headers()).json()
resp = self._request(INSTRUMENTS + str(cusip),
headers=self._headers()).json()
return resp

def instrumentDF(self, cusip):
return pd.DataFrame(self.instrument(cusip))

def quote(self, symbols):
return requests.get(QUOTES,
headers=self._headers(),
params={'symbol': symbols.upper()}).json()
def quote(self, symbol):
resp = self._request(QUOTES,
headers=self._headers(),
params={'symbol': symbol.upper()}).json()
return resp

def quoteDF(self, symbol):
x = self.quote(symbol)
return pd.DataFrame(x).T.reset_index(drop=True)

def history(self, symbol, **kwargs):
return requests.get(HISTORY % symbol,
headers=self._headers(),
params=kwargs).json()
resp = self._request(HISTORY % symbol,
headers=self._headers(),
params=kwargs).json()
return resp

def historyDF(self, symbol, **kwargs):
x = self.history(symbol, **kwargs)
df = pd.DataFrame(x['candles'])
df['datetime'] = pd.to_datetime(df['datetime'], unit='ms')
return df

def options(self, symbol, **kwargs):
return requests.get(OPTIONCHAIN,
headers=self._headers(),
params={'symbol': symbol.upper(), **kwargs}).json()
def options(self, symbol):
resp = self._request(OPTIONCHAIN,
headers=self._headers(),
params={'symbol': symbol.upper()}).json()
return resp

def optionsDF(self, symbol):
ret = []
Expand All @@ -112,12 +148,23 @@ def optionsDF(self, symbol):
return df

def movers(self, index, direction='up', change_type='percent'):
return requests.get(MOVERS % index,
headers=self._headers(),
params={'direction': direction,
'change_type': change_type}).json()
resp = self._request(MOVERS % index,
headers=self._headers(),
params={'direction': direction,
'change_type': change_type}).json()
return resp

def saved_orders(self, account_id, json_order):
saved_orders = ACCOUNTS + account_id + "/savedorders"
resp = self._request(saved_orders,
headers=self._headers(),
json=json_order).json()
return resp

def orders(self, account_id, order):
return requests.post(ACCOUNTS + account_id + "/orders",
def orders(self, account_id, json_order):
orders = ACCOUNTS + account_id + "/orders"
resp = self._request(orders,
headers=self._headers(),
json=order)
json=json_order
).json()
return resp
53 changes: 53 additions & 0 deletions tdameritrade/exceptions.py
@@ -0,0 +1,53 @@

def handle_error_response(resp):
codes = {
400: ValidationError,
401: InvalidAuthToken,
500: ServerError,
403: Forbidden,
404: NotFound,
-1: TDAAPIError
}

raise codes[resp.status_code]()


class TDAAPIError(Exception):
response = None
data = {}
code = -1
message = "An unknown error occurred"

def __init__(self, message=None, code=None, data={}, response=None):
self.response = response
if message:
self.message = message
if code:
self.code = code
if data:
self.data = data

def __str__(self):
if self.code:
return '{}: {}'.format(self.code, self.message)
return self.data


class ValidationError(TDAAPIError):
pass


class InvalidAuthToken(TDAAPIError):
pass


class ServerError(TDAAPIError):
pass


class Forbidden(TDAAPIError):
pass


class NotFound(TDAAPIError):
pass
7 changes: 7 additions & 0 deletions tdameritrade/session.py
@@ -0,0 +1,7 @@
import requests


class TDASession(requests.Session):

def set_token(self, token):
self.headers.update({'Authorization': 'Bearer ' + token})
18 changes: 18 additions & 0 deletions tdameritrade/tests/JSONS.py
@@ -0,0 +1,18 @@
TEST_BUY_MARKET_STOCK = '''
{
"orderType": "MARKET",
"session": "NORMAL",
"duration": "DAY",
"orderStrategyType": "SINGLE",
"orderLegCollection": [
{
"instruction": "Buy",
"quantity": 15,
"instrument": {
"symbol": "XYZ",
"assetType": "EQUITY"
}
}
]
}
'''