From 32a9442644b259eb6b4504468b3531170fd1a6d9 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Mon, 16 Sep 2019 00:22:59 +0100 Subject: [PATCH] MAINT: Use isort and black Use isort to sort imports Use black for format --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- .travis.yml | 8 +- ci/pypi-install.sh | 10 +- pandas_datareader/__init__.py | 76 +- pandas_datareader/_utils.py | 5 +- pandas_datareader/_version.py | 154 ++-- pandas_datareader/av/__init__.py | 61 +- pandas_datareader/av/forex.py | 55 +- pandas_datareader/av/quotes.py | 44 +- pandas_datareader/av/sector.py | 8 +- pandas_datareader/av/time_series.py | 62 +- pandas_datareader/bankofcanada.py | 19 +- pandas_datareader/base.py | 120 +-- pandas_datareader/compat/__init__.py | 40 +- pandas_datareader/conftest.py | 3 +- pandas_datareader/data.py | 462 +++++++---- pandas_datareader/econdb.py | 42 +- pandas_datareader/enigma.py | 46 +- pandas_datareader/eurostat.py | 18 +- pandas_datareader/famafrench.py | 89 +- pandas_datareader/fred.py | 29 +- pandas_datareader/iex/__init__.py | 30 +- pandas_datareader/iex/daily.py | 60 +- pandas_datareader/iex/deep.py | 107 +-- pandas_datareader/iex/market.py | 18 +- pandas_datareader/iex/ref.py | 18 +- pandas_datareader/iex/stats.py | 92 ++- pandas_datareader/iex/tops.py | 35 +- pandas_datareader/io/jsdmx.py | 29 +- pandas_datareader/io/sdmx.py | 78 +- pandas_datareader/io/util.py | 4 +- pandas_datareader/moex.py | 119 +-- pandas_datareader/nasdaq_trader.py | 79 +- pandas_datareader/oecd.py | 14 +- pandas_datareader/quandl.py | 98 ++- pandas_datareader/robinhood.py | 109 ++- pandas_datareader/stooq.py | 21 +- pandas_datareader/tests/io/test_jsdmx.py | 144 +++- pandas_datareader/tests/io/test_sdmx.py | 23 +- pandas_datareader/tests/test_bankofcanada.py | 55 +- pandas_datareader/tests/test_base.py | 8 +- pandas_datareader/tests/test_data.py | 3 +- pandas_datareader/tests/test_econdb.py | 54 +- pandas_datareader/tests/test_enigma.py | 19 +- pandas_datareader/tests/test_famafrench.py | 185 ++++- pandas_datareader/tests/test_fred.py | 35 +- pandas_datareader/tests/test_iex.py | 36 +- pandas_datareader/tests/test_iex_daily.py | 100 ++- pandas_datareader/tests/test_moex.py | 11 +- pandas_datareader/tests/test_nasdaq.py | 9 +- pandas_datareader/tests/test_oecd.py | 257 ++++-- pandas_datareader/tests/test_robinhood.py | 11 +- pandas_datareader/tests/test_stooq.py | 14 +- pandas_datareader/tests/test_tsp.py | 13 +- pandas_datareader/tests/test_wb.py | 269 ++++--- pandas_datareader/tests/yahoo/test_options.py | 102 ++- pandas_datareader/tests/yahoo/test_yahoo.py | 312 ++++--- pandas_datareader/tiingo.py | 133 +-- pandas_datareader/tsp.py | 42 +- pandas_datareader/wb.py | 759 ++++++++++++++---- pandas_datareader/yahoo/actions.py | 20 +- pandas_datareader/yahoo/components.py | 19 +- pandas_datareader/yahoo/daily.py | 170 ++-- pandas_datareader/yahoo/fx.py | 50 +- pandas_datareader/yahoo/options.py | 268 ++++--- pandas_datareader/yahoo/quotes.py | 27 +- setup.cfg | 6 +- 67 files changed, 3478 insertions(+), 1940 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index ec4a6995..010d6dff 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,4 +1,4 @@ - [ ] closes #xxxx - [ ] tests added / passed -- [ ] passes `git diff upstream/master -u -- "*.py" | flake8 --diff` +- [ ] passes `black --check pandas_datareader` - [ ] added entry to docs/source/whatsnew/vLATEST.txt diff --git a/.travis.yml b/.travis.yml index c665049e..ca050bde 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,15 +35,17 @@ matrix: install: - source ci/pypi-install.sh; - - pip install codecov coveralls beautifulsoup4 flake8 - pip list - python setup.py install script: + - if [[ -n "${TEST_TYPE+x}" ]]; then export MARKERS="-m ${TEST_TYPE}"; fi - pytest -s -r xX "${MARKERS}" --cov-config .coveragerc --cov=pandas_datareader --cov-report xml:/tmp/cov-datareader.xml --junitxml=/tmp/datareader.xml - - flake8 --version - - flake8 pandas_datareader + - | + if [[ "$TRAVIS_PYTHON_VERSION" != 2.7 ]]; then + black --check pandas_datareader + fi after_script: - | diff --git a/ci/pypi-install.sh b/ci/pypi-install.sh index 4d2ea62e..c14f9f2c 100644 --- a/ci/pypi-install.sh +++ b/ci/pypi-install.sh @@ -1,15 +1,19 @@ #!/usr/bin/env bash -echo "PyPI install" - pip install pip --upgrade -pip install numpy=="$NUMPY" pytz python-dateutil coverage setuptools html5lib lxml pytest pytest-cov wrapt +pip install numpy=="$NUMPY" pytz python-dateutil coverage setuptools html5lib lxml pytest pytest-cov wrapt codecov coveralls beautifulsoup4 isort + +if [[ "$TRAVIS_PYTHON_VERSION" != 2.7 ]]; then + pip install black +fi + if [[ "$PANDAS" == "MASTER" ]]; then PRE_WHEELS="https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com" pip install --pre --upgrade --timeout=60 -f "$PRE_WHEELS" pandas else pip install pandas=="$PANDAS" fi + if [[ "$DOCBUILD" ]]; then pip install sphinx ipython matplotlib sphinx_rtd_theme doctr fi diff --git a/pandas_datareader/__init__.py b/pandas_datareader/__init__.py index 1246c2d7..67525138 100644 --- a/pandas_datareader/__init__.py +++ b/pandas_datareader/__init__.py @@ -1,24 +1,60 @@ from ._version import get_versions -from .data import (DataReader, Options, get_components_yahoo, - get_dailysummary_iex, get_data_enigma, get_data_famafrench, - get_data_fred, get_data_moex, get_data_quandl, - get_data_stooq, get_data_yahoo, get_data_yahoo_actions, - get_iex_book, get_iex_symbols, get_last_iex, - get_markets_iex, get_nasdaq_symbols, get_quote_yahoo, - get_recent_iex, get_records_iex, get_summary_iex, - get_tops_iex, get_data_tiingo, get_iex_data_tiingo, - get_data_alphavantage) +from .data import ( + DataReader, + Options, + get_components_yahoo, + get_dailysummary_iex, + get_data_alphavantage, + get_data_enigma, + get_data_famafrench, + get_data_fred, + get_data_moex, + get_data_quandl, + get_data_stooq, + get_data_tiingo, + get_data_yahoo, + get_data_yahoo_actions, + get_iex_book, + get_iex_data_tiingo, + get_iex_symbols, + get_last_iex, + get_markets_iex, + get_nasdaq_symbols, + get_quote_yahoo, + get_recent_iex, + get_records_iex, + get_summary_iex, + get_tops_iex, +) -__version__ = get_versions()['version'] +__version__ = get_versions()["version"] del get_versions -__all__ = ['__version__', 'get_components_yahoo', 'get_data_enigma', - 'get_data_famafrench', 'get_data_yahoo', - 'get_data_yahoo_actions', 'get_quote_yahoo', - 'get_iex_book', 'get_iex_symbols', 'get_last_iex', - 'get_markets_iex', 'get_recent_iex', 'get_records_iex', - 'get_summary_iex', 'get_tops_iex', - 'get_nasdaq_symbols', 'get_data_quandl', 'get_data_moex', - 'get_data_fred', 'get_dailysummary_iex', - 'get_data_stooq', 'DataReader', 'Options', - 'get_data_tiingo', 'get_iex_data_tiingo', 'get_data_alphavantage'] +__all__ = [ + "__version__", + "get_components_yahoo", + "get_data_enigma", + "get_data_famafrench", + "get_data_yahoo", + "get_data_yahoo_actions", + "get_quote_yahoo", + "get_iex_book", + "get_iex_symbols", + "get_last_iex", + "get_markets_iex", + "get_recent_iex", + "get_records_iex", + "get_summary_iex", + "get_tops_iex", + "get_nasdaq_symbols", + "get_data_quandl", + "get_data_moex", + "get_data_fred", + "get_dailysummary_iex", + "get_data_stooq", + "DataReader", + "Options", + "get_data_tiingo", + "get_iex_data_tiingo", + "get_data_alphavantage", +] diff --git a/pandas_datareader/_utils.py b/pandas_datareader/_utils.py index b6058620..8f1a5edf 100644 --- a/pandas_datareader/_utils.py +++ b/pandas_datareader/_utils.py @@ -1,7 +1,8 @@ import datetime as dt -import requests from pandas import to_datetime +import requests + from pandas_datareader.compat import is_number @@ -33,7 +34,7 @@ def _sanitize_dates(start, end): if end is None: end = dt.datetime.today() if start > end: - raise ValueError('start must be an earlier date than end') + raise ValueError("start must be an earlier date than end") return start, end diff --git a/pandas_datareader/_version.py b/pandas_datareader/_version.py index 1f0230d8..da7d0e17 100644 --- a/pandas_datareader/_version.py +++ b/pandas_datareader/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -58,17 +57,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -330,8 +355,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -445,11 +469,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -469,9 +495,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -485,8 +515,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -495,13 +524,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for i in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -515,6 +547,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/pandas_datareader/av/__init__.py b/pandas_datareader/av/__init__.py index 808dd490..dfe6147b 100644 --- a/pandas_datareader/av/__init__.py +++ b/pandas_datareader/av/__init__.py @@ -1,11 +1,11 @@ import os -from pandas_datareader.base import _BaseReader -from pandas_datareader._utils import RemoteDataError - import pandas as pd -AV_BASE_URL = 'https://www.alphavantage.co/query' +from pandas_datareader._utils import RemoteDataError +from pandas_datareader.base import _BaseReader + +AV_BASE_URL = "https://www.alphavantage.co/query" class AlphaVantage(_BaseReader): @@ -16,20 +16,36 @@ class AlphaVantage(_BaseReader): ----- See `Alpha Vantage `__ """ - _format = 'json' - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None, api_key=None): - super(AlphaVantage, self).__init__(symbols=symbols, start=start, - end=end, retry_count=retry_count, - pause=pause, session=session) + _format = "json" + + def __init__( + self, + symbols=None, + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + api_key=None, + ): + super(AlphaVantage, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) if api_key is None: - api_key = os.getenv('ALPHAVANTAGE_API_KEY') + api_key = os.getenv("ALPHAVANTAGE_API_KEY") if not api_key or not isinstance(api_key, str): - raise ValueError('The AlphaVantage API key must be provided ' - 'either through the api_key variable or ' - 'through the environment variable ' - 'ALPHAVANTAGE_API_KEY') + raise ValueError( + "The AlphaVantage API key must be provided " + "either through the api_key variable or " + "through the environment variable " + "ALPHAVANTAGE_API_KEY" + ) self.api_key = api_key @property @@ -39,10 +55,7 @@ def url(self): @property def params(self): - return { - 'function': self.function, - 'apikey': self.api_key - } + return {"function": self.function, "apikey": self.api_key} @property def function(self): @@ -56,12 +69,14 @@ def data_key(self): def _read_lines(self, out): try: - df = pd.DataFrame.from_dict(out[self.data_key], orient='index') + df = pd.DataFrame.from_dict(out[self.data_key], orient="index") except KeyError: if "Error Message" in out: - raise ValueError("The requested symbol {} could not be " - "retrived. Check valid ticker" - ".".format(self.symbols)) + raise ValueError( + "The requested symbol {} could not be " + "retrived. Check valid ticker" + ".".format(self.symbols) + ) else: raise RemoteDataError() df = df[sorted(df.columns)] diff --git a/pandas_datareader/av/forex.py b/pandas_datareader/av/forex.py index b4df0577..d0a195b4 100644 --- a/pandas_datareader/av/forex.py +++ b/pandas_datareader/av/forex.py @@ -1,8 +1,7 @@ -from pandas_datareader.av import AlphaVantage +import pandas as pd from pandas_datareader._utils import RemoteDataError - -import pandas as pd +from pandas_datareader.av import AlphaVantage class AVForexReader(AlphaVantage): @@ -27,15 +26,20 @@ class AVForexReader(AlphaVantage): Alpha Vantage API key . If not provided the environmental variable ALPHAVANTAGE_API_KEY is read. The API key is *required*. """ - def __init__(self, symbols=None, retry_count=3, pause=0.1, session=None, - api_key=None): - super(AVForexReader, self).__init__(symbols=symbols, - start=None, end=None, - retry_count=retry_count, - pause=pause, - session=session, - api_key=api_key) + def __init__( + self, symbols=None, retry_count=3, pause=0.1, session=None, api_key=None + ): + + super(AVForexReader, self).__init__( + symbols=symbols, + start=None, + end=None, + retry_count=retry_count, + pause=pause, + session=session, + api_key=api_key, + ) self.from_curr = {} self.to_curr = {} self.optional_params = {} @@ -45,28 +49,27 @@ def __init__(self, symbols=None, retry_count=3, pause=0.1, session=None, self.symbols = symbols try: for pair in self.symbols: - self.from_curr[pair] = pair.split('/')[0] - self.to_curr[pair] = pair.split('/')[1] + self.from_curr[pair] = pair.split("/")[0] + self.to_curr[pair] = pair.split("/")[1] except Exception as e: print(e) - raise ValueError("Please input a currency pair " - "formatted 'FROM/TO' or a list of " - "currency symbols") + raise ValueError( + "Please input a currency pair " + "formatted 'FROM/TO' or a list of " + "currency symbols" + ) @property def function(self): - return 'CURRENCY_EXCHANGE_RATE' + return "CURRENCY_EXCHANGE_RATE" @property def data_key(self): - return 'Realtime Currency Exchange Rate' + return "Realtime Currency Exchange Rate" @property def params(self): - params = { - 'apikey': self.api_key, - 'function': self.function - } + params = {"apikey": self.api_key, "function": self.function} params.update(self.optional_params) return params @@ -74,9 +77,9 @@ def read(self): result = [] for pair in self.symbols: self.optional_params = { - 'from_currency': self.from_curr[pair], - 'to_currency': self.to_curr[pair], - } + "from_currency": self.from_curr[pair], + "to_currency": self.to_curr[pair], + } data = super(AVForexReader, self).read() result.append(data) df = pd.concat(result, axis=1) @@ -85,7 +88,7 @@ def read(self): def _read_lines(self, out): try: - df = pd.DataFrame.from_dict(out[self.data_key], orient='index') + df = pd.DataFrame.from_dict(out[self.data_key], orient="index") except KeyError: raise RemoteDataError() df.sort_index(ascending=True, inplace=True) diff --git a/pandas_datareader/av/quotes.py b/pandas_datareader/av/quotes.py index ae2e1657..8b9b16a6 100644 --- a/pandas_datareader/av/quotes.py +++ b/pandas_datareader/av/quotes.py @@ -1,7 +1,7 @@ -from pandas_datareader.av import AlphaVantage - -import pandas as pd import numpy as np +import pandas as pd + +from pandas_datareader.av import AlphaVantage class AVQuotesReader(AlphaVantage): @@ -22,8 +22,10 @@ class AVQuotesReader(AlphaVantage): session : Session, default None requests.sessions.Session instance to be used """ - def __init__(self, symbols=None, retry_count=3, pause=0.1, session=None, - api_key=None): + + def __init__( + self, symbols=None, retry_count=3, pause=0.1, session=None, api_key=None + ): if isinstance(symbols, str): syms = [symbols] elif isinstance(symbols, list): @@ -31,27 +33,30 @@ def __init__(self, symbols=None, retry_count=3, pause=0.1, session=None, raise ValueError("Up to 100 symbols at once are allowed.") else: syms = symbols - super(AVQuotesReader, self).__init__(symbols=syms, - start=None, end=None, - retry_count=retry_count, - pause=pause, - session=session, - api_key=api_key) + super(AVQuotesReader, self).__init__( + symbols=syms, + start=None, + end=None, + retry_count=retry_count, + pause=pause, + session=session, + api_key=api_key, + ) @property def function(self): - return 'BATCH_STOCK_QUOTES' + return "BATCH_STOCK_QUOTES" @property def data_key(self): - return 'Stock Quotes' + return "Stock Quotes" @property def params(self): return { - 'symbols': ','.join(self.symbols), - 'function': self.function, - 'apikey': self.api_key, + "symbols": ",".join(self.symbols), + "function": self.function, + "apikey": self.api_key, } def _read_lines(self, out): @@ -61,14 +66,13 @@ def _read_lines(self, out): df = pd.DataFrame(quote, index=[0]) df.columns = [col[3:] for col in df.columns] df.set_index("symbol", inplace=True) - df["price"] = df["price"].astype('float64') + df["price"] = df["price"].astype("float64") try: - df["volume"] = df["volume"].astype('int64') + df["volume"] = df["volume"].astype("int64") except ValueError: df["volume"] = [np.nan * len(self.symbols)] result.append(df) if len(result) != len(self.symbols): - raise ValueError("Not all symbols downloaded. Check valid " - "ticker(s).") + raise ValueError("Not all symbols downloaded. Check valid " "ticker(s).") else: return pd.concat(result) diff --git a/pandas_datareader/av/sector.py b/pandas_datareader/av/sector.py index 43a2f3c8..46d445ca 100644 --- a/pandas_datareader/av/sector.py +++ b/pandas_datareader/av/sector.py @@ -1,7 +1,7 @@ import pandas as pd -from pandas_datareader.av import AlphaVantage from pandas_datareader._utils import RemoteDataError +from pandas_datareader.av import AlphaVantage class AVSectorPerformanceReader(AlphaVantage): @@ -25,9 +25,10 @@ class AVSectorPerformanceReader(AlphaVantage): Alpha Vantage API key . If not provided the environmental variable ALPHAVANTAGE_API_KEY is read. The API key is *required*. """ + @property def function(self): - return 'SECTOR' + return "SECTOR" def _read_lines(self, out): if "Information" in out: @@ -35,7 +36,6 @@ def _read_lines(self, out): else: out.pop("Meta Data") df = pd.DataFrame(out) - columns = ["RT", "1D", "5D", "1M", "3M", "YTD", "1Y", "3Y", "5Y", - "10Y"] + columns = ["RT", "1D", "5D", "1M", "3M", "YTD", "1Y", "3Y", "5Y", "10Y"] df.columns = columns return df diff --git a/pandas_datareader/av/time_series.py b/pandas_datareader/av/time_series.py index 13b6ef2d..8513bdad 100644 --- a/pandas_datareader/av/time_series.py +++ b/pandas_datareader/av/time_series.py @@ -1,7 +1,7 @@ -from pandas_datareader.av import AlphaVantage - from datetime import datetime +from pandas_datareader.av import AlphaVantage + class AVTimeSeriesReader(AlphaVantage): """ @@ -29,24 +29,38 @@ class AVTimeSeriesReader(AlphaVantage): AlphaVantage API key . If not provided the environmental variable ALPHAVANTAGE_API_KEY is read. The API key is *required*. """ + _FUNC_TO_DATA_KEY = { - "TIME_SERIES_DAILY": "Time Series (Daily)", - "TIME_SERIES_DAILY_ADJUSTED": "Time Series (Daily)", - "TIME_SERIES_WEEKLY": "Weekly Time Series", - "TIME_SERIES_WEEKLY_ADJUSTED": "Weekly Adjusted Time Series", - "TIME_SERIES_MONTHLY": "Monthly Time Series", - "TIME_SERIES_MONTHLY_ADJUSTED": "Monthly Adjusted Time Series", - "TIME_SERIES_INTRADAY": "Time Series (1min)" + "TIME_SERIES_DAILY": "Time Series (Daily)", + "TIME_SERIES_DAILY_ADJUSTED": "Time Series (Daily)", + "TIME_SERIES_WEEKLY": "Weekly Time Series", + "TIME_SERIES_WEEKLY_ADJUSTED": "Weekly Adjusted Time Series", + "TIME_SERIES_MONTHLY": "Monthly Time Series", + "TIME_SERIES_MONTHLY_ADJUSTED": "Monthly Adjusted Time Series", + "TIME_SERIES_INTRADAY": "Time Series (1min)", } - def __init__(self, symbols=None, function="TIME_SERIES_DAILY", - start=None, end=None, retry_count=3, pause=0.1, - session=None, chunksize=25, api_key=None): - super(AVTimeSeriesReader, self).__init__(symbols=symbols, start=start, - end=end, - retry_count=retry_count, - pause=pause, session=session, - api_key=api_key) + def __init__( + self, + symbols=None, + function="TIME_SERIES_DAILY", + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + chunksize=25, + api_key=None, + ): + super(AVTimeSeriesReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=api_key, + ) self._func = function @@ -60,7 +74,7 @@ def output_size(self): possible. """ delta = datetime.now() - self.start - return 'full' if delta.days > 80 else 'compact' + return "full" if delta.days > 80 else "compact" @property def data_key(self): @@ -72,7 +86,7 @@ def params(self): "symbol": self.symbols, "function": self.function, "apikey": self.api_key, - "outputsize": self.output_size + "outputsize": self.output_size, } if self.function == "TIME_SERIES_INTRADAY": p.update({"interval": "1min"}) @@ -82,15 +96,15 @@ def _read_lines(self, out): data = super(AVTimeSeriesReader, self)._read_lines(out) # reverse since alphavantage returns descending by date data = data[::-1] - start_str = self.start.strftime('%Y-%m-%d') - end_str = self.end.strftime('%Y-%m-%d') + start_str = self.start.strftime("%Y-%m-%d") + end_str = self.end.strftime("%Y-%m-%d") data = data.loc[start_str:end_str] if data.empty: raise ValueError("Please input a valid date range") else: for column in data.columns: - if column == 'volume': - data[column] = data[column].astype('int64') + if column == "volume": + data[column] = data[column].astype("int64") else: - data[column] = data[column].astype('float64') + data[column] = data[column].astype("float64") return data diff --git a/pandas_datareader/bankofcanada.py b/pandas_datareader/bankofcanada.py index c8ca75b1..0b0ea004 100644 --- a/pandas_datareader/bankofcanada.py +++ b/pandas_datareader/bankofcanada.py @@ -1,8 +1,7 @@ from __future__ import unicode_literals -from pandas_datareader.compat import string_types - from pandas_datareader.base import _BaseReader +from pandas_datareader.compat import string_types class BankOfCanadaReader(_BaseReader): @@ -12,26 +11,28 @@ class BankOfCanadaReader(_BaseReader): ----- See `Bank of Canada `__""" - _URL = 'http://www.bankofcanada.ca/valet/observations' + _URL = "http://www.bankofcanada.ca/valet/observations" @property def url(self): """API URL""" if not isinstance(self.symbols, string_types): - raise ValueError('data name must be string') + raise ValueError("data name must be string") - return '{0}/{1}/csv'.format(self._URL, self.symbols) + return "{0}/{1}/csv".format(self._URL, self.symbols) @property def params(self): """Parameters to use in API calls""" - return {'start_date': self.start.strftime('%Y-%m-%d'), - 'end_date': self.end.strftime('%Y-%m-%d')} + return { + "start_date": self.start.strftime("%Y-%m-%d"), + "end_date": self.end.strftime("%Y-%m-%d"), + } @staticmethod def _sanitize_response(response): """ Clean up the response string """ - data = response.text.split('OBSERVATIONS')[1] - return data.split('ERRORS')[0].strip() + data = response.text.split("OBSERVATIONS")[1] + return data.split("ERRORS")[0].strip() diff --git a/pandas_datareader/base.py b/pandas_datareader/base.py index 7e63839d..ddac3ce4 100644 --- a/pandas_datareader/base.py +++ b/pandas_datareader/base.py @@ -1,17 +1,18 @@ import time import warnings -import numpy as np - -import requests -from pandas import DataFrame -from pandas import read_csv, concat +import numpy as np +from pandas import DataFrame, concat, read_csv from pandas.io.common import urlencode -from pandas_datareader.compat import bytes_to_str, string_types, binary_type, \ - StringIO +import requests -from pandas_datareader._utils import (RemoteDataError, SymbolWarning, - _sanitize_dates, _init_session) +from pandas_datareader._utils import ( + RemoteDataError, + SymbolWarning, + _init_session, + _sanitize_dates, +) +from pandas_datareader.compat import StringIO, binary_type, bytes_to_str, string_types class _BaseReader(object): @@ -36,10 +37,19 @@ class _BaseReader(object): """ _chunk_size = 1024 * 1024 - _format = 'string' - - def __init__(self, symbols, start=None, end=None, retry_count=3, - pause=0.1, timeout=30, session=None, freq=None): + _format = "string" + + def __init__( + self, + symbols, + start=None, + end=None, + retry_count=3, + pause=0.1, + timeout=30, + session=None, + freq=None, + ): self.symbols = symbols @@ -80,9 +90,9 @@ def read(self): def _read_one_data(self, url, params): """ read one data from specified URL """ - if self._format == 'string': + if self._format == "string": out = self._read_url_as_StringIO(url, params=params) - elif self._format == 'json': + elif self._format == "json": out = self._get_response(url, params=params).json() else: raise NotImplementedError(self._format) @@ -97,8 +107,10 @@ def _read_url_as_StringIO(self, url, params=None): out = StringIO() if len(text) == 0: service = self.__class__.__name__ - raise IOError("{} request returned no data; check URL for invalid " - "inputs: {}".format(service, self.url)) + raise IOError( + "{} request returned no data; check URL for invalid " + "inputs: {}".format(service, self.url) + ) if isinstance(text, binary_type): out.write(bytes_to_str(text)) else: @@ -125,11 +137,9 @@ def _get_response(self, url, params=None, headers=None): # initial attempt + retry pause = self.pause - last_response_text = '' + last_response_text = "" for i in range(self.retry_count + 1): - response = self.session.get(url, - params=params, - headers=headers) + response = self.session.get(url, params=params, headers=headers) if response.status_code == requests.codes.ok: return response @@ -140,8 +150,8 @@ def _get_response(self, url, params=None, headers=None): # Increase time between subsequent requests, per subclass. pause *= self.pause_multiplier # Get a new breadcrumb if necessary, in case ours is invalidated - if isinstance(params, list) and 'crumb' in params: - params['crumb'] = self._get_crumb(self.retry_count) + if isinstance(params, list) and "crumb" in params: + params["crumb"] = self._get_crumb(self.retry_count) # If our output error function returns True, exit the loop. if self._output_error(response): @@ -149,9 +159,9 @@ def _get_response(self, url, params=None, headers=None): if params is not None and len(params) > 0: url = url + "?" + urlencode(params) - msg = 'Unable to read URL: {0}'.format(url) + msg = "Unable to read URL: {0}".format(url) if last_response_text: - msg += '\nResponse Text:\n{0}'.format(last_response_text) + msg += "\nResponse Text:\n{0}".format(last_response_text) raise RemoteDataError(msg) @@ -169,8 +179,7 @@ def _output_error(self, out): return False def _read_lines(self, out): - rs = read_csv(out, index_col=0, parse_dates=True, - na_values=('-', 'null'))[::-1] + rs = read_csv(out, index_col=0, parse_dates=True, na_values=("-", "null"))[::-1] # Needed to remove blank space character in header names rs.columns = list(map(lambda x: x.strip(), rs.columns.values.tolist())) @@ -180,11 +189,12 @@ def _read_lines(self, out): rs = rs[:-1] # Get rid of unicode characters in index name. try: - rs.index.name = rs.index.name.decode( - 'unicode_escape').encode('ascii', 'ignore') + rs.index.name = rs.index.name.decode("unicode_escape").encode( + "ascii", "ignore" + ) except AttributeError: # Python 3 string has no decode method. - rs.index.name = rs.index.name.encode('ascii', 'ignore').decode() + rs.index.name = rs.index.name.encode("ascii", "ignore").decode() return rs @@ -192,12 +202,24 @@ def _read_lines(self, out): class _DailyBaseReader(_BaseReader): """ Base class for Google / Yahoo daily reader """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None, chunksize=25): - super(_DailyBaseReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + def __init__( + self, + symbols=None, + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + chunksize=25, + ): + super(_DailyBaseReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) self.chunksize = chunksize def _get_params(self, *args, **kwargs): @@ -207,8 +229,7 @@ def read(self): """Read data""" # If a single symbol, (e.g., 'GOOG') if isinstance(self.symbols, (string_types, int)): - df = self._read_one_data(self.url, - params=self._get_params(self.symbols)) + df = self._read_one_data(self.url, params=self._get_params(self.symbols)) # Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) elif isinstance(self.symbols, DataFrame): df = self._dl_mult_symbols(self.symbols.index) @@ -223,11 +244,10 @@ def _dl_mult_symbols(self, symbols): for sym_group in _in_chunks(symbols, self.chunksize): for sym in sym_group: try: - stocks[sym] = self._read_one_data(self.url, - self._get_params(sym)) + stocks[sym] = self._read_one_data(self.url, self._get_params(sym)) passed.append(sym) except IOError: - msg = 'Failed to read symbol: {0!r}, replacing with NaN.' + msg = "Failed to read symbol: {0!r}, replacing with NaN." warnings.warn(msg.format(sym), SymbolWarning) failed.append(sym) @@ -241,7 +261,7 @@ def _dl_mult_symbols(self, symbols): for sym in failed: stocks[sym] = df_na result = concat(stocks).unstack(level=0) - result.columns.names = ['Attributes', 'Symbols'] + result.columns.names = ["Attributes", "Symbols"] return result except AttributeError: # cannot construct a panel with just 1D nans indicating no data @@ -253,16 +273,14 @@ def _in_chunks(seq, size): """ Return sequence in 'chunks' of size defined by size """ - return (seq[pos:pos + size] for pos in range(0, len(seq), size)) + return (seq[pos : pos + size] for pos in range(0, len(seq), size)) class _OptionBaseReader(_BaseReader): - def __init__(self, symbol, session=None): """ Instantiates options_data with a ticker saved as symbol """ self.symbol = symbol.upper() - super(_OptionBaseReader, self).__init__(symbols=symbol, - session=session) + super(_OptionBaseReader, self).__init__(symbols=symbol, session=session) def get_options_data(self, month=None, year=None, expiry=None): """ @@ -288,16 +306,18 @@ def get_put_data(self, month=None, year=None, expiry=None): """ raise NotImplementedError - def get_near_stock_price(self, above_below=2, call=True, put=False, - month=None, year=None, expiry=None): + def get_near_stock_price( + self, above_below=2, call=True, put=False, month=None, year=None, expiry=None + ): """ ***Experimental*** Returns a data frame of options that are near the current stock price. """ raise NotImplementedError - def get_forward_data(self, months, call=True, put=False, near=False, - above_below=2): # pragma: no cover + def get_forward_data( + self, months, call=True, put=False, near=False, above_below=2 + ): # pragma: no cover """ ***Experimental*** Gets either call, put, or both data for months starting in the current diff --git a/pandas_datareader/compat/__init__.py b/pandas_datareader/compat/__init__.py index 1d364077..a342b0e7 100644 --- a/pandas_datareader/compat/__init__.py +++ b/pandas_datareader/compat/__init__.py @@ -1,17 +1,18 @@ # flake8: noqa +from distutils.version import LooseVersion +import sys + import pandas as pd import pandas.io.common as com -import sys -from distutils.version import LooseVersion PY3 = sys.version_info >= (3, 0) PANDAS_VERSION = LooseVersion(pd.__version__) -PANDAS_0190 = (PANDAS_VERSION >= LooseVersion('0.19.0')) -PANDAS_0200 = (PANDAS_VERSION >= LooseVersion('0.20.0')) -PANDAS_0210 = (PANDAS_VERSION >= LooseVersion('0.21.0')) -PANDAS_0230 = (PANDAS_VERSION >= LooseVersion('0.23.0')) +PANDAS_0190 = PANDAS_VERSION >= LooseVersion("0.19.0") +PANDAS_0200 = PANDAS_VERSION >= LooseVersion("0.20.0") +PANDAS_0210 = PANDAS_VERSION >= LooseVersion("0.21.0") +PANDAS_0230 = PANDAS_VERSION >= LooseVersion("0.23.0") if PANDAS_0190: from pandas.api.types import is_number @@ -23,8 +24,7 @@ if PANDAS_0200: from pandas.util.testing import assert_raises_regex - def get_filepath_or_buffer(filepath_or_buffer, encoding=None, - compression=None): + def get_filepath_or_buffer(filepath_or_buffer, encoding=None, compression=None): # Dictionaries are no longer considered valid inputs # for "get_filepath_or_buffer" starting in pandas >= 0.20.0 @@ -32,9 +32,13 @@ def get_filepath_or_buffer(filepath_or_buffer, encoding=None, return filepath_or_buffer, encoding, compression return com.get_filepath_or_buffer( - filepath_or_buffer, encoding=encoding, compression=None) + filepath_or_buffer, encoding=encoding, compression=None + ) + + else: from pandas.util.testing import assertRaisesRegexp as assert_raises_regex + get_filepath_or_buffer = com.get_filepath_or_buffer if PANDAS_0190: @@ -46,28 +50,28 @@ def get_filepath_or_buffer(filepath_or_buffer, encoding=None, from urllib.error import HTTPError from functools import reduce - string_types = str, + string_types = (str,) binary_type = bytes from io import StringIO def str_to_bytes(s, encoding=None): - return s.encode(encoding or 'ascii') - + return s.encode(encoding or "ascii") def bytes_to_str(b, encoding=None): - return b.decode(encoding or 'utf-8') + return b.decode(encoding or "utf-8") + + else: from urllib2 import HTTPError from cStringIO import StringIO + reduce = reduce binary_type = str - string_types = basestring, - + string_types = (basestring,) def bytes_to_str(b, encoding=None): return b - def str_to_bytes(s, encoding=None): return s @@ -84,6 +88,6 @@ def concat(*args, **kwargs): """ Shim to wokr around sort keyword """ - if not PANDAS_0230 and 'sort' in kwargs: - del kwargs['sort'] + if not PANDAS_0230 and "sort" in kwargs: + del kwargs["sort"] return pd.concat(*args, **kwargs) diff --git a/pandas_datareader/conftest.py b/pandas_datareader/conftest.py index 79f3bc66..df71d4bd 100644 --- a/pandas_datareader/conftest.py +++ b/pandas_datareader/conftest.py @@ -21,7 +21,7 @@ def datapath(request): ValueError If the path doesn't exist and the --strict-data-files option is set. """ - BASE_PATH = os.path.join(os.path.dirname(__file__), 'tests') + BASE_PATH = os.path.join(os.path.dirname(__file__), "tests") def deco(*args): path = os.path.join(BASE_PATH, *args) @@ -33,4 +33,5 @@ def deco(*args): msg = "Could not find {}." pytest.skip(msg.format(path)) return path + return deco diff --git a/pandas_datareader/data.py b/pandas_datareader/data.py index b7b80abb..6ff45f16 100644 --- a/pandas_datareader/data.py +++ b/pandas_datareader/data.py @@ -12,39 +12,54 @@ from pandas_datareader.econdb import EcondbReader from pandas_datareader.enigma import EnigmaReader from pandas_datareader.eurostat import EurostatReader -from pandas_datareader.exceptions import DEP_ERROR_MSG, \ - ImmediateDeprecationError +from pandas_datareader.exceptions import DEP_ERROR_MSG, ImmediateDeprecationError from pandas_datareader.famafrench import FamaFrenchReader from pandas_datareader.fred import FredReader from pandas_datareader.iex.daily import IEXDailyReader from pandas_datareader.iex.deep import Deep as IEXDeep -from pandas_datareader.iex.tops import LastReader as IEXLasts, \ - TopsReader as IEXTops +from pandas_datareader.iex.tops import LastReader as IEXLasts, TopsReader as IEXTops from pandas_datareader.moex import MoexReader from pandas_datareader.nasdaq_trader import get_nasdaq_symbols from pandas_datareader.oecd import OECDReader from pandas_datareader.quandl import QuandlReader -from pandas_datareader.robinhood import RobinhoodHistoricalReader, \ - RobinhoodQuoteReader +from pandas_datareader.robinhood import RobinhoodHistoricalReader, RobinhoodQuoteReader from pandas_datareader.stooq import StooqDailyReader -from pandas_datareader.tiingo import (TiingoDailyReader, TiingoQuoteReader, - TiingoIEXHistoricalReader) -from pandas_datareader.yahoo.actions import (YahooActionReader, YahooDivReader) -from pandas_datareader.yahoo.components import _get_data as \ - get_components_yahoo +from pandas_datareader.tiingo import ( + TiingoDailyReader, + TiingoIEXHistoricalReader, + TiingoQuoteReader, +) +from pandas_datareader.yahoo.actions import YahooActionReader, YahooDivReader +from pandas_datareader.yahoo.components import _get_data as get_components_yahoo from pandas_datareader.yahoo.daily import YahooDailyReader from pandas_datareader.yahoo.options import Options as YahooOptions from pandas_datareader.yahoo.quotes import YahooQuotesReader -__all__ = ['get_components_yahoo', 'get_data_enigma', 'get_data_famafrench', - 'get_data_fred', 'get_data_moex', - 'get_data_quandl', 'get_data_yahoo', 'get_data_yahoo_actions', - 'get_nasdaq_symbols', 'get_quote_yahoo', - 'get_tops_iex', 'get_summary_iex', 'get_records_iex', - 'get_recent_iex', 'get_markets_iex', 'get_last_iex', - 'get_iex_symbols', 'get_iex_book', 'get_dailysummary_iex', - 'get_data_stooq', 'get_data_robinhood', - 'get_quotes_robinhood', 'DataReader'] +__all__ = [ + "get_components_yahoo", + "get_data_enigma", + "get_data_famafrench", + "get_data_fred", + "get_data_moex", + "get_data_quandl", + "get_data_yahoo", + "get_data_yahoo_actions", + "get_nasdaq_symbols", + "get_quote_yahoo", + "get_tops_iex", + "get_summary_iex", + "get_records_iex", + "get_recent_iex", + "get_markets_iex", + "get_last_iex", + "get_iex_symbols", + "get_iex_book", + "get_dailysummary_iex", + "get_data_stooq", + "get_data_robinhood", + "get_quotes_robinhood", + "DataReader", +] def get_data_alphavantage(*args, **kwargs): @@ -139,6 +154,7 @@ def get_markets_iex(*args, **kwargs): :return: DataFrame """ from pandas_datareader.iex.market import MarketReader + return MarketReader(*args, **kwargs).read() @@ -157,6 +173,7 @@ def get_dailysummary_iex(*args, **kwargs): :return: DataFrame """ from pandas_datareader.iex.stats import DailySummaryReader + return DailySummaryReader(*args, **kwargs).read() @@ -175,6 +192,7 @@ def get_summary_iex(*args, **kwargs): :return: DataFrame """ from pandas_datareader.iex.stats import MonthlySummaryReader + return MonthlySummaryReader(*args, **kwargs).read() @@ -189,6 +207,7 @@ def get_records_iex(*args, **kwargs): :return: DataFrame """ from pandas_datareader.iex.stats import RecordsReader + return RecordsReader(*args, **kwargs).read() @@ -203,6 +222,7 @@ def get_recent_iex(*args, **kwargs): :return: DataFrame """ from pandas_datareader.iex.stats import RecentReader + return RecentReader(*args, **kwargs).read() @@ -216,6 +236,7 @@ def get_iex_symbols(*args, **kwargs): :return: DataFrame """ from pandas_datareader.iex.ref import SymbolsReader + return SymbolsReader(*args, **kwargs).read() @@ -241,8 +262,16 @@ def get_iex_book(*args, **kwargs): return IEXDeep(*args, **kwargs).read() -def DataReader(name, data_source=None, start=None, end=None, - retry_count=3, pause=0.1, session=None, access_key=None): +def DataReader( + name, + data_source=None, + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + access_key=None, +): """ Imports data from a number of online sources. @@ -291,170 +320,327 @@ def DataReader(name, data_source=None, start=None, end=None, ff = DataReader("6_Portfolios_2x3", "famafrench") ff = DataReader("F-F_ST_Reversal_Factor", "famafrench") """ - expected_source = ["yahoo", "iex", "iex-tops", "iex-last", - "iex-last", "bankofcanada", "stooq", "iex-book", - "enigma", "fred", "famafrench", "oecd", "eurostat", - "nasdaq", "quandl", "moex", 'robinhood', - "tiingo", "yahoo-actions", "yahoo-dividends", - "av-forex", "av-daily", "av-daily-adjusted", - "av-weekly", "av-weekly-adjusted", "av-monthly", - "av-monthly-adjusted", "av-intraday", "econdb"] + expected_source = [ + "yahoo", + "iex", + "iex-tops", + "iex-last", + "iex-last", + "bankofcanada", + "stooq", + "iex-book", + "enigma", + "fred", + "famafrench", + "oecd", + "eurostat", + "nasdaq", + "quandl", + "moex", + "robinhood", + "tiingo", + "yahoo-actions", + "yahoo-dividends", + "av-forex", + "av-daily", + "av-daily-adjusted", + "av-weekly", + "av-weekly-adjusted", + "av-monthly", + "av-monthly-adjusted", + "av-intraday", + "econdb", + ] if data_source not in expected_source: msg = "data_source=%r is not implemented" % data_source raise NotImplementedError(msg) if data_source == "yahoo": - return YahooDailyReader(symbols=name, start=start, end=end, - adjust_price=False, chunksize=25, - retry_count=retry_count, pause=pause, - session=session).read() + return YahooDailyReader( + symbols=name, + start=start, + end=end, + adjust_price=False, + chunksize=25, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "iex": - return IEXDailyReader(symbols=name, start=start, end=end, - chunksize=25, api_key=access_key, - retry_count=retry_count, pause=pause, - session=session).read() + return IEXDailyReader( + symbols=name, + start=start, + end=end, + chunksize=25, + api_key=access_key, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "iex-tops": - return IEXTops(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return IEXTops( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "iex-last": - return IEXLasts(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return IEXLasts( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "bankofcanada": - return BankOfCanadaReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return BankOfCanadaReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "stooq": - return StooqDailyReader(symbols=name, - chunksize=25, - retry_count=retry_count, pause=pause, - session=session).read() + return StooqDailyReader( + symbols=name, + chunksize=25, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "iex-book": - return IEXDeep(symbols=name, service="book", start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return IEXDeep( + symbols=name, + service="book", + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "enigma": return EnigmaReader(dataset_id=name, api_key=access_key).read() elif data_source == "fred": - return FredReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return FredReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "famafrench": - return FamaFrenchReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return FamaFrenchReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "oecd": - return OECDReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return OECDReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "eurostat": - return EurostatReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() - elif data_source == 'nasdaq': - if name != 'symbols': - raise ValueError("Only the string 'symbols' is supported for " - "Nasdaq, not %r" % (name,)) + return EurostatReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() + elif data_source == "nasdaq": + if name != "symbols": + raise ValueError( + "Only the string 'symbols' is supported for " "Nasdaq, not %r" % (name,) + ) return get_nasdaq_symbols(retry_count=retry_count, pause=pause) elif data_source == "quandl": - return QuandlReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session, api_key=access_key).read() + return QuandlReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "moex": - return MoexReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() - elif data_source == 'robinhood': - return RobinhoodHistoricalReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() - elif data_source == 'tiingo': - return TiingoDailyReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session, - api_key=access_key).read() + return MoexReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() + elif data_source == "robinhood": + return RobinhoodHistoricalReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() + elif data_source == "tiingo": + return TiingoDailyReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "yahoo-actions": - return YahooActionReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return YahooActionReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() elif data_source == "yahoo-dividends": - return YahooDivReader(symbols=name, start=start, end=end, - adjust_price=False, chunksize=25, - retry_count=retry_count, pause=pause, - session=session, interval='d').read() + return YahooDivReader( + symbols=name, + start=start, + end=end, + adjust_price=False, + chunksize=25, + retry_count=retry_count, + pause=pause, + session=session, + interval="d", + ).read() elif data_source == "av-forex": - return AVForexReader(symbols=name, retry_count=retry_count, - pause=pause, session=session, - api_key=access_key).read() + return AVForexReader( + symbols=name, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "av-daily": - return AVTimeSeriesReader(symbols=name, - function="TIME_SERIES_DAILY", start=start, - end=end, retry_count=retry_count, - pause=pause, session=session, - api_key=access_key).read() + return AVTimeSeriesReader( + symbols=name, + function="TIME_SERIES_DAILY", + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "av-daily-adjusted": - return AVTimeSeriesReader(symbols=name, - function="TIME_SERIES_DAILY_ADJUSTED", - start=start, end=end, - retry_count=retry_count, pause=pause, - session=session, api_key=access_key).read() + return AVTimeSeriesReader( + symbols=name, + function="TIME_SERIES_DAILY_ADJUSTED", + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "av-weekly": - return AVTimeSeriesReader(symbols=name, - function="TIME_SERIES_WEEKLY", start=start, - end=end, retry_count=retry_count, - pause=pause, session=session, - api_key=access_key).read() + return AVTimeSeriesReader( + symbols=name, + function="TIME_SERIES_WEEKLY", + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "av-weekly-adjusted": - return AVTimeSeriesReader(symbols=name, - function="TIME_SERIES_WEEKLY_ADJUSTED", - start=start, end=end, - retry_count=retry_count, pause=pause, - session=session, api_key=access_key).read() + return AVTimeSeriesReader( + symbols=name, + function="TIME_SERIES_WEEKLY_ADJUSTED", + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "av-monthly": - return AVTimeSeriesReader(symbols=name, - function="TIME_SERIES_MONTHLY", start=start, - end=end, retry_count=retry_count, - pause=pause, session=session, - api_key=access_key).read() + return AVTimeSeriesReader( + symbols=name, + function="TIME_SERIES_MONTHLY", + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "av-monthly-adjusted": - return AVTimeSeriesReader(symbols=name, - function="TIME_SERIES_MONTHLY_ADJUSTED", - start=start, end=end, - retry_count=retry_count, pause=pause, - session=session, api_key=access_key).read() + return AVTimeSeriesReader( + symbols=name, + function="TIME_SERIES_MONTHLY_ADJUSTED", + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "av-intraday": - return AVTimeSeriesReader(symbols=name, - function="TIME_SERIES_INTRADAY", - start=start, end=end, - retry_count=retry_count, pause=pause, - session=session, api_key=access_key).read() + return AVTimeSeriesReader( + symbols=name, + function="TIME_SERIES_INTRADAY", + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=access_key, + ).read() elif data_source == "econdb": - return EcondbReader(symbols=name, start=start, end=end, - retry_count=retry_count, pause=pause, - session=session).read() + return EcondbReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ).read() else: msg = "data_source=%r is not implemented" % data_source @@ -463,11 +649,15 @@ def DataReader(name, data_source=None, start=None, end=None, def Options(symbol, data_source=None, session=None): if data_source is None: - warnings.warn("Options(symbol) is deprecated, use Options(symbol," - " data_source) instead", FutureWarning, stacklevel=2) + warnings.warn( + "Options(symbol) is deprecated, use Options(symbol," + " data_source) instead", + FutureWarning, + stacklevel=2, + ) data_source = "yahoo" if data_source == "yahoo": - raise ImmediateDeprecationError(DEP_ERROR_MSG.format('Yahoo Options')) + raise ImmediateDeprecationError(DEP_ERROR_MSG.format("Yahoo Options")) return YahooOptions(symbol, session=session) else: raise NotImplementedError("currently only yahoo supported") diff --git a/pandas_datareader/econdb.py b/pandas_datareader/econdb.py index 99312f2f..11277488 100644 --- a/pandas_datareader/econdb.py +++ b/pandas_datareader/econdb.py @@ -7,44 +7,50 @@ class EcondbReader(_BaseReader): """Get data for the given name from Econdb.""" - _URL = 'https://www.econdb.com/api/series/' + _URL = "https://www.econdb.com/api/series/" _format = None - _show = 'labels' + _show = "labels" @property def url(self): """API URL""" if not isinstance(self.symbols, str): - raise ValueError('data name must be string') + raise ValueError("data name must be string") - return ('{0}?{1}&format=json&page_size=500&expand=both' - .format(self._URL, self.symbols)) + return "{0}?{1}&format=json&page_size=500&expand=both".format( + self._URL, self.symbols + ) def read(self): """ read one data from specified URL """ - results = requests.get(self.url).json()['results'] - df = pd.DataFrame({'dates': []}).set_index('dates') + results = requests.get(self.url).json()["results"] + df = pd.DataFrame({"dates": []}).set_index("dates") - if self._show == 'labels': - def show_func(x): return x.split(':')[1] - elif self._show == 'codes': - def show_func(x): return x.split(':')[0] + if self._show == "labels": + + def show_func(x): + return x.split(":")[1] + + elif self._show == "codes": + + def show_func(x): + return x.split(":")[0] for entry in results: - series = (pd.DataFrame(entry['data'])[['dates', 'values']] - .set_index('dates')) + series = pd.DataFrame(entry["data"])[["dates", "values"]].set_index("dates") - head = entry['additional_metadata'] + head = entry["additional_metadata"] if head != "": # this additional metadata is not blank series.columns = pd.MultiIndex.from_tuples( [[show_func(x) for x in head.values()]], - names=[show_func(x) for x in head.keys()]) + names=[show_func(x) for x in head.keys()], + ) if not df.empty: - df = df.join(series, how='outer') + df = df.join(series, how="outer") else: df = series - df.index = pd.to_datetime(df.index, errors='ignore') - df.index.name = 'TIME_PERIOD' + df.index = pd.to_datetime(df.index, errors="ignore") + df.index.name = "TIME_PERIOD" df = df.truncate(self.start, self.end) return df diff --git a/pandas_datareader/enigma.py b/pandas_datareader/enigma.py index 6513b223..d2442293 100644 --- a/pandas_datareader/enigma.py +++ b/pandas_datareader/enigma.py @@ -42,24 +42,29 @@ class EnigmaReader(_BaseReader): >>> df = EnigmaReader(dataset_id='bedaf052-5fcd-4758-8d27-048ce8746c6a', ... api_key='INSERT_API_KEY').read() """ - def __init__(self, - dataset_id=None, - api_key=None, - retry_count=5, - pause=.75, - session=None, - base_url="https://public.enigma.com/api"): - - super(EnigmaReader, self).__init__(symbols=[], - retry_count=retry_count, - pause=pause, session=session) + + def __init__( + self, + dataset_id=None, + api_key=None, + retry_count=5, + pause=0.75, + session=None, + base_url="https://public.enigma.com/api", + ): + + super(EnigmaReader, self).__init__( + symbols=[], retry_count=retry_count, pause=pause, session=session + ) if api_key is None: - self._api_key = os.getenv('ENIGMA_API_KEY') + self._api_key = os.getenv("ENIGMA_API_KEY") if self._api_key is None: - raise ValueError("Please provide an Enigma API key or set " - "the ENIGMA_API_KEY environment variable\n" - "If you do not have an API key, you can get " - "one here: http://public.enigma.com/signup") + raise ValueError( + "Please provide an Enigma API key or set " + "the ENIGMA_API_KEY environment variable\n" + "If you do not have an API key, you can get " + "one here: http://public.enigma.com/signup" + ) else: self._api_key = api_key @@ -67,11 +72,12 @@ def __init__(self, if not isinstance(self._dataset_id, string_types): raise ValueError( "The Enigma dataset_id must be a string (ex: " - "'bedaf052-5fcd-4758-8d27-048ce8746c6a')") + "'bedaf052-5fcd-4758-8d27-048ce8746c6a')" + ) headers = { - 'Authorization': 'Bearer {0}'.format(self._api_key), - 'User-Agent': 'pandas-datareader', + "Authorization": "Bearer {0}".format(self._api_key), + "User-Agent": "pandas-datareader", } self.session.headers.update(headers) self._base_url = base_url @@ -111,7 +117,7 @@ def _get(self, url): def get_current_snapshot_id(self, dataset_id): """Get ID of the most current snapshot of a dataset""" dataset_metadata = self.get_dataset_metadata(dataset_id) - return dataset_metadata['current_snapshot']['id'] + return dataset_metadata["current_snapshot"]["id"] def get_dataset_metadata(self, dataset_id): """Get the Dataset Model of this EnigmaReader's dataset diff --git a/pandas_datareader/eurostat.py b/pandas_datareader/eurostat.py index 8777fcb3..fd59037b 100644 --- a/pandas_datareader/eurostat.py +++ b/pandas_datareader/eurostat.py @@ -2,34 +2,32 @@ import pandas as pd -from pandas_datareader.io.sdmx import read_sdmx, _read_sdmx_dsd -from pandas_datareader.compat import string_types from pandas_datareader.base import _BaseReader +from pandas_datareader.compat import string_types +from pandas_datareader.io.sdmx import _read_sdmx_dsd, read_sdmx class EurostatReader(_BaseReader): """Get data for the given name from Eurostat.""" - _URL = 'http://ec.europa.eu/eurostat/SDMX/diss-web/rest' + _URL = "http://ec.europa.eu/eurostat/SDMX/diss-web/rest" @property def url(self): """API URL""" if not isinstance(self.symbols, string_types): - raise ValueError('data name must be string') + raise ValueError("data name must be string") - q = '{0}/data/{1}/?startperiod={2}&endperiod={3}' - return q.format(self._URL, self.symbols, - self.start.year, self.end.year) + q = "{0}/data/{1}/?startperiod={2}&endperiod={3}" + return q.format(self._URL, self.symbols, self.start.year, self.end.year) @property def dsd_url(self): """API DSD URL""" if not isinstance(self.symbols, string_types): - raise ValueError('data name must be string') + raise ValueError("data name must be string") - return '{0}/datastructure/ESTAT/DSD_{1}'.format( - self._URL, self.symbols) + return "{0}/datastructure/ESTAT/DSD_{1}".format(self._URL, self.symbols) def _read_one_data(self, url, params): resp_dsd = self._get_response(self.dsd_url) diff --git a/pandas_datareader/famafrench.py b/pandas_datareader/famafrench.py index 6aa6230a..eb2b9174 100644 --- a/pandas_datareader/famafrench.py +++ b/pandas_datareader/famafrench.py @@ -1,17 +1,16 @@ import datetime as dt import re import tempfile - from zipfile import ZipFile from pandas import read_csv, to_datetime -from pandas_datareader.compat import lmap, StringIO from pandas_datareader.base import _BaseReader +from pandas_datareader.compat import StringIO, lmap -_URL = 'http://mba.tuck.dartmouth.edu/pages/faculty/ken.french/' -_URL_PREFIX = 'ftp/' -_URL_SUFFIX = '_CSV.zip' +_URL = "http://mba.tuck.dartmouth.edu/pages/faculty/ken.french/" +_URL_PREFIX = "ftp/" +_URL_SUFFIX = "_CSV.zip" def get_available_datasets(**kwargs): @@ -27,13 +26,13 @@ def get_available_datasets(**kwargs): ------- A list of valid inputs for get_data_famafrench. """ - return FamaFrenchReader(symbols='', **kwargs).get_available_datasets() + return FamaFrenchReader(symbols="", **kwargs).get_available_datasets() def _parse_date_famafrench(x): x = x.strip() try: - return dt.datetime.strptime(x, '%Y%m') + return dt.datetime.strptime(x, "%Y%m") except Exception: pass return to_datetime(x) @@ -50,7 +49,7 @@ class FamaFrenchReader(_BaseReader): @property def url(self): """API URL""" - return ''.join([_URL, _URL_PREFIX, self.symbols, _URL_SUFFIX]) + return "".join([_URL, _URL_PREFIX, self.symbols, _URL_SUFFIX]) def _read_zipfile(self, url): raw = self._get_response(url).content @@ -58,7 +57,7 @@ def _read_zipfile(self, url): with tempfile.TemporaryFile() as tmpf: tmpf.write(raw) - with ZipFile(tmpf, 'r') as zf: + with ZipFile(tmpf, "r") as zf: data = zf.open(zf.namelist()[0]).read().decode() return data @@ -77,40 +76,42 @@ def read(self): def _read_one_data(self, url, params): - params = {'index_col': 0, - 'parse_dates': [0], - 'date_parser': _parse_date_famafrench} + params = { + "index_col": 0, + "parse_dates": [0], + "date_parser": _parse_date_famafrench, + } # headers in these files are not valid - if self.symbols.endswith('_Breakpoints'): + if self.symbols.endswith("_Breakpoints"): - if self.symbols.find('-') > -1: - c = ['<=0', '>0'] + if self.symbols.find("-") > -1: + c = ["<=0", ">0"] else: - c = ['Count'] + c = ["Count"] r = list(range(0, 105, 5)) - params['names'] = ['Date'] + c + list(zip(r, r[1:])) + params["names"] = ["Date"] + c + list(zip(r, r[1:])) - if self.symbols != 'Prior_2-12_Breakpoints': - params['skiprows'] = 1 + if self.symbols != "Prior_2-12_Breakpoints": + params["skiprows"] = 1 else: - params['skiprows'] = 3 + params["skiprows"] = 3 doc_chunks, tables = [], [] data = self._read_zipfile(url) - for chunk in data.split(2 * '\r\n'): + for chunk in data.split(2 * "\r\n"): if len(chunk) < 800: - doc_chunks.append(chunk.replace('\r\n', ' ').strip()) + doc_chunks.append(chunk.replace("\r\n", " ").strip()) else: tables.append(chunk) datasets, table_desc = {}, [] for i, src in enumerate(tables): - match = re.search(r'^\s*,', src, re.M) # the table starts there + match = re.search(r"^\s*,", src, re.M) # the table starts there start = 0 if not match else match.start() - df = read_csv(StringIO('Date' + src[start:]), **params) + df = read_csv(StringIO("Date" + src[start:]), **params) try: idx_name = df.index.name # hack for pandas 0.16.2 df = df.to_period(df.index.inferred_freq[:1]) @@ -120,17 +121,17 @@ def _read_one_data(self, url, params): df = df.truncate(self.start, self.end) datasets[i] = df - title = src[:start].replace('\r\n', ' ').strip() - shape = '({0} rows x {1} cols)'.format(*df.shape) - table_desc.append('{0} {1}'.format(title, shape).strip()) + title = src[:start].replace("\r\n", " ").strip() + shape = "({0} rows x {1} cols)".format(*df.shape) + table_desc.append("{0} {1}".format(title, shape).strip()) - descr = '{0}\n{1}\n\n'.format(self.symbols.replace('_', ' '), - len(self.symbols) * '-') + descr = "{0}\n{1}\n\n".format( + self.symbols.replace("_", " "), len(self.symbols) * "-" + ) if doc_chunks: - descr += ' '.join(doc_chunks).replace(2 * ' ', ' ') + '\n\n' - table_descr = map(lambda x: '{0:3} : {1}'.format(*x), - enumerate(table_desc)) - datasets['DESCR'] = descr + '\n'.join(table_descr) + descr += " ".join(doc_chunks).replace(2 * " ", " ") + "\n\n" + table_descr = map(lambda x: "{0:3} : {1}".format(*x), enumerate(table_desc)) + datasets["DESCR"] = descr + "\n".join(table_descr) return datasets @@ -146,15 +147,21 @@ def get_available_datasets(self): try: from lxml.html import document_fromstring except ImportError: - raise ImportError("Please install lxml if you want to use the " - "get_datasets_famafrench function") + raise ImportError( + "Please install lxml if you want to use the " + "get_datasets_famafrench function" + ) - response = self.session.get(_URL + 'data_library.html') + response = self.session.get(_URL + "data_library.html") root = document_fromstring(response.content) - datasets = [e.attrib['href'] for e in root.findall('.//a') - if 'href' in e.attrib] - datasets = [ds for ds in datasets if ds.startswith(_URL_PREFIX) - and ds.endswith(_URL_SUFFIX)] + datasets = [ + e.attrib["href"] for e in root.findall(".//a") if "href" in e.attrib + ] + datasets = [ + ds + for ds in datasets + if ds.startswith(_URL_PREFIX) and ds.endswith(_URL_SUFFIX) + ] - return lmap(lambda x: x[len(_URL_PREFIX):-len(_URL_SUFFIX)], datasets) + return lmap(lambda x: x[len(_URL_PREFIX) : -len(_URL_SUFFIX)], datasets) diff --git a/pandas_datareader/fred.py b/pandas_datareader/fred.py index 52bffd68..699c91c3 100644 --- a/pandas_datareader/fred.py +++ b/pandas_datareader/fred.py @@ -1,8 +1,7 @@ -from pandas_datareader.compat import is_list_like - from pandas import concat, read_csv from pandas_datareader.base import _BaseReader +from pandas_datareader.compat import is_list_like class FredReader(_BaseReader): @@ -40,16 +39,26 @@ def _read(self): def fetch_data(url, name): """Utillity to fetch data""" resp = self._read_url_as_StringIO(url) - data = read_csv(resp, index_col=0, parse_dates=True, - header=None, skiprows=1, names=["DATE", name], - na_values='.') + data = read_csv( + resp, + index_col=0, + parse_dates=True, + header=None, + skiprows=1, + names=["DATE", name], + na_values=".", + ) try: return data.truncate(self.start, self.end) except KeyError: # pragma: no cover - if data.iloc[3].name[7:12] == 'Error': - raise IOError("Failed to get the data. Check that " - "{0!r} is a valid FRED series.".format(name)) + if data.iloc[3].name[7:12] == "Error": + raise IOError( + "Failed to get the data. Check that " + "{0!r} is a valid FRED series.".format(name) + ) raise - df = concat([fetch_data(url, n) for url, n in zip(urls, names)], - axis=1, join='outer') + + df = concat( + [fetch_data(url, n) for url, n in zip(urls, names)], axis=1, join="outer" + ) return df diff --git a/pandas_datareader/iex/__init__.py b/pandas_datareader/iex/__init__.py index 195e02f7..221009b1 100644 --- a/pandas_datareader/iex/__init__.py +++ b/pandas_datareader/iex/__init__.py @@ -2,8 +2,8 @@ import pandas as pd from pandas.io.common import urlencode -from pandas_datareader.base import _BaseReader +from pandas_datareader.base import _BaseReader # Data provided for free by IEX # Data is furnished in compliance with the guidelines promulgated in the IEX @@ -17,14 +17,19 @@ class IEX(_BaseReader): Serves as the base class for all IEX API services. """ - _format = 'json' + _format = "json" - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): - super(IEX, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): + super(IEX, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): @@ -36,8 +41,7 @@ def service(self): def url(self): """API URL""" qstring = urlencode(self._get_params(self.symbols)) - return "https://api.iextrading.com/1.0/{}?{}".format(self.service, - qstring) + return "https://api.iextrading.com/1.0/{}?{}".format(self.service, qstring) def read(self): """Read data""" @@ -51,9 +55,9 @@ def read(self): def _get_params(self, symbols): p = {} if isinstance(symbols, list): - p['symbols'] = ','.join(symbols) + p["symbols"] = ",".join(symbols) elif isinstance(symbols, str): - p['symbols'] = symbols + p["symbols"] = symbols return p def _output_error(self, out): @@ -69,7 +73,7 @@ def _output_error(self, out): for key, string in content.items(): e = "IEX Output error encountered: {}".format(string) - if key == 'error': + if key == "error": raise Exception(e) def _read_lines(self, out): diff --git a/pandas_datareader/iex/daily.py b/pandas_datareader/iex/daily.py index d16c69a1..cb860522 100644 --- a/pandas_datareader/iex/daily.py +++ b/pandas_datareader/iex/daily.py @@ -2,9 +2,9 @@ import json import os +from dateutil.relativedelta import relativedelta import pandas as pd -from dateutil.relativedelta import relativedelta from pandas_datareader.base import _DailyBaseReader # Data provided for free by IEX @@ -45,32 +45,48 @@ class IEXDailyReader(_DailyBaseReader): IEX Cloud Secret Token """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None, chunksize=25, api_key=None): + def __init__( + self, + symbols=None, + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + chunksize=25, + api_key=None, + ): if api_key is None: - api_key = os.getenv('IEX_API_KEY') + api_key = os.getenv("IEX_API_KEY") if not api_key or not isinstance(api_key, str): - raise ValueError('The IEX Cloud API key must be provided either ' - 'through the api_key variable or through the ' - ' environment variable IEX_API_KEY') + raise ValueError( + "The IEX Cloud API key must be provided either " + "through the api_key variable or through the " + " environment variable IEX_API_KEY" + ) # Support for sandbox environment (testing purposes) if os.getenv("IEX_SANDBOX") == "enable": self.sandbox = True else: self.sandbox = False self.api_key = api_key - super(IEXDailyReader, self).__init__(symbols=symbols, start=start, - end=end, retry_count=retry_count, - pause=pause, session=session, - chunksize=chunksize) + super(IEXDailyReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + chunksize=chunksize, + ) @property def url(self): """API URL""" if self.sandbox is True: - return 'https://sandbox.iexapis.com/stable/stock/market/batch' + return "https://sandbox.iexapis.com/stable/stock/market/batch" else: - return 'https://cloud.iexapis.com/stable/stock/market/batch' + return "https://cloud.iexapis.com/stable/stock/market/batch" @property def endpoint(self): @@ -80,20 +96,20 @@ def endpoint(self): def _get_params(self, symbol): chart_range = self._range_string_from_date() if isinstance(symbol, list): - symbolList = ','.join(symbol) + symbolList = ",".join(symbol) else: symbolList = symbol params = { "symbols": symbolList, "types": self.endpoint, "range": chart_range, - "token": self.api_key + "token": self.api_key, } return params def _range_string_from_date(self): delta = relativedelta(self.start, datetime.datetime.now()) - years = (delta.years * -1) + years = delta.years * -1 if 5 <= years <= 15: return "max" if 2 <= years < 5: @@ -113,14 +129,12 @@ def _range_string_from_date(self): return "1y" else: - raise ValueError( - "Invalid date specified. Must be within past 15 years.") + raise ValueError("Invalid date specified. Must be within past 15 years.") def read(self): """Read data""" try: - return self._read_one_data(self.url, - self._get_params(self.symbols)) + return self._read_one_data(self.url, self._get_params(self.symbols)) finally: self.close() @@ -138,12 +152,12 @@ def _read_lines(self, out): df.set_index("date", inplace=True) values = ["open", "high", "low", "close", "volume"] df = df[values] - sstart = self.start.strftime('%Y-%m-%d') - send = self.end.strftime('%Y-%m-%d') + sstart = self.start.strftime("%Y-%m-%d") + send = self.end.strftime("%Y-%m-%d") df = df.loc[sstart:send] result.update({symbol: df}) if len(result) > 1: result = pd.concat(result).unstack(level=0) - result.columns.names = ['Attributes', 'Symbols'] + result.columns.names = ["Attributes", "Symbols"] return result return result[self.symbols] diff --git a/pandas_datareader/iex/deep.py b/pandas_datareader/iex/deep.py index 9dcc27af..02304ae2 100644 --- a/pandas_datareader/iex/deep.py +++ b/pandas_datareader/iex/deep.py @@ -1,6 +1,7 @@ -from pandas_datareader.iex import IEX from datetime import datetime +from pandas_datareader.iex import IEX + # Data provided for free by IEX # Data is furnished in compliance with the guidelines promulgated in the IEX # API terms of service and manual @@ -22,17 +23,30 @@ class Deep(IEX): Also provides last trade price and size information. Routed executions are not reported. """ - def __init__(self, symbols=None, service=None, start=None, end=None, - retry_count=3, pause=0.1, session=None): + + def __init__( + self, + symbols=None, + service=None, + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + ): if isinstance(symbols, str): symbols = symbols.lower() else: symbols = [s.lower() for s in symbols] - super(Deep, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + super(Deep, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) self.sub = service @property @@ -51,15 +65,15 @@ def _read_lines(self, out): # Runs appropriate output functions per the service being accessed. fmap = { - 'book': '_pass', - 'op-halt-status': '_convert_tstamp', - 'security-event': '_convert_tstamp', - 'ssr-status': '_convert_tstamp', - 'system-event': '_read_system_event', - 'trades': '_pass', - 'trade-breaks': '_convert_tstamp', - 'trading-status': '_read_trading_status', - None: '_pass', + "book": "_pass", + "op-halt-status": "_convert_tstamp", + "security-event": "_convert_tstamp", + "ssr-status": "_convert_tstamp", + "system-event": "_read_system_event", + "trades": "_pass", + "trade-breaks": "_convert_tstamp", + "trading-status": "_read_trading_status", + None: "_pass", } if self.sub in fmap: @@ -71,12 +85,12 @@ def _read_system_event(self, out): # Map the response code to a string output per the API docs. # Per: https://www.iextrading.com/developer/docs/#system-event-message smap = { - 'O': 'Start of messages', - 'S': 'Start of system hours', - 'R': 'Start of regular market hours', - 'M': 'End of regular market hours', - 'E': 'End of system hours', - 'C': 'End of messages' + "O": "Start of messages", + "S": "Start of system hours", + "R": "Start of regular market hours", + "M": "End of regular market hours", + "E": "End of system hours", + "C": "End of messages", } tid = out["systemEvent"] out["eventResponse"] = smap[tid] @@ -90,34 +104,33 @@ def _pass(out): def _read_trading_status(self, out): # Reference: https://www.iextrading.com/developer/docs/#trading-status smap = { - 'H': 'Trading halted across all US equity markets', - 'O': 'Trading halt released into an Order Acceptance Period ' - '(IEX-listed securities only)', - 'P': 'Trading paused and Order Acceptance Period on IEX ' - '(IEX-listed securities only)', - 'T': 'Trading on IEX' + "H": "Trading halted across all US equity markets", + "O": "Trading halt released into an Order Acceptance Period " + "(IEX-listed securities only)", + "P": "Trading paused and Order Acceptance Period on IEX " + "(IEX-listed securities only)", + "T": "Trading on IEX", } rmap = { # Trading Halt Reasons - 'T1': 'Halt News Pending', - 'IPO1': 'IPO/New Issue Not Yet Trading', - 'IPOD': 'IPO/New Issue Deferred', - 'MCB3': 'Market-Wide Circuit Breaker Level 3 - Breached', - 'NA': 'Reason Not Available', - + "T1": "Halt News Pending", + "IPO1": "IPO/New Issue Not Yet Trading", + "IPOD": "IPO/New Issue Deferred", + "MCB3": "Market-Wide Circuit Breaker Level 3 - Breached", + "NA": "Reason Not Available", # Order Acceptance Period Reasons - 'T2': 'Halt News Dissemination', - 'IPO2': 'IPO/New Issue Order Acceptance Period', - 'IPO3': 'IPO Pre-Launch Period', - 'MCB1': 'Market-Wide Circuit Breaker Level 1 - Breached', - 'MCB2': 'Market-Wide Circuit Breaker Level 2 - Breached' + "T2": "Halt News Dissemination", + "IPO2": "IPO/New Issue Order Acceptance Period", + "IPO3": "IPO Pre-Launch Period", + "MCB1": "Market-Wide Circuit Breaker Level 1 - Breached", + "MCB2": "Market-Wide Circuit Breaker Level 2 - Breached", } for ticker, data in out.items(): - if data['status'] in smap: - data['statusText'] = smap[data['status']] + if data["status"] in smap: + data["statusText"] = smap[data["status"]] - if data['reason'] in rmap: - data['reasonText'] = rmap[data['reason']] + if data["reason"] in rmap: + data["reasonText"] = rmap[data["reason"]] out[ticker] = data @@ -126,15 +139,15 @@ def _read_trading_status(self, out): @staticmethod def _convert_tstamp(out): # Searches for top-level timestamp attributes or within dictionaries - if 'timestamp' in out: + if "timestamp" in out: # Convert UNIX to datetime object f = float(out["timestamp"]) - out["timestamp"] = datetime.fromtimestamp(f/1000) + out["timestamp"] = datetime.fromtimestamp(f / 1000) else: for ticker, data in out.items(): - if 'timestamp' in data: + if "timestamp" in data: f = float(data["timestamp"]) - data["timestamp"] = datetime.fromtimestamp(f/1000) + data["timestamp"] = datetime.fromtimestamp(f / 1000) out[ticker] = data return out diff --git a/pandas_datareader/iex/market.py b/pandas_datareader/iex/market.py index 16fa4452..6cd1a879 100644 --- a/pandas_datareader/iex/market.py +++ b/pandas_datareader/iex/market.py @@ -16,12 +16,18 @@ class MarketReader(IEX): Market data is captured by the IEX system between approximately 7:45 a.m. and 5:15 p.m. ET. """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): - super(MarketReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): + super(MarketReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): diff --git a/pandas_datareader/iex/ref.py b/pandas_datareader/iex/ref.py index 995905ed..93908fc4 100644 --- a/pandas_datareader/iex/ref.py +++ b/pandas_datareader/iex/ref.py @@ -16,12 +16,18 @@ class SymbolsReader(IEX): Returns symbols IEX supports for trading. Updated daily as of 7:45 a.m. ET. """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): - super(SymbolsReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): + super(SymbolsReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): diff --git a/pandas_datareader/iex/stats.py b/pandas_datareader/iex/stats.py index 873d6bec..a0c5cf34 100644 --- a/pandas_datareader/iex/stats.py +++ b/pandas_datareader/iex/stats.py @@ -5,7 +5,6 @@ from pandas_datareader.exceptions import UnstableAPIWarning from pandas_datareader.iex import IEX - # Data provided for free by IEX # Data is furnished in compliance with the guidelines promulgated in the IEX # API terms of service and manual @@ -17,16 +16,25 @@ class DailySummaryReader(IEX): """ Daily statistics from IEX for a day or month """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): + + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): import warnings - warnings.warn('Daily statistics is not working due to issues with the ' - 'IEX API', UnstableAPIWarning) + + warnings.warn( + "Daily statistics is not working due to issues with the " "IEX API", + UnstableAPIWarning, + ) self.curr_date = start - super(DailySummaryReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + super(DailySummaryReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): @@ -37,7 +45,7 @@ def _get_params(self, symbols): p = {} if self.curr_date is not None: - p['date'] = self.curr_date.strftime('%Y%m%d') + p["date"] = self.curr_date.strftime("%Y%m%d") return p @@ -59,16 +67,21 @@ def read(self): class MonthlySummaryReader(IEX): """Monthly statistics from IEX""" - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): + + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): self.curr_date = start - self.date_format = '%Y%m' + self.date_format = "%Y%m" - super(MonthlySummaryReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, - session=session) + super(MonthlySummaryReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): @@ -79,7 +92,7 @@ def _get_params(self, symbols): p = {} if self.curr_date is not None: - p['date'] = self.curr_date.strftime(self.date_format) + p["date"] = self.curr_date.strftime(self.date_format) return p @@ -94,8 +107,7 @@ def read(self): dfs = [] # Build list of all dates within the given range - lrange = [x for x in (self.start + timedelta(n) - for n in range(tlen.days))] + lrange = [x for x in (self.start + timedelta(n) for n in range(tlen.days))] mrange = [] for dt in lrange: @@ -109,7 +121,7 @@ def read(self): # We may not return data if this was a weekend/holiday: if not tdf.empty: - tdf['date'] = date.strftime(self.date_format) + tdf["date"] = date.strftime(self.date_format) dfs.append(tdf) # We may not return any data if we failed to specify useful parameters: @@ -120,12 +132,18 @@ class RecordsReader(IEX): """ Total matched volume information from IEX """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): - super(RecordsReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): + super(RecordsReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): @@ -155,12 +173,18 @@ class RecentReader(IEX): * litVolume: refers to the number of lit shares traded on IEX (single-counted). """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): - super(RecentReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): + super(RecentReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): diff --git a/pandas_datareader/iex/tops.py b/pandas_datareader/iex/tops.py index 410fd2be..10a6c1ec 100644 --- a/pandas_datareader/iex/tops.py +++ b/pandas_datareader/iex/tops.py @@ -17,12 +17,17 @@ class TopsReader(IEX): on IEX's displayed limit order book. """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): - super(TopsReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): + super(TopsReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): @@ -39,13 +44,19 @@ class LastReader(IEX): Last provides trade data for executions on IEX. Provides last sale price, size and time. """ + # todo: Eventually we'll want to implement WebSockets as an option. - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None): - super(LastReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + def __init__( + self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None + ): + super(LastReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) @property def service(self): diff --git a/pandas_datareader/io/jsdmx.py b/pandas_datareader/io/jsdmx.py index e993effb..4fc0dd69 100644 --- a/pandas_datareader/io/jsdmx.py +++ b/pandas_datareader/io/jsdmx.py @@ -1,8 +1,8 @@ # pylint: disable-msg=E1101,W0613,W0603 from __future__ import unicode_literals -from collections import OrderedDict +from collections import OrderedDict import itertools import sys @@ -32,7 +32,7 @@ def read_jsdmx(path_or_buf): import simplejson as json except ImportError: if sys.version_info[:2] < (2, 7): - raise ImportError('simplejson is required in python 2.6') + raise ImportError("simplejson is required in python 2.6") import json if isinstance(jdata, dict): @@ -40,11 +40,11 @@ def read_jsdmx(path_or_buf): else: data = json.loads(jdata, object_pairs_hook=OrderedDict) - structure = data['structure'] - index = _parse_dimensions(structure['dimensions']['observation']) - columns = _parse_dimensions(structure['dimensions']['series']) + structure = data["structure"] + index = _parse_dimensions(structure["dimensions"]["observation"]) + columns = _parse_dimensions(structure["dimensions"]["series"]) - dataset = data['dataSets'] + dataset = data["dataSets"] if len(dataset) != 1: raise ValueError("length of 'dataSets' must be 1") dataset = dataset[0] @@ -58,20 +58,19 @@ def _get_indexer(index): if index.nlevels == 1: return [str(i) for i in range(len(index))] else: - it = itertools.product(*[range( - len(level)) for level in index.levels]) - return [':'.join(map(str, i)) for i in it] + it = itertools.product(*[range(len(level)) for level in index.levels]) + return [":".join(map(str, i)) for i in it] def _parse_values(dataset, index, columns): size = len(index) - series = dataset['series'] + series = dataset["series"] values = [] # for s_key, s_value in iteritems(series): for s_key in _get_indexer(columns): try: - observations = series[s_key]['observations'] + observations = series[s_key]["observations"] observed = [] for o_key in _get_indexer(index): try: @@ -90,14 +89,14 @@ def _parse_dimensions(dimensions): arrays = [] names = [] for key in dimensions: - values = [v['name'] for v in key['values']] + values = [v["name"] for v in key["values"]] - role = key.get('role', None) - if role in ('time', 'TIME_PERIOD'): + role = key.get("role", None) + if role in ("time", "TIME_PERIOD"): values = pd.DatetimeIndex(values) arrays.append(values) - names.append(key['name']) + names.append(key["name"]) midx = pd.MultiIndex.from_product(arrays, names=names) if len(arrays) == 1 and isinstance(midx, pd.MultiIndex): # Fix for pandas >= 0.21 diff --git a/pandas_datareader/io/sdmx.py b/pandas_datareader/io/sdmx.py index 4295c319..55f185f0 100644 --- a/pandas_datareader/io/sdmx.py +++ b/pandas_datareader/io/sdmx.py @@ -1,34 +1,33 @@ from __future__ import unicode_literals import collections +from io import BytesIO import time import zipfile -from io import BytesIO import pandas as pd -from pandas_datareader.io.util import _read_content from pandas_datareader.compat import HTTPError, str_to_bytes +from pandas_datareader.io.util import _read_content +_STRUCTURE = "{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/structure}" +_MESSAGE = "{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/message}" +_GENERIC = "{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/data/generic}" +_COMMON = "{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/common}" +_XML = "{http://www.w3.org/XML/1998/namespace}" -_STRUCTURE = '{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/structure}' -_MESSAGE = '{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/message}' -_GENERIC = '{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/data/generic}' -_COMMON = '{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/common}' -_XML = '{http://www.w3.org/XML/1998/namespace}' - -_DATASET = _MESSAGE + 'DataSet' -_SERIES = _GENERIC + 'Series' -_SERIES_KEY = _GENERIC + 'SeriesKey' -_OBSERVATION = _GENERIC + 'Obs' -_VALUE = _GENERIC + 'Value' -_OBSDIMENSION = _GENERIC + 'ObsDimension' -_OBSVALUE = _GENERIC + 'ObsValue' -_CODE = _STRUCTURE + 'Code' -_TIMEDIMENSION = _STRUCTURE + 'TimeDimension' +_DATASET = _MESSAGE + "DataSet" +_SERIES = _GENERIC + "Series" +_SERIES_KEY = _GENERIC + "SeriesKey" +_OBSERVATION = _GENERIC + "Obs" +_VALUE = _GENERIC + "Value" +_OBSDIMENSION = _GENERIC + "ObsDimension" +_OBSVALUE = _GENERIC + "ObsValue" +_CODE = _STRUCTURE + "Code" +_TIMEDIMENSION = _STRUCTURE + "TimeDimension" -def read_sdmx(path_or_buf, dtype='float64', dsd=None): +def read_sdmx(path_or_buf, dtype="float64", dsd=None): """ Convert a SDMX-XML string to pandas object @@ -49,14 +48,15 @@ def read_sdmx(path_or_buf, dtype='float64', dsd=None): xdata = _read_content(path_or_buf) import xml.etree.ElementTree as ET + root = ET.fromstring(xdata) try: - structure = _get_child(root, _MESSAGE + 'Structure') + structure = _get_child(root, _MESSAGE + "Structure") except ValueError: # get zipped path - result = list(root.iter(_COMMON + 'Text'))[1].text - if not result.startswith('http'): + result = list(root.iter(_COMMON + "Text"))[1].text + if not result.startswith("http"): raise ValueError(result) for _ in range(60): @@ -68,11 +68,13 @@ def read_sdmx(path_or_buf, dtype='float64', dsd=None): time.sleep(1) continue - msg = ('Unable to download zipped data within 60 secs, ' - 'please download it manually from: {0}') + msg = ( + "Unable to download zipped data within 60 secs, " + "please download it manually from: {0}" + ) raise ValueError(msg.format(result)) - idx_name = structure.get('dimensionAtObservation') + idx_name = structure.get("dimensionAtObservation") dataset = _get_child(root, _DATASET) keys = [] @@ -141,8 +143,7 @@ def _construct_index(keys, dsd=None): except KeyError: values[name] = [value] - midx = pd.MultiIndex.from_arrays([values[name] for name in names], - names=names) + midx = pd.MultiIndex.from_arrays([values[name] for name in names], names=names) return midx @@ -151,7 +152,7 @@ def _parse_observations(observations): for observation in observations: obsdimension = _get_child(observation, _OBSDIMENSION) obsvalue = _get_child(observation, _OBSVALUE) - results.append((obsdimension.get('value'), obsvalue.get('value'))) + results.append((obsdimension.get("value"), obsvalue.get("value"))) # return list of key/value tuple, eg: [(key, value), ...] return results @@ -159,7 +160,7 @@ def _parse_observations(observations): def _parse_series_key(series): serieskey = _get_child(series, _SERIES_KEY) key_values = serieskey.iter(_VALUE) - keys = [(key.get('id'), key.get('value')) for key in key_values] + keys = [(key.get("id"), key.get("value")) for key in key_values] # return list of key/value tuple, eg: [(key, value), ...] return keys @@ -169,11 +170,11 @@ def _get_child(element, key): if len(elements) == 1: return elements[0] elif len(elements) == 0: - raise ValueError("Element {0} contains " - "no {1}".format(element.tag, key)) + raise ValueError("Element {0} contains " "no {1}".format(element.tag, key)) else: - raise ValueError("Element {0} contains " - "multiple {1}".format(element.tag, key)) + raise ValueError( + "Element {0} contains " "multiple {1}".format(element.tag, key) + ) _NAME_EN = ".//{0}Name[@{1}lang='en']".format(_COMMON, _XML) @@ -184,7 +185,7 @@ def _get_english_name(element): return name -SDMXCode = collections.namedtuple('SDMXCode', ['codes', 'ts']) +SDMXCode = collections.namedtuple("SDMXCode", ["codes", "ts"]) def _read_sdmx_dsd(path_or_buf): @@ -204,12 +205,13 @@ def _read_sdmx_dsd(path_or_buf): xdata = _read_content(path_or_buf) import xml.etree.cElementTree as ET + root = ET.fromstring(xdata) - structure = _get_child(root, _MESSAGE + 'Structures') - codes = _get_child(structure, _STRUCTURE + 'Codelists') + structure = _get_child(root, _MESSAGE + "Structures") + codes = _get_child(structure, _STRUCTURE + "Codelists") # concepts = _get_child(structure, _STRUCTURE + 'Concepts') - datastructures = _get_child(structure, _STRUCTURE + 'DataStructures') + datastructures = _get_child(structure, _STRUCTURE + "DataStructures") code_results = {} for codelist in codes: @@ -217,7 +219,7 @@ def _read_sdmx_dsd(path_or_buf): codelist_name = _get_english_name(codelist) mapper = {} for code in codelist.iter(_CODE): - code_id = code.get('id') + code_id = code.get("id") name = _get_english_name(code) mapper[code_id] = name # codeobj = SDMXCode(id=codelist_id, name=codelist_name, mapper=mapper) @@ -225,7 +227,7 @@ def _read_sdmx_dsd(path_or_buf): code_results[codelist_name] = mapper times = list(datastructures.iter(_TIMEDIMENSION)) - times = [t.get('id') for t in times] + times = [t.get("id") for t in times] result = SDMXCode(codes=code_results, ts=times) return result diff --git a/pandas_datareader/io/util.py b/pandas_datareader/io/util.py index 25f6168e..92a3ae8b 100644 --- a/pandas_datareader/io/util.py +++ b/pandas_datareader/io/util.py @@ -19,11 +19,11 @@ def _read_content(path_or_buf): exists = False if exists: - with open(filepath_or_buffer, 'r') as fh: + with open(filepath_or_buffer, "r") as fh: data = fh.read() else: data = filepath_or_buffer - elif hasattr(filepath_or_buffer, 'read'): + elif hasattr(filepath_or_buffer, "read"): data = filepath_or_buffer.read() else: data = filepath_or_buffer diff --git a/pandas_datareader/moex.py b/pandas_datareader/moex.py index e439efdb..ff3aa68d 100644 --- a/pandas_datareader/moex.py +++ b/pandas_datareader/moex.py @@ -5,9 +5,7 @@ import pandas as pd from pandas_datareader.base import _DailyBaseReader -from pandas_datareader.compat import (binary_type, concat, is_list_like, - StringIO) - +from pandas_datareader.compat import StringIO, binary_type, concat, is_list_like class MoexReader(_DailyBaseReader): @@ -54,39 +52,45 @@ def __init__(self, *args, **kwargs): self.__engines, self.__markets = {}, {} # dicts for engines and markets __url_metadata = "https://iss.moex.com/iss/securities/{symbol}.csv" - __url_data = "https://iss.moex.com/iss/history/engines/{engine}/" \ - "markets/{market}/securities/{symbol}.csv" + __url_data = ( + "https://iss.moex.com/iss/history/engines/{engine}/" + "markets/{market}/securities/{symbol}.csv" + ) @property def url(self): """Return a list of API URLs per symbol""" if not self.__engines or not self.__markets: - raise Exception("Accessing url property before invocation " - "of read() or _get_metadata() methods") + raise Exception( + "Accessing url property before invocation " + "of read() or _get_metadata() methods" + ) - return [self.__url_data.format( - engine=self.__engines[s], - market=self.__markets[s], - symbol=s) for s in self.symbols] + return [ + self.__url_data.format( + engine=self.__engines[s], market=self.__markets[s], symbol=s + ) + for s in self.symbols + ] def _get_params(self, start): """Return a dict for REST API GET request parameters""" params = { - 'iss.only': 'history', - 'iss.dp': 'point', - 'iss.df': '%Y-%m-%d', - 'iss.tf': '%H:%M:%S', - 'iss.dft': '%Y-%m-%d %H:%M:%S', - 'iss.json': 'extended', - 'callback': 'JSON_CALLBACK', - 'from': start, - 'till': self.end_dt.strftime('%Y-%m-%d'), - 'limit': 100, - 'start': 1, - 'sort_order': 'TRADEDATE', - 'sort_order_desc': 'asc' + "iss.only": "history", + "iss.dp": "point", + "iss.df": "%Y-%m-%d", + "iss.tf": "%H:%M:%S", + "iss.dft": "%Y-%m-%d %H:%M:%S", + "iss.json": "extended", + "callback": "JSON_CALLBACK", + "from": start, + "till": self.end_dt.strftime("%Y-%m-%d"), + "limit": 100, + "start": 1, + "sort_order": "TRADEDATE", + "sort_order_desc": "asc", } return params @@ -96,33 +100,36 @@ def _get_metadata(self): markets, engines = {}, {} for symbol in self.symbols: - response = self._get_response( - self.__url_metadata.format(symbol=symbol) - ) + response = self._get_response(self.__url_metadata.format(symbol=symbol)) text = self._sanitize_response(response) if len(text) == 0: service = self.__class__.__name__ - raise IOError("{} request returned no data; check URL for invalid " - "inputs: {}".format(service, self.__url_metadata)) + raise IOError( + "{} request returned no data; check URL for invalid " + "inputs: {}".format(service, self.__url_metadata) + ) if isinstance(text, binary_type): - text = text.decode('windows-1251') + text = text.decode("windows-1251") - header_str = 'secid;boardid;' + header_str = "secid;boardid;" get_data = False for s in text.splitlines(): if s.startswith(header_str): get_data = True continue - if get_data and s != '': - fields = s.split(';') + if get_data and s != "": + fields = s.split(";") markets[symbol], engines[symbol] = fields[5], fields[7] break if symbol not in markets or symbol not in engines: - raise IOError("{} request returned no metadata: {}\n" - "Typo in the security symbol `{}`?".format( - self.__class__.__name__, - self.__url_metadata.format(symbol=symbol), - symbol)) + raise IOError( + "{} request returned no metadata: {}\n" + "Typo in the security symbol `{}`?".format( + self.__class__.__name__, + self.__url_metadata.format(symbol=symbol), + symbol, + ) + ) return markets, engines def read(self): @@ -140,21 +147,22 @@ def read(self): while True: # read in a loop with small date intervals if len(out_list) > 0: if date_column is None: - date_column = out_list[0].split(';').index('TRADEDATE') + date_column = out_list[0].split(";").index("TRADEDATE") # get the last downloaded date - start_str = out_list[-1].split(';', 4)[date_column] - start = dt.datetime.strptime(start_str, '%Y-%m-%d').date() + start_str = out_list[-1].split(";", 4)[date_column] + start = dt.datetime.strptime(start_str, "%Y-%m-%d").date() else: - start_str = self.start.strftime('%Y-%m-%d') + start_str = self.start.strftime("%Y-%m-%d") start = self.start if start >= self.end or start >= dt.date.today(): break params = self._get_params(start_str) - strings_out = self._read_url_as_String(urls[i], params) \ - .splitlines()[2:] + strings_out = self._read_url_as_String( + urls[i], params + ).splitlines()[2:] strings_out = list(filter(lambda x: x.strip(), strings_out)) if len(out_list) == 0: @@ -165,13 +173,13 @@ def read(self): out_list += strings_out[1:] # remove a CSV head line if len(strings_out) < 100: # all data recevied - break break - str_io = StringIO('\r\n'.join(out_list)) - dfs.append(self._read_lines(str_io)) # add a new DataFrame + str_io = StringIO("\r\n".join(out_list)) + dfs.append(self._read_lines(str_io)) # add a new DataFrame finally: self.close() if len(dfs) > 1: - return concat(dfs, axis=0, join='outer', sort=True) + return concat(dfs, axis=0, join="outer", sort=True) else: return dfs[0] @@ -182,14 +190,21 @@ def _read_url_as_String(self, url, params=None): text = self._sanitize_response(response) if len(text) == 0: service = self.__class__.__name__ - raise IOError("{} request returned no data; check URL for invalid " - "inputs: {}".format(service, self.url)) + raise IOError( + "{} request returned no data; check URL for invalid " + "inputs: {}".format(service, self.url) + ) if isinstance(text, binary_type): - text = text.decode('windows-1251') + text = text.decode("windows-1251") return text def _read_lines(self, input): """Return a pandas DataFrame from input""" - return pd.read_csv(input, index_col='TRADEDATE', parse_dates=True, - sep=';', na_values=('-', 'null')) + return pd.read_csv( + input, + index_col="TRADEDATE", + parse_dates=True, + sep=";", + na_values=("-", "null"), + ) diff --git a/pandas_datareader/nasdaq_trader.py b/pandas_datareader/nasdaq_trader.py index 78cffe7e..fd645504 100644 --- a/pandas_datareader/nasdaq_trader.py +++ b/pandas_datareader/nasdaq_trader.py @@ -1,34 +1,36 @@ from ftplib import FTP, all_errors +import time +import warnings from pandas import read_csv + from pandas_datareader._utils import RemoteDataError from pandas_datareader.compat import StringIO -import time -import warnings - -_NASDAQ_TICKER_LOC = '/SymbolDirectory/nasdaqtraded.txt' -_NASDAQ_FTP_SERVER = 'ftp.nasdaqtrader.com' -_TICKER_DTYPE = [('Nasdaq Traded', bool), - ('Symbol', str), - ('Security Name', str), - ('Listing Exchange', str), - ('Market Category', str), - ('ETF', bool), - ('Round Lot Size', float), - ('Test Issue', bool), - ('Financial Status', str), - ('CQS Symbol', str), - ('NASDAQ Symbol', str), - ('NextShares', bool)] -_CATEGORICAL = ('Listing Exchange', 'Financial Status') - -_DELIMITER = '|' +_NASDAQ_TICKER_LOC = "/SymbolDirectory/nasdaqtraded.txt" +_NASDAQ_FTP_SERVER = "ftp.nasdaqtrader.com" +_TICKER_DTYPE = [ + ("Nasdaq Traded", bool), + ("Symbol", str), + ("Security Name", str), + ("Listing Exchange", str), + ("Market Category", str), + ("ETF", bool), + ("Round Lot Size", float), + ("Test Issue", bool), + ("Financial Status", str), + ("CQS Symbol", str), + ("NASDAQ Symbol", str), + ("NextShares", bool), +] +_CATEGORICAL = ("Listing Exchange", "Financial Status") + +_DELIMITER = "|" _ticker_cache = None def _bool_converter(item): - return item == 'Y' + return item == "Y" def _download_nasdaq_symbols(timeout): @@ -39,38 +41,43 @@ def _download_nasdaq_symbols(timeout): ftp_session = FTP(_NASDAQ_FTP_SERVER, timeout=timeout) ftp_session.login() except all_errors as err: - raise RemoteDataError('Error connecting to %r: %s' % - (_NASDAQ_FTP_SERVER, err)) + raise RemoteDataError("Error connecting to %r: %s" % (_NASDAQ_FTP_SERVER, err)) lines = [] try: - ftp_session.retrlines('RETR ' + _NASDAQ_TICKER_LOC, lines.append) + ftp_session.retrlines("RETR " + _NASDAQ_TICKER_LOC, lines.append) except all_errors as err: - raise RemoteDataError('Error downloading from %r: %s' % - (_NASDAQ_FTP_SERVER, err)) + raise RemoteDataError( + "Error downloading from %r: %s" % (_NASDAQ_FTP_SERVER, err) + ) finally: ftp_session.close() # Sanity Checking - if not lines[-1].startswith('File Creation Time:'): - raise RemoteDataError('Missing expected footer. Found %r' % lines[-1]) + if not lines[-1].startswith("File Creation Time:"): + raise RemoteDataError("Missing expected footer. Found %r" % lines[-1]) # Convert Y/N to True/False. - converter_map = dict((col, _bool_converter) for col, t in _TICKER_DTYPE - if t is bool) + converter_map = dict( + (col, _bool_converter) for col, t in _TICKER_DTYPE if t is bool + ) # For pandas >= 0.20.0, the Python parser issues a warning if # both a converter and dtype are specified for the same column. # However, this measure is probably temporary until the read_csv # behavior is better formalized. with warnings.catch_warnings(record=True): - data = read_csv(StringIO('\n'.join(lines[:-1])), '|', - dtype=_TICKER_DTYPE, converters=converter_map, - index_col=1) + data = read_csv( + StringIO("\n".join(lines[:-1])), + "|", + dtype=_TICKER_DTYPE, + converters=converter_map, + index_col=1, + ) # Properly cast enumerations for cat in _CATEGORICAL: - data[cat] = data[cat].astype('category') + data[cat] = data[cat].astype("category") return data @@ -87,12 +94,12 @@ def get_nasdaq_symbols(retry_count=3, timeout=30, pause=None): global _ticker_cache if timeout < 0: - raise ValueError('timeout must be >= 0, not %r' % (timeout,)) + raise ValueError("timeout must be >= 0, not %r" % (timeout,)) if pause is None: pause = timeout / 3 elif pause < 0: - raise ValueError('pause must be >= 0, not %r' % (pause,)) + raise ValueError("pause must be >= 0, not %r" % (pause,)) if _ticker_cache is None: while retry_count > 0: diff --git a/pandas_datareader/oecd.py b/pandas_datareader/oecd.py index 9d9d85e4..94364352 100644 --- a/pandas_datareader/oecd.py +++ b/pandas_datareader/oecd.py @@ -1,34 +1,34 @@ import pandas as pd -from pandas_datareader.io import read_jsdmx from pandas_datareader.base import _BaseReader from pandas_datareader.compat import string_types +from pandas_datareader.io import read_jsdmx class OECDReader(_BaseReader): """Get data for the given name from OECD.""" - _format = 'json' + _format = "json" @property def url(self): """API URL""" - url = 'http://stats.oecd.org/SDMX-JSON/data' + url = "http://stats.oecd.org/SDMX-JSON/data" if not isinstance(self.symbols, string_types): - raise ValueError('data name must be string') + raise ValueError("data name must be string") # API: https://data.oecd.org/api/sdmx-json-documentation/ - return '{0}/{1}/all/all?'.format(url, self.symbols) + return "{0}/{1}/all/all?".format(url, self.symbols) def _read_lines(self, out): """ read one data from specified URL """ df = read_jsdmx(out) try: idx_name = df.index.name # hack for pandas 0.16.2 - df.index = pd.to_datetime(df.index, errors='ignore') + df.index = pd.to_datetime(df.index, errors="ignore") for col in df: - df[col] = pd.to_numeric(df[col], errors='ignore') + df[col] = pd.to_numeric(df[col], errors="ignore") df = df.sort_index() df = df.truncate(self.start, self.end) df.index.name = idx_name diff --git a/pandas_datareader/quandl.py b/pandas_datareader/quandl.py index 54cacb5f..990d60ed 100644 --- a/pandas_datareader/quandl.py +++ b/pandas_datareader/quandl.py @@ -44,30 +44,40 @@ class QuandlReader(_DailyBaseReader): _BASE_URL = "https://www.quandl.com/api/v3/datasets/" - def __init__(self, symbols, start=None, end=None, retry_count=3, pause=0.1, - session=None, chunksize=25, api_key=None): - super(QuandlReader, self).__init__(symbols, start, end, retry_count, - pause, session, chunksize) + def __init__( + self, + symbols, + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + chunksize=25, + api_key=None, + ): + super(QuandlReader, self).__init__( + symbols, start, end, retry_count, pause, session, chunksize + ) if api_key is None: - api_key = os.getenv('QUANDL_API_KEY') + api_key = os.getenv("QUANDL_API_KEY") if not api_key or not isinstance(api_key, str): - raise ValueError('The Quandl API key must be provided either ' - 'through the api_key variable or through the ' - 'environmental variable QUANDL_API_KEY.') + raise ValueError( + "The Quandl API key must be provided either " + "through the api_key variable or through the " + "environmental variable QUANDL_API_KEY." + ) self.api_key = api_key @property def url(self): """API URL""" - symbol = self.symbols if isinstance(self.symbols, str) else \ - self.symbols[0] + symbol = self.symbols if isinstance(self.symbols, str) else self.symbols[0] mm = self._fullmatch(r"([A-Z0-9]+)(([/\.])([A-Z0-9_]+))?", symbol) - assert mm, ("Symbol '%s' must conform to Quandl convention 'DB/SYM'" % - symbol) - datasetname = 'WIKI' + assert mm, "Symbol '%s' must conform to Quandl convention 'DB/SYM'" % symbol + datasetname = "WIKI" if not mm.group(2): # bare symbol: - datasetname = 'WIKI' # default; symbol stays itself + datasetname = "WIKI" # default; symbol stays itself elif mm.group(3) == "/": # --- normal Quandl DB/SYM convention: symbol = mm.group(4) @@ -77,15 +87,16 @@ def url(self): symbol = mm.group(1) datasetname = self._db_from_countrycode(mm.group(4)) params = { - 'start_date': self.start.strftime('%Y-%m-%d'), - 'end_date': self.end.strftime('%Y-%m-%d'), - 'order': "asc", - 'api_key': self.api_key + "start_date": self.start.strftime("%Y-%m-%d"), + "end_date": self.end.strftime("%Y-%m-%d"), + "order": "asc", + "api_key": self.api_key, } - paramstring = '&'.join(['%s=%s' % (k, v) for k, v in params.items()]) - url = '{url}{dataset}/{symbol}.csv?{params}' - return url.format(url=self._BASE_URL, dataset=datasetname, - symbol=symbol, params=paramstring) + paramstring = "&".join(["%s=%s" % (k, v) for k, v in params.items()]) + url = "{url}{dataset}/{symbol}.csv?{params}" + return url.format( + url=self._BASE_URL, dataset=datasetname, symbol=symbol, params=paramstring + ) def _fullmatch(self, regex, string, flags=0): """Emulate python-3.4 re.fullmatch().""" @@ -93,27 +104,28 @@ def _fullmatch(self, regex, string, flags=0): _COUNTRYCODE_TO_DATASET = dict( # https://www.quandl.com/data/EURONEXT-Euronext-Stock-Exchange - BE='EURONEXT', + BE="EURONEXT", # https://www.quandl.com/data/HKEX-Hong-Kong-Exchange - CN='HKEX', + CN="HKEX", # https://www.quandl.com/data/SSE-Boerse-Stuttgart - DE='SSE', - FR='EURONEXT', + DE="SSE", + FR="EURONEXT", # https://www.quandl.com/data/NSE-National-Stock-Exchange-of-India - IN='NSE', + IN="NSE", # https://www.quandl.com/data/TSE-Tokyo-Stock-Exchange - JP='TSE', - NL='EURONEXT', - PT='EURONEXT', + JP="TSE", + NL="EURONEXT", + PT="EURONEXT", # https://www.quandl.com/data/LSE-London-Stock-Exchange - UK='LSE', + UK="LSE", # https://www.quandl.com/data/WIKI-Wiki-EOD-Stock-Prices - US='WIKI', + US="WIKI", ) def _db_from_countrycode(self, code): - assert code in self._COUNTRYCODE_TO_DATASET, \ + assert code in self._COUNTRYCODE_TO_DATASET, ( "No Quandl dataset known for country code '%s'" % code + ) return self._COUNTRYCODE_TO_DATASET[code] def _get_params(self, symbol): @@ -122,13 +134,15 @@ def _get_params(self, symbol): def read(self): """Read data""" df = super(QuandlReader, self).read() - df.rename(columns=lambda n: n.replace(' ', '') - .replace('.', '') - .replace('/', '') - .replace('%', '') - .replace('(', '') - .replace(')', '') - .replace("'", '') - .replace('-', ''), - inplace=True) + df.rename( + columns=lambda n: n.replace(" ", "") + .replace(".", "") + .replace("/", "") + .replace("%", "") + .replace("(", "") + .replace(")", "") + .replace("'", "") + .replace("-", ""), + inplace=True, + ) return df diff --git a/pandas_datareader/robinhood.py b/pandas_datareader/robinhood.py index adca1c40..639a3e56 100644 --- a/pandas_datareader/robinhood.py +++ b/pandas_datareader/robinhood.py @@ -1,8 +1,7 @@ import pandas as pd from pandas_datareader.base import _BaseReader -from pandas_datareader.exceptions import (ImmediateDeprecationError, - DEP_ERROR_MSG) +from pandas_datareader.exceptions import DEP_ERROR_MSG, ImmediateDeprecationError class RobinhoodQuoteReader(_BaseReader): @@ -29,14 +28,24 @@ class RobinhoodQuoteReader(_BaseReader): freq : None Quotes are near real-time and so this value is ignored """ - _format = 'json' - def __init__(self, symbols, start=None, end=None, retry_count=3, pause=.1, - timeout=30, session=None, freq=None): + _format = "json" + + def __init__( + self, + symbols, + start=None, + end=None, + retry_count=3, + pause=0.1, + timeout=30, + session=None, + freq=None, + ): raise ImmediateDeprecationError(DEP_ERROR_MSG.format("Robinhood")) - super(RobinhoodQuoteReader, self).__init__(symbols, start, end, - retry_count, pause, - timeout, session, freq) + super(RobinhoodQuoteReader, self).__init__( + symbols, start, end, retry_count, pause, timeout, session, freq + ) if isinstance(self.symbols, str): self.symbols = [self.symbols] self._max_symbols = 1630 @@ -45,8 +54,10 @@ def __init__(self, symbols, start=None, end=None, retry_count=3, pause=.1, def _validate_symbols(self): if len(self.symbols) > self._max_symbols: - raise ValueError('A maximum of {0} symbols are supported ' - 'in a single call.'.format(self._max_symbols)) + raise ValueError( + "A maximum of {0} symbols are supported " + "in a single call.".format(self._max_symbols) + ) def _get_crumb(self, *args): pass @@ -54,23 +65,23 @@ def _get_crumb(self, *args): @property def url(self): """API URL""" - return 'https://api.robinhood.com/quotes/' + return "https://api.robinhood.com/quotes/" @property def params(self): """Parameters to use in API calls""" - symbols = ','.join(self.symbols) - return {'symbols': symbols} + symbols = ",".join(self.symbols) + return {"symbols": symbols} def _process_json(self): res = pd.DataFrame(self._json_results) - return res.set_index('symbol').T + return res.set_index("symbol").T def _read_lines(self, out): - if 'next' in out: - self._json_results.extend(out['results']) - return self._read_one_data(out['next']) - self._json_results.extend(out['results']) + if "next" in out: + self._json_results.extend(out["results"]) + return self._read_one_data(out["next"]) + self._json_results.extend(out["results"]) return self._process_json() @@ -114,26 +125,42 @@ class RobinhoodHistoricalReader(RobinhoodQuoteReader): * 5minute: day, week * 10minute: day, week """ - _format = 'json' - def __init__(self, symbols, start=None, end=None, retry_count=3, pause=.1, - timeout=30, session=None, freq=None, interval='day', - span='year'): + _format = "json" + + def __init__( + self, + symbols, + start=None, + end=None, + retry_count=3, + pause=0.1, + timeout=30, + session=None, + freq=None, + interval="day", + span="year", + ): raise ImmediateDeprecationError(DEP_ERROR_MSG.format("Robinhood")) - super(RobinhoodHistoricalReader, self).__init__(symbols, start, end, - retry_count, pause, - timeout, session, freq) - interval_span = {'day': ['year'], - 'week': ['5year'], - '10minute': ['day', 'week'], - '5minute': ['day', 'week']} + super(RobinhoodHistoricalReader, self).__init__( + symbols, start, end, retry_count, pause, timeout, session, freq + ) + interval_span = { + "day": ["year"], + "week": ["5year"], + "10minute": ["day", "week"], + "5minute": ["day", "week"], + } if interval not in interval_span: - raise ValueError('Interval must be one of ' - '{0}'.format(', '.join(interval_span.keys()))) + raise ValueError( + "Interval must be one of " "{0}".format(", ".join(interval_span.keys())) + ) valid_spans = interval_span[interval] if span not in valid_spans: - raise ValueError('For interval {0}, span must ' - 'be in: {1}'.format(interval, valid_spans)) + raise ValueError( + "For interval {0}, span must " + "be in: {1}".format(interval, valid_spans) + ) self.interval = interval self.span = span self._max_symbols = 75 @@ -143,23 +170,21 @@ def __init__(self, symbols, start=None, end=None, retry_count=3, pause=.1, @property def url(self): """API URL""" - return 'https://api.robinhood.com/quotes/historicals/' + return "https://api.robinhood.com/quotes/historicals/" @property def params(self): """Parameters to use in API calls""" - symbols = ','.join(self.symbols) - pars = {'symbols': symbols, - 'interval': self.interval, - 'span': self.span} + symbols = ",".join(self.symbols) + pars = {"symbols": symbols, "interval": self.interval, "span": self.span} return pars def _process_json(self): df = [] for sym in self._json_results: - vals = pd.DataFrame(sym['historicals']) - vals['begins_at'] = pd.to_datetime(vals['begins_at']) - vals['symbol'] = sym['symbol'] - df.append(vals.set_index(['symbol', 'begins_at'])) + vals = pd.DataFrame(sym["historicals"]) + vals["begins_at"] = pd.to_datetime(vals["begins_at"]) + vals["symbol"] = sym["symbol"] + df.append(vals.set_index(["symbol", "begins_at"])) return pd.concat(df, 0) diff --git a/pandas_datareader/stooq.py b/pandas_datareader/stooq.py index 35fe8317..0895a467 100644 --- a/pandas_datareader/stooq.py +++ b/pandas_datareader/stooq.py @@ -33,25 +33,24 @@ class StooqDailyReader(_DailyBaseReader): @property def url(self): """API URL""" - return 'https://stooq.com/q/d/l/' + return "https://stooq.com/q/d/l/" - def _get_params(self, symbol, country='US'): + def _get_params(self, symbol, country="US"): symbol_parts = symbol.split(".") - if not symbol.startswith('^'): + if not symbol.startswith("^"): if len(symbol_parts) == 1: symbol = ".".join([symbol, country]) - elif symbol_parts[1].lower() == 'pl': + elif symbol_parts[1].lower() == "pl": symbol = symbol_parts[0] else: - if symbol_parts[1].lower() not in ['de', 'hk', 'hu', 'jp', - 'uk', 'us']: - symbol = ".".join([symbol, 'US']) + if symbol_parts[1].lower() not in ["de", "hk", "hu", "jp", "uk", "us"]: + symbol = ".".join([symbol, "US"]) params = { - 's': symbol, - 'i': self.freq or 'd', - 'd1': self.start.strftime('%Y%m%d'), - 'd2': self.end.strftime('%Y%m%d') + "s": symbol, + "i": self.freq or "d", + "d1": self.start.strftime("%Y%m%d"), + "d2": self.end.strftime("%Y%m%d"), } return params diff --git a/pandas_datareader/tests/io/test_jsdmx.py b/pandas_datareader/tests/io/test_jsdmx.py index 1a438200..e34c128a 100644 --- a/pandas_datareader/tests/io/test_jsdmx.py +++ b/pandas_datareader/tests/io/test_jsdmx.py @@ -18,25 +18,46 @@ def dirpath(datapath): return datapath("io", "data") -@pytest.mark.skipif(not PANDAS_0210, reason='Broken on old pandas') +@pytest.mark.skipif(not PANDAS_0210, reason="Broken on old pandas") def test_tourism(dirpath): # OECD -> Industry and Services -> Inbound Tourism - result = read_jsdmx(os.path.join(dirpath, 'jsdmx', - 'tourism.json')) + result = read_jsdmx(os.path.join(dirpath, "jsdmx", "tourism.json")) assert isinstance(result, pd.DataFrame) - jp = result['Japan'] - visitors = ['China', 'Hong Kong, China', - 'Total international arrivals', - 'Korea', 'Chinese Taipei', 'United States'] + jp = result["Japan"] + visitors = [ + "China", + "Hong Kong, China", + "Total international arrivals", + "Korea", + "Chinese Taipei", + "United States", + ] exp_col = pd.Index( - ['China', 'Hong Kong, China', 'Total international arrivals', - 'Korea', 'Chinese Taipei', 'United States'], - name='Variable') - exp_idx = pd.DatetimeIndex(['2008-01-01', '2009-01-01', '2010-01-01', - '2011-01-01', '2012-01-01', '2013-01-01', - '2014-01-01', '2015-01-01', '2016-01-01'], - name='Year') + [ + "China", + "Hong Kong, China", + "Total international arrivals", + "Korea", + "Chinese Taipei", + "United States", + ], + name="Variable", + ) + exp_idx = pd.DatetimeIndex( + [ + "2008-01-01", + "2009-01-01", + "2010-01-01", + "2011-01-01", + "2012-01-01", + "2013-01-01", + "2014-01-01", + "2015-01-01", + "2016-01-01", + ], + name="Year", + ) values = [ [1000000.0, 550000.0, 8351000.0, 2382000.0, 1390000.0, 768000.0], [1006000.0, 450000.0, 6790000.0, 1587000.0, 1024000.0, 700000.0], @@ -45,44 +66,83 @@ def test_tourism(dirpath): [1430000.0, 482000.0, 8368000.0, 2044000.0, 1467000.0, 717000.0], [1314000.0, 746000.0, 10364000.0, 2456000.0, 2211000.0, 799000.0], [2409000.0, 926000.0, 13413000.0, 2755000.0, 2830000.0, 892000.0], - [4993689.0, 1524292.0, 19737409.0, 4002095.0, 3677075.0, - 1033258.0], - [6373564.0, 1839193.0, 24039700.0, 5090302.0, 4167512.0, 1242719.0] + [4993689.0, 1524292.0, 19737409.0, 4002095.0, 3677075.0, 1033258.0], + [6373564.0, 1839193.0, 24039700.0, 5090302.0, 4167512.0, 1242719.0], ] - values = np.array(values, dtype='object') + values = np.array(values, dtype="object") expected = pd.DataFrame(values, index=exp_idx, columns=exp_col) tm.assert_frame_equal(jp[visitors], expected) -@pytest.mark.skipif(not PANDAS_0210, reason='Broken on old pandas') +@pytest.mark.skipif(not PANDAS_0210, reason="Broken on old pandas") def test_land_use(dirpath): # OECD -> Environment -> Resources Land Use - result = read_jsdmx(os.path.join(dirpath, 'jsdmx', - 'land_use.json')) + result = read_jsdmx(os.path.join(dirpath, "jsdmx", "land_use.json")) assert isinstance(result, pd.DataFrame) - result = result.loc['2010':'2011'] + result = result.loc["2010":"2011"] - cols = ['Arable land and permanent crops', - 'Arable and cropland % land area', - 'Total area', 'Forest', 'Forest % land area', - 'Land area', 'Permanent meadows and pastures', - 'Meadows and pastures % land area', 'Other areas', - 'Other % land area'] - exp_col = pd.MultiIndex.from_product([ - ['Japan', 'United States'], - cols], names=['Country', 'Variable']) - exp_idx = pd.DatetimeIndex(['2010', '2011'], name='Year') + cols = [ + "Arable land and permanent crops", + "Arable and cropland % land area", + "Total area", + "Forest", + "Forest % land area", + "Land area", + "Permanent meadows and pastures", + "Meadows and pastures % land area", + "Other areas", + "Other % land area", + ] + exp_col = pd.MultiIndex.from_product( + [["Japan", "United States"], cols], names=["Country", "Variable"] + ) + exp_idx = pd.DatetimeIndex(["2010", "2011"], name="Year") values = [ - [53790.0, 14.753154141525, 377800.0, np.nan, np.nan, 364600.0, - 5000.0, 1.3713658804169, np.nan, np.nan, - 1897990.0, 20.722767650476, 9629090.0, np.nan, np.nan, 9158960.0, - 2416000.0, 26.378540795025, np.nan, - np.nan], - [53580.0, 14.691527282698, 377800.0, np.nan, np.nan, 364700.0, - 5000.0, 1.3709898546751, np.nan, np.nan, - 1897990.0, 20.722767650476, 9629090.0, np.nan, np.nan, 9158960.0, - 2416000.0, 26.378540795025, np.nan, - np.nan]] + [ + 53790.0, + 14.753154141525, + 377800.0, + np.nan, + np.nan, + 364600.0, + 5000.0, + 1.3713658804169, + np.nan, + np.nan, + 1897990.0, + 20.722767650476, + 9629090.0, + np.nan, + np.nan, + 9158960.0, + 2416000.0, + 26.378540795025, + np.nan, + np.nan, + ], + [ + 53580.0, + 14.691527282698, + 377800.0, + np.nan, + np.nan, + 364700.0, + 5000.0, + 1.3709898546751, + np.nan, + np.nan, + 1897990.0, + 20.722767650476, + 9629090.0, + np.nan, + np.nan, + 9158960.0, + 2416000.0, + 26.378540795025, + np.nan, + np.nan, + ], + ] values = np.array(values) expected = pd.DataFrame(values, index=exp_idx, columns=exp_col) tm.assert_frame_equal(result[exp_col], expected) diff --git a/pandas_datareader/tests/io/test_sdmx.py b/pandas_datareader/tests/io/test_sdmx.py index d52e56e2..de31822c 100644 --- a/pandas_datareader/tests/io/test_sdmx.py +++ b/pandas_datareader/tests/io/test_sdmx.py @@ -7,7 +7,7 @@ import pandas.util.testing as tm import pytest -from pandas_datareader.io.sdmx import read_sdmx, _read_sdmx_dsd +from pandas_datareader.io.sdmx import _read_sdmx_dsd, read_sdmx pytestmark = pytest.mark.stable @@ -21,23 +21,20 @@ def test_tourism(dirpath): # Eurostat # Employed doctorate holders in non managerial and non professional # occupations by fields of science (%) - dsd = _read_sdmx_dsd(os.path.join(dirpath, 'sdmx', - 'DSD_cdh_e_fos.xml')) - df = read_sdmx(os.path.join(dirpath, 'sdmx', - 'cdh_e_fos.xml'), dsd=dsd) + dsd = _read_sdmx_dsd(os.path.join(dirpath, "sdmx", "DSD_cdh_e_fos.xml")) + df = read_sdmx(os.path.join(dirpath, "sdmx", "cdh_e_fos.xml"), dsd=dsd) assert isinstance(df, pd.DataFrame) assert df.shape == (2, 336) - df = df['Percentage']['Total']['Natural sciences'] - df = df[['Norway', 'Poland', 'Portugal', 'Russia']] + df = df["Percentage"]["Total"]["Natural sciences"] + df = df[["Norway", "Poland", "Portugal", "Russia"]] - exp_col = pd.MultiIndex.from_product([['Norway', 'Poland', 'Portugal', - 'Russia'], ['Annual']], - names=['GEO', 'FREQ']) - exp_idx = pd.DatetimeIndex(['2009', '2006'], name='TIME_PERIOD') + exp_col = pd.MultiIndex.from_product( + [["Norway", "Poland", "Portugal", "Russia"], ["Annual"]], names=["GEO", "FREQ"] + ) + exp_idx = pd.DatetimeIndex(["2009", "2006"], name="TIME_PERIOD") - values = np.array([[20.38, 25.1, 27.77, 38.1], - [25.49, np.nan, 39.05, np.nan]]) + values = np.array([[20.38, 25.1, 27.77, 38.1], [25.49, np.nan, 39.05, np.nan]]) expected = pd.DataFrame(values, index=exp_idx, columns=exp_col) tm.assert_frame_equal(df, expected) diff --git a/pandas_datareader/tests/test_bankofcanada.py b/pandas_datareader/tests/test_bankofcanada.py index daff9f58..fe9cb12b 100644 --- a/pandas_datareader/tests/test_bankofcanada.py +++ b/pandas_datareader/tests/test_bankofcanada.py @@ -2,68 +2,81 @@ import pytest -import pandas_datareader.data as web from pandas_datareader._utils import RemoteDataError +import pandas_datareader.data as web pytestmark = pytest.mark.stable class TestBankOfCanada(object): - @staticmethod def get_symbol(currency_code, inverted=False): if inverted: - return 'FXCAD{}'.format(currency_code) + return "FXCAD{}".format(currency_code) else: - return 'FX{}CAD'.format(currency_code) + return "FX{}CAD".format(currency_code) def check_bankofcanada_count(self, code): start, end = date.today() - timedelta(days=30), date.today() - df = web.DataReader(self.get_symbol(code), 'bankofcanada', start, end) + df = web.DataReader(self.get_symbol(code), "bankofcanada", start, end) assert df.size > 15 def check_bankofcanada_valid(self, code): symbol = self.get_symbol(code) - df = web.DataReader(symbol, 'bankofcanada', - date.today() - timedelta(days=30), date.today()) + df = web.DataReader( + symbol, "bankofcanada", date.today() - timedelta(days=30), date.today() + ) assert symbol in df.columns def check_bankofcanada_inverted(self, code): symbol = self.get_symbol(code) symbol_inverted = self.get_symbol(code, inverted=True) - df = web.DataReader(symbol, 'bankofcanada', - date.today() - timedelta(days=30), date.today()) - df_i = web.DataReader(symbol_inverted, 'bankofcanada', - date.today() - timedelta(days=30), date.today()) + df = web.DataReader( + symbol, "bankofcanada", date.today() - timedelta(days=30), date.today() + ) + df_i = web.DataReader( + symbol_inverted, + "bankofcanada", + date.today() - timedelta(days=30), + date.today(), + ) pairs = zip((1 / df)[symbol].tolist(), df_i[symbol_inverted].tolist()) assert all(a - b < 0.01 for a, b in pairs) def test_bankofcanada_usd_count(self): - self.check_bankofcanada_count('USD') + self.check_bankofcanada_count("USD") def test_bankofcanada_eur_count(self): - self.check_bankofcanada_count('EUR') + self.check_bankofcanada_count("EUR") def test_bankofcanada_usd_valid(self): - self.check_bankofcanada_valid('USD') + self.check_bankofcanada_valid("USD") def test_bankofcanada_eur_valid(self): - self.check_bankofcanada_valid('EUR') + self.check_bankofcanada_valid("EUR") def test_bankofcanada_usd_inverted(self): - self.check_bankofcanada_inverted('USD') + self.check_bankofcanada_inverted("USD") def test_bankofcanada_eur_inverted(self): - self.check_bankofcanada_inverted('EUR') + self.check_bankofcanada_inverted("EUR") def test_bankofcanada_bad_range(self): with pytest.raises(ValueError): - web.DataReader('FXCADUSD', 'bankofcanada', - date.today(), date.today() - timedelta(days=30)) + web.DataReader( + "FXCADUSD", + "bankofcanada", + date.today(), + date.today() - timedelta(days=30), + ) def test_bankofcanada_bad_url(self): with pytest.raises(RemoteDataError): - web.DataReader('abcdefgh', 'bankofcanada', - date.today() - timedelta(days=30), date.today()) + web.DataReader( + "abcdefgh", + "bankofcanada", + date.today() - timedelta(days=30), + date.today(), + ) diff --git a/pandas_datareader/tests/test_base.py b/pandas_datareader/tests/test_base.py index 11103592..2154b81b 100644 --- a/pandas_datareader/tests/test_base.py +++ b/pandas_datareader/tests/test_base.py @@ -8,11 +8,11 @@ class TestBaseReader(object): def test_requests_not_monkey_patched(self): - assert not hasattr(requests.Session(), 'stor') + assert not hasattr(requests.Session(), "stor") def test_valid_retry_count(self): with pytest.raises(ValueError): - base._BaseReader([], retry_count='stuff') + base._BaseReader([], retry_count="stuff") with pytest.raises(ValueError): base._BaseReader([], retry_count=-1) @@ -23,8 +23,8 @@ def test_invalid_url(self): def test_invalid_format(self): with pytest.raises(NotImplementedError): b = base._BaseReader([]) - b._format = 'IM_NOT_AN_IMPLEMENTED_TYPE' - b._read_one_data('a', None) + b._format = "IM_NOT_AN_IMPLEMENTED_TYPE" + b._read_one_data("a", None) class TestDailyBaseReader(object): diff --git a/pandas_datareader/tests/test_data.py b/pandas_datareader/tests/test_data.py index 9ecd970c..4cdea1d0 100644 --- a/pandas_datareader/tests/test_data.py +++ b/pandas_datareader/tests/test_data.py @@ -1,13 +1,12 @@ +from pandas import DataFrame import pytest -from pandas import DataFrame from pandas_datareader.data import DataReader pytestmark = pytest.mark.stable class TestDataReader(object): - def test_read_iex(self): gs = DataReader("GS", "iex-last") assert isinstance(gs, DataFrame) diff --git a/pandas_datareader/tests/test_econdb.py b/pandas_datareader/tests/test_econdb.py index 1b960dfb..58bb48cc 100644 --- a/pandas_datareader/tests/test_econdb.py +++ b/pandas_datareader/tests/test_econdb.py @@ -8,16 +8,16 @@ class TestEcondb(object): - def test_get_cdh_e_fos(self): # EUROSTAT # Employed doctorate holders in non managerial and non professional # occupations by fields of science (%) df = web.DataReader( - 'dataset=CDH_E_FOS&GEO=NO,PL,PT,RU&FOS07=FOS1&Y_GRAD=TOTAL', - 'econdb', - start=pd.Timestamp('2005-01-01'), - end=pd.Timestamp('2010-01-01')) + "dataset=CDH_E_FOS&GEO=NO,PL,PT,RU&FOS07=FOS1&Y_GRAD=TOTAL", + "econdb", + start=pd.Timestamp("2005-01-01"), + end=pd.Timestamp("2010-01-01"), + ) assert isinstance(df, pd.DataFrame) assert df.shape == (2, 4) @@ -26,11 +26,9 @@ def test_get_cdh_e_fos(self): levels = [lvl.values.tolist() for lvl in list(df.columns.levels)] exp_col = pd.MultiIndex.from_product(levels, names=names) - exp_idx = pd.DatetimeIndex(['2006-01-01', '2009-01-01'], - name='TIME_PERIOD') + exp_idx = pd.DatetimeIndex(["2006-01-01", "2009-01-01"], name="TIME_PERIOD") - values = np.array([[25.49, np.nan, 39.05, np.nan], - [20.38, 25.1, 27.77, 38.1]]) + values = np.array([[25.49, np.nan, 39.05, np.nan], [20.38, 25.1, 27.77, 38.1]]) expected = pd.DataFrame(values, index=exp_idx, columns=exp_col) tm.assert_frame_equal(df, expected) @@ -39,32 +37,36 @@ def test_get_tourism(self): # TOURISM_INBOUND df = web.DataReader( - 'dataset=OE_TOURISM_INBOUND&COUNTRY=JPN,USA&' - 'VARIABLE=INB_ARRIVALS_TOTAL', 'econdb', - start=pd.Timestamp('2008-01-01'), end=pd.Timestamp('2012-01-01')) + "dataset=OE_TOURISM_INBOUND&COUNTRY=JPN,USA&VARIABLE=INB_ARRIVALS_TOTAL", + "econdb", + start=pd.Timestamp("2008-01-01"), + end=pd.Timestamp("2012-01-01"), + ) df = df.astype(np.float) - jp = np.array([8351000, 6790000, 8611000, 6219000, - 8368000], dtype=float) - us = np.array([175702304, 160507424, 164079728, 167600272, - 171320416], dtype=float) - index = pd.date_range('2008-01-01', '2012-01-01', freq='AS', - name='TIME_PERIOD') + jp = np.array([8351000, 6790000, 8611000, 6219000, 8368000], dtype=float) + us = np.array( + [175702304, 160507424, 164079728, 167600272, 171320416], dtype=float + ) + index = pd.date_range("2008-01-01", "2012-01-01", freq="AS", name="TIME_PERIOD") # sometimes the country and variable columns are swapped lvl1 = df.columns.levels[0][0] if lvl1 == "Total international arrivals": df = df.swaplevel(0, 1, axis=1) - for label, values in [('Japan', jp), ('United States', us)]: - expected = pd.Series(values, index=index, - name='Total international arrivals') - tm.assert_series_equal(df[label]['Total international arrivals'], - expected) + for label, values in [("Japan", jp), ("United States", us)]: + expected = pd.Series( + values, index=index, name="Total international arrivals" + ) + tm.assert_series_equal(df[label]["Total international arrivals"], expected) def test_bls(self): # BLS # CPI df = web.DataReader( - 'ticker=BLS_CU.CUSR0000SA0.M.US', 'econdb', - start=pd.Timestamp('2010-01-01'), end=pd.Timestamp('2013-01-27')) + "ticker=BLS_CU.CUSR0000SA0.M.US", + "econdb", + start=pd.Timestamp("2010-01-01"), + end=pd.Timestamp("2013-01-27"), + ) - assert df.loc['2010-05-01'][0] == 217.3 + assert df.loc["2010-05-01"][0] == 217.3 diff --git a/pandas_datareader/tests/test_enigma.py b/pandas_datareader/tests/test_enigma.py index 46340d5d..7c34e964 100644 --- a/pandas_datareader/tests/test_enigma.py +++ b/pandas_datareader/tests/test_enigma.py @@ -1,7 +1,8 @@ import os -import pytest +import pytest from requests.exceptions import HTTPError + import pandas_datareader as pdr import pandas_datareader.data as web @@ -12,7 +13,6 @@ @pytest.mark.skipif(TEST_API_KEY is None, reason="no enigma_api_key") class TestEnigma(object): - @property def dataset_id(self): """ @@ -28,25 +28,24 @@ def setup_class(cls): def test_enigma_datareader(self): try: - df = web.DataReader(self.dataset_id, - 'enigma', access_key=TEST_API_KEY) - assert 'case_number' in df.columns + df = web.DataReader(self.dataset_id, "enigma", access_key=TEST_API_KEY) + assert "case_number" in df.columns except HTTPError as e: pytest.skip(e) def test_enigma_get_data_enigma(self): try: df = pdr.get_data_enigma(self.dataset_id, TEST_API_KEY) - assert 'case_number' in df.columns + assert "case_number" in df.columns except HTTPError as e: pytest.skip(e) def test_bad_key(self): with pytest.raises(HTTPError): - web.DataReader(self.dataset_id, - 'enigma', access_key=TEST_API_KEY + 'xxx') + web.DataReader(self.dataset_id, "enigma", access_key=TEST_API_KEY + "xxx") def test_bad_dataset_id(self): with pytest.raises(HTTPError): - web.DataReader('zzzzzzzz-zzzz-zzzz-zzzz-zzzzzzzzzzz', - 'enigma', access_key=TEST_API_KEY) + web.DataReader( + "zzzzzzzz-zzzz-zzzz-zzzz-zzzzzzzzzzz", "enigma", access_key=TEST_API_KEY + ) diff --git a/pandas_datareader/tests/test_famafrench.py b/pandas_datareader/tests/test_famafrench.py index 08335968..c73d476b 100644 --- a/pandas_datareader/tests/test_famafrench.py +++ b/pandas_datareader/tests/test_famafrench.py @@ -1,7 +1,6 @@ -import pytest - import pandas as pd import pandas.util.testing as tm +import pytest import pandas_datareader.data as web from pandas_datareader.famafrench import get_available_datasets @@ -10,17 +9,19 @@ class TestFamaFrench(object): - def test_get_data(self): keys = [ - 'F-F_Research_Data_Factors', 'F-F_ST_Reversal_Factor', - '6_Portfolios_2x3', 'Portfolios_Formed_on_ME', - 'Prior_2-12_Breakpoints', 'ME_Breakpoints', + "F-F_Research_Data_Factors", + "F-F_ST_Reversal_Factor", + "6_Portfolios_2x3", + "Portfolios_Formed_on_ME", + "Prior_2-12_Breakpoints", + "ME_Breakpoints", ] for name in keys: - ff = web.DataReader(name, 'famafrench') - assert 'DESCR' in ff + ff = web.DataReader(name, "famafrench") + assert "DESCR" in ff assert len(ff) > 1 def test_get_available_datasets(self): @@ -29,63 +30,159 @@ def test_get_available_datasets(self): assert len(avail) > 100 def test_index(self): - ff = web.DataReader('F-F_Research_Data_Factors', 'famafrench') - assert ff[0].index.freq == 'M' - assert ff[1].index.freq == 'A-DEC' + ff = web.DataReader("F-F_Research_Data_Factors", "famafrench") + assert ff[0].index.freq == "M" + assert ff[1].index.freq == "A-DEC" def test_f_f_research(self): - results = web.DataReader("F-F_Research_Data_Factors", "famafrench", - start='2010-01-01', end='2010-12-01') + results = web.DataReader( + "F-F_Research_Data_Factors", + "famafrench", + start="2010-01-01", + end="2010-12-01", + ) assert isinstance(results, dict) assert len(results) == 3 - exp = pd.DataFrame({'Mkt-RF': [-3.36, 3.4, 6.31, 2., -7.89, -5.56, - 6.93, -4.77, 9.54, 3.88, 0.6, 6.82], - 'SMB': [0.38, 1.2, 1.42, 4.98, 0.05, -1.97, 0.16, - -3.00, 3.92, 1.15, 3.70, 0.7], - 'HML': [0.31, 3.16, 2.1, 2.81, -2.38, -4.5, -0.27, - -1.95, -3.12, -2.59, -0.9, 3.81], - 'RF': [0., 0., 0.01, 0.01, 0.01, 0.01, 0.01, - 0.01, 0.01, 0.01, 0.01, 0.01]}, - index=pd.period_range('2010-01-01', '2010-12-01', - freq='M', name='Date'), - columns=['Mkt-RF', 'SMB', 'HML', 'RF']) + exp = pd.DataFrame( + { + "Mkt-RF": [ + -3.36, + 3.4, + 6.31, + 2.0, + -7.89, + -5.56, + 6.93, + -4.77, + 9.54, + 3.88, + 0.6, + 6.82, + ], + "SMB": [ + 0.38, + 1.2, + 1.42, + 4.98, + 0.05, + -1.97, + 0.16, + -3.00, + 3.92, + 1.15, + 3.70, + 0.7, + ], + "HML": [ + 0.31, + 3.16, + 2.1, + 2.81, + -2.38, + -4.5, + -0.27, + -1.95, + -3.12, + -2.59, + -0.9, + 3.81, + ], + "RF": [ + 0.0, + 0.0, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + ], + }, + index=pd.period_range("2010-01-01", "2010-12-01", freq="M", name="Date"), + columns=["Mkt-RF", "SMB", "HML", "RF"], + ) tm.assert_frame_equal(results[0], exp, check_less_precise=0) def test_me_breakpoints(self): - results = web.DataReader("ME_Breakpoints", "famafrench", - start='2010-01-01', end='2010-12-01') + results = web.DataReader( + "ME_Breakpoints", "famafrench", start="2010-01-01", end="2010-12-01" + ) assert isinstance(results, dict) assert len(results) == 2 assert results[0].shape == (12, 21) - exp_columns = pd.Index(['Count', (0, 5), (5, 10), (10, 15), (15, 20), - (20, 25), (25, 30), (30, 35), (35, 40), - (40, 45), (45, 50), (50, 55), (55, 60), - (60, 65), (65, 70), (70, 75), (75, 80), - (80, 85), (85, 90), (90, 95), (95, 100)], - dtype='object') + exp_columns = pd.Index( + [ + "Count", + (0, 5), + (5, 10), + (10, 15), + (15, 20), + (20, 25), + (25, 30), + (30, 35), + (35, 40), + (40, 45), + (45, 50), + (50, 55), + (55, 60), + (60, 65), + (65, 70), + (70, 75), + (75, 80), + (80, 85), + (85, 90), + (90, 95), + (95, 100), + ], + dtype="object", + ) tm.assert_index_equal(results[0].columns, exp_columns) - exp_index = pd.period_range('2010-01-01', '2010-12-01', - freq='M', name='Date') + exp_index = pd.period_range("2010-01-01", "2010-12-01", freq="M", name="Date") tm.assert_index_equal(results[0].index, exp_index) def test_prior_2_12_breakpoints(self): - results = web.DataReader("Prior_2-12_Breakpoints", "famafrench", - start='2010-01-01', end='2010-12-01') + results = web.DataReader( + "Prior_2-12_Breakpoints", "famafrench", start="2010-01-01", end="2010-12-01" + ) assert isinstance(results, dict) assert len(results) == 2 assert results[0].shape == (12, 22) - exp_columns = pd.Index(['<=0', '>0', (0, 5), (5, 10), (10, 15), - (15, 20), (20, 25), (25, 30), (30, 35), - (35, 40), (40, 45), (45, 50), (50, 55), - (55, 60), (60, 65), (65, 70), (70, 75), - (75, 80), (80, 85), (85, 90), (90, 95), - (95, 100)], dtype='object') + exp_columns = pd.Index( + [ + "<=0", + ">0", + (0, 5), + (5, 10), + (10, 15), + (15, 20), + (20, 25), + (25, 30), + (30, 35), + (35, 40), + (40, 45), + (45, 50), + (50, 55), + (55, 60), + (60, 65), + (65, 70), + (70, 75), + (75, 80), + (80, 85), + (85, 90), + (90, 95), + (95, 100), + ], + dtype="object", + ) tm.assert_index_equal(results[0].columns, exp_columns) - exp_index = pd.period_range('2010-01-01', '2010-12-01', - freq='M', name='Date') + exp_index = pd.period_range("2010-01-01", "2010-12-01", freq="M", name="Date") tm.assert_index_equal(results[0].index, exp_index) diff --git a/pandas_datareader/tests/test_fred.py b/pandas_datareader/tests/test_fred.py index d55782a4..158721be 100644 --- a/pandas_datareader/tests/test_fred.py +++ b/pandas_datareader/tests/test_fred.py @@ -1,20 +1,18 @@ from datetime import datetime -import pytest - import numpy as np import pandas as pd +from pandas import DataFrame import pandas.util.testing as tm -import pandas_datareader.data as web +import pytest -from pandas import DataFrame from pandas_datareader._utils import RemoteDataError +import pandas_datareader.data as web pytestmark = pytest.mark.stable class TestFred(object): - def test_fred(self): # Raises an exception when DataReader can't @@ -24,7 +22,7 @@ def test_fred(self): end = datetime(2013, 1, 1) df = web.DataReader("GDP", "fred", start, end) - ts = df['GDP'] + ts = df["GDP"] assert ts.index[0] == pd.to_datetime("2010-01-01") assert ts.index[-1] == pd.to_datetime("2013-01-01") @@ -33,30 +31,26 @@ def test_fred(self): assert len(ts) == 13 with pytest.raises(RemoteDataError): - web.DataReader("NON EXISTENT SERIES", 'fred', start, end) + web.DataReader("NON EXISTENT SERIES", "fred", start, end) def test_fred_nan(self): start = datetime(2010, 1, 1) end = datetime(2013, 1, 27) df = web.DataReader("DFII5", "fred", start, end) - assert pd.isnull(df.loc['2010-01-01'][0]) + assert pd.isnull(df.loc["2010-01-01"][0]) def test_fred_parts(self): # pragma: no cover start = datetime(2010, 1, 1) end = datetime(2013, 1, 27) df = web.get_data_fred("CPIAUCSL", start, end) - assert df.loc['2010-05-01'][0] == 217.29 + assert df.loc["2010-05-01"][0] == 217.29 t = df.CPIAUCSL.values assert np.issubdtype(t.dtype, np.floating) assert t.shape == (37,) def test_fred_part2(self): - expected = [[576.7], - [962.9], - [684.7], - [848.3], - [933.3]] + expected = [[576.7], [962.9], [684.7], [848.3], [933.3]] result = web.get_data_fred("A09024USA144NNBR", start="1915").iloc[:5] tm.assert_numpy_array_equal(result.values, np.array(expected)) @@ -66,18 +60,21 @@ def test_invalid_series(self): web.get_data_fred(name) def test_fred_multi(self): # pragma: no cover - names = ['CPIAUCSL', 'CPALTT01USQ661S', 'CPILFESL'] + names = ["CPIAUCSL", "CPALTT01USQ661S", "CPILFESL"] start = datetime(2010, 1, 1) end = datetime(2013, 1, 27) received = web.DataReader(names, "fred", start, end).head(1) - expected = DataFrame([[217.488, 91.712409, 220.633]], columns=names, - index=[pd.Timestamp('2010-01-01 00:00:00')]) - expected.index.rename('DATE', inplace=True) + expected = DataFrame( + [[217.488, 91.712409, 220.633]], + columns=names, + index=[pd.Timestamp("2010-01-01 00:00:00")], + ) + expected.index.rename("DATE", inplace=True) tm.assert_frame_equal(received, expected, check_less_precise=True) def test_fred_multi_bad_series(self): - names = ['NOTAREALSERIES', 'CPIAUCSL', "ALSO FAKE"] + names = ["NOTAREALSERIES", "CPIAUCSL", "ALSO FAKE"] with pytest.raises(RemoteDataError): web.DataReader(names, data_source="fred") diff --git a/pandas_datareader/tests/test_iex.py b/pandas_datareader/tests/test_iex.py index f3933aad..6cc4b2d5 100644 --- a/pandas_datareader/tests/test_iex.py +++ b/pandas_datareader/tests/test_iex.py @@ -1,11 +1,16 @@ from datetime import datetime -import pytest from pandas import DataFrame +import pytest -from pandas_datareader.data import (DataReader, get_summary_iex, get_last_iex, - get_dailysummary_iex, get_iex_symbols, - get_iex_book) +from pandas_datareader.data import ( + DataReader, + get_dailysummary_iex, + get_iex_book, + get_iex_symbols, + get_last_iex, + get_summary_iex, +) from pandas_datareader.exceptions import UnstableAPIWarning @@ -19,25 +24,26 @@ def test_read_iex(self): assert isinstance(gs, DataFrame) def test_historical(self): - df = get_summary_iex(start=datetime(2017, 4, 1), - end=datetime(2017, 4, 30)) + df = get_summary_iex(start=datetime(2017, 4, 1), end=datetime(2017, 4, 30)) assert df.T["averageDailyVolume"].iloc[0] == 137650908.9 def test_false_ticker(self): df = get_last_iex("INVALID TICKER") assert df.shape[0] == 0 - @pytest.mark.xfail(reason='IEX daily history API is returning 500 as of ' - 'Jan 2018') + @pytest.mark.xfail( + reason="IEX daily history API is returning 500 as of " "Jan 2018" + ) def test_daily(self): with pytest.warns(UnstableAPIWarning): - df = get_dailysummary_iex(start=datetime(2017, 5, 5), - end=datetime(2017, 5, 6)) - assert df['routedVolume'].iloc[0] == 39974788 + df = get_dailysummary_iex( + start=datetime(2017, 5, 5), end=datetime(2017, 5, 6) + ) + assert df["routedVolume"].iloc[0] == 39974788 def test_symbols(self): df = get_iex_symbols() - assert 'GS' in df.symbol.values + assert "GS" in df.symbol.values def test_live_prices(self): dftickers = get_iex_symbols() @@ -46,8 +52,8 @@ def test_live_prices(self): assert df["price"].mean() > 0 def test_deep(self): - dob = get_iex_book('GS', service='book') + dob = get_iex_book("GS", service="book") if dob: - assert 'GS' in dob + assert "GS" in dob else: - pytest.xfail(reason='Can only get Book when market open') + pytest.xfail(reason="Can only get Book when market open") diff --git a/pandas_datareader/tests/test_iex_daily.py b/pandas_datareader/tests/test_iex_daily.py index e95ca82d..6df487d5 100644 --- a/pandas_datareader/tests/test_iex_daily.py +++ b/pandas_datareader/tests/test_iex_daily.py @@ -1,18 +1,17 @@ from datetime import date, datetime, timedelta - import os from pandas import DataFrame, MultiIndex - import pytest import pandas_datareader.data as web from pandas_datareader.iex.daily import IEXDailyReader -@pytest.mark.skipif(os.getenv("IEX_SANDBOX") != 'enable', - reason='All tests must be run in sandbox mode') -class TestIEXDaily(object): +@pytest.mark.skipif( + os.getenv("IEX_SANDBOX") != "enable", reason="All tests must be run in sandbox mode" +) +class TestIEXDaily(object): @classmethod def setup_class(cls): pytest.importorskip("lxml") @@ -31,8 +30,7 @@ def test_iex_bad_symbol(self): def test_iex_bad_symbol_list(self): with pytest.raises(Exception): - web.DataReader(["AAPL", "BADTICKER"], "iex", - self.start, self.end) + web.DataReader(["AAPL", "BADTICKER"], "iex", self.start, self.end) def test_daily_invalid_date(self): start = datetime(2000, 1, 5) @@ -50,7 +48,7 @@ def test_multiple_symbols(self): df = web.DataReader(syms, "iex", self.start, self.end) assert sorted(list(df.columns.levels[1])) == syms for sym in syms: - assert len(df.xs(sym, level='Symbols', axis=1) == 578) + assert len(df.xs(sym, level="Symbols", axis=1) == 578) def test_multiple_symbols_2(self): syms = ["AAPL", "MSFT", "TSLA"] @@ -62,8 +60,8 @@ def test_multiple_symbols_2(self): assert len(df.columns.levels[1]) == 3 assert sorted(list(df.columns.levels[1])) == syms - a = df.xs("AAPL", axis=1, level='Symbols') - t = df.xs("TSLA", axis=1, level='Symbols') + a = df.xs("AAPL", axis=1, level="Symbols") + t = df.xs("TSLA", axis=1, level="Symbols") assert len(a) == 73 assert len(t) == 73 @@ -71,35 +69,53 @@ def test_multiple_symbols_2(self): def test_range_string_from_date(self): syms = ["AAPL"] - assert IEXDailyReader(symbols=syms, - start=date.today() - timedelta(days=5), - end=date.today() - )._range_string_from_date() == '5d' - assert IEXDailyReader(symbols=syms, - start=date.today() - timedelta(days=27), - end=date.today() - )._range_string_from_date() == '1m' - assert IEXDailyReader(symbols=syms, - start=date.today() - timedelta(days=83), - end=date.today() - )._range_string_from_date() == '3m' - assert IEXDailyReader(symbols=syms, - start=date.today() - timedelta(days=167), - end=date.today() - )._range_string_from_date() == '6m' - assert IEXDailyReader(symbols=syms, - start=date.today() - timedelta(days=170), - end=date.today() - )._range_string_from_date() == '1y' - assert IEXDailyReader(symbols=syms, - start=date.today() - timedelta(days=365), - end=date.today() - )._range_string_from_date() == '2y' - assert IEXDailyReader(symbols=syms, - start=date.today() - timedelta(days=730), - end=date.today() - )._range_string_from_date() == '5y' - assert IEXDailyReader(symbols=syms, - start=date.today() - timedelta(days=1826), - end=date.today() - )._range_string_from_date() == 'max' + assert ( + IEXDailyReader( + symbols=syms, start=date.today() - timedelta(days=5), end=date.today() + )._range_string_from_date() + == "5d" + ) + assert ( + IEXDailyReader( + symbols=syms, start=date.today() - timedelta(days=27), end=date.today() + )._range_string_from_date() + == "1m" + ) + assert ( + IEXDailyReader( + symbols=syms, start=date.today() - timedelta(days=83), end=date.today() + )._range_string_from_date() + == "3m" + ) + assert ( + IEXDailyReader( + symbols=syms, start=date.today() - timedelta(days=167), end=date.today() + )._range_string_from_date() + == "6m" + ) + assert ( + IEXDailyReader( + symbols=syms, start=date.today() - timedelta(days=170), end=date.today() + )._range_string_from_date() + == "1y" + ) + assert ( + IEXDailyReader( + symbols=syms, start=date.today() - timedelta(days=365), end=date.today() + )._range_string_from_date() + == "2y" + ) + assert ( + IEXDailyReader( + symbols=syms, start=date.today() - timedelta(days=730), end=date.today() + )._range_string_from_date() + == "5y" + ) + assert ( + IEXDailyReader( + symbols=syms, + start=date.today() - timedelta(days=1826), + end=date.today(), + )._range_string_from_date() + == "max" + ) diff --git a/pandas_datareader/tests/test_moex.py b/pandas_datareader/tests/test_moex.py index 5a26970b..5f46ac3e 100644 --- a/pandas_datareader/tests/test_moex.py +++ b/pandas_datareader/tests/test_moex.py @@ -1,6 +1,6 @@ import pytest - from requests.exceptions import HTTPError + import pandas_datareader.data as web pytestmark = pytest.mark.stable @@ -9,10 +9,9 @@ class TestMoex(object): def test_moex_datareader(self): try: - df = web.DataReader("USD000UTSTOM", - "moex", - start="2017-07-01", - end="2017-07-31") - assert 'SECID' in df.columns + df = web.DataReader( + "USD000UTSTOM", "moex", start="2017-07-01", end="2017-07-31" + ) + assert "SECID" in df.columns except HTTPError as e: pytest.skip(e) diff --git a/pandas_datareader/tests/test_nasdaq.py b/pandas_datareader/tests/test_nasdaq.py index d58e8869..659755c4 100644 --- a/pandas_datareader/tests/test_nasdaq.py +++ b/pandas_datareader/tests/test_nasdaq.py @@ -1,11 +1,10 @@ -import pandas_datareader.data as web -from pandas_datareader._utils import RemoteDataError from pandas_datareader._testing import skip_on_exception +from pandas_datareader._utils import RemoteDataError +import pandas_datareader.data as web class TestNasdaqSymbols(object): - @skip_on_exception(RemoteDataError) def test_get_symbols(self): - symbols = web.DataReader('symbols', 'nasdaq') - assert 'IBM' in symbols.index + symbols = web.DataReader("symbols", "nasdaq") + assert "IBM" in symbols.index diff --git a/pandas_datareader/tests/test_oecd.py b/pandas_datareader/tests/test_oecd.py index cd358a90..ca4a60ff 100644 --- a/pandas_datareader/tests/test_oecd.py +++ b/pandas_datareader/tests/test_oecd.py @@ -5,87 +5,210 @@ import pandas.util.testing as tm import pytest -import pandas_datareader.data as web from pandas_datareader._utils import RemoteDataError +import pandas_datareader.data as web class TestOECD(object): - - @pytest.mark.xfail(reason='Incorrect URL') + @pytest.mark.xfail(reason="Incorrect URL") def test_get_un_den(self): - df = web.DataReader('TUD', 'oecd', start=datetime(1960, 1, 1), - end=datetime(2012, 1, 1)) + df = web.DataReader( + "TUD", "oecd", start=datetime(1960, 1, 1), end=datetime(2012, 1, 1) + ) - au = [50.17292785, 49.47181009, 49.52106174, 49.16341327, - 48.19296375, 47.8863461, 45.83517292, 45.02021403, - 44.78983834, 44.37794217, 44.15358142, 45.38865546, - 46.33092037, 47.2343406, 48.80023876, 50.0639872, - 50.23390644, 49.8214994, 49.67636585, 49.55227375, - 48.48657368, 47.41179739, 47.52526561, 47.93048854, - 47.26327162, 45.4617105, 43.90202112, 42.32759607, - 40.35838899, 39.35157364, 39.55023059, 39.93212859, - 39.21948472, 37.24343693, 34.42549573, 32.51172639, - 31.16809569, 29.78835077, 28.14657769, 25.41970706, - 25.71752984, 24.53108811, 23.21936888, 22.99140633, - 22.29380238, 22.29160819, 20.22236326, 18.51151852, - 18.56792804, 19.31219498, 18.44405734, 18.51105048, - 18.19718895] - jp = [32.32911392, 33.73688458, 34.5969919, 35.01871257, - 35.46869345, 35.28164117, 34.749499, 34.40573103, - 34.50762389, 35.16411379, 35.10284332, 34.57209848, - 34.31168831, 33.46611342, 34.26450371, 34.53099287, - 33.69881466, 32.99814274, 32.59541985, 31.75696594, - 31.14832536, 30.8917513, 30.56612982, 29.75285171, - 29.22391559, 28.79202411, 28.18680064, 27.71454381, - 26.94358748, 26.13165206, 25.36711479, 24.78408637, - 24.49892557, 24.34256055, 24.25324675, 23.96731902, - 23.3953401, 22.78797997, 22.52794337, 22.18157944, - 21.54406273, 20.88284597, 20.26073907, 19.73945642, - 19.25116713, 18.79844243, 18.3497807, 18.25095057, - 18.2204924, 18.45787546, 18.40380743, 18.99504195, - 17.97238372] - us = [30.89748411, 29.51891217, 29.34276869, 28.51337535, - 28.30646144, 28.16661991, 28.19557735, 27.76578899, - 27.9004622, 27.30836054, 27.43402867, 26.94941363, - 26.25996487, 25.83134349, 25.74427582, 25.28771204, - 24.38412814, 23.59186681, 23.94328194, 22.36400776, - 22.06009466, 21.01328205, 20.47463895, 19.45290876, - 18.22953818, 17.44855678, 17.00126975, 16.5162476, - 16.24744487, 15.86401127, 15.45147174, 15.46986912, - 15.1499578, 15.13654544, 14.91544059, 14.31762091, - 14.02052225, 13.55213736, 13.39571457, 13.36670812, - 12.84874656, 12.85719022, 12.63753733, 12.39142968, - 12.02130767, 11.96023574, 11.48458378, 11.56435375, - 11.91022276, 11.79401904, 11.38345975, 11.32948829, - 10.81535229] + au = [ + 50.17292785, + 49.47181009, + 49.52106174, + 49.16341327, + 48.19296375, + 47.8863461, + 45.83517292, + 45.02021403, + 44.78983834, + 44.37794217, + 44.15358142, + 45.38865546, + 46.33092037, + 47.2343406, + 48.80023876, + 50.0639872, + 50.23390644, + 49.8214994, + 49.67636585, + 49.55227375, + 48.48657368, + 47.41179739, + 47.52526561, + 47.93048854, + 47.26327162, + 45.4617105, + 43.90202112, + 42.32759607, + 40.35838899, + 39.35157364, + 39.55023059, + 39.93212859, + 39.21948472, + 37.24343693, + 34.42549573, + 32.51172639, + 31.16809569, + 29.78835077, + 28.14657769, + 25.41970706, + 25.71752984, + 24.53108811, + 23.21936888, + 22.99140633, + 22.29380238, + 22.29160819, + 20.22236326, + 18.51151852, + 18.56792804, + 19.31219498, + 18.44405734, + 18.51105048, + 18.19718895, + ] + jp = [ + 32.32911392, + 33.73688458, + 34.5969919, + 35.01871257, + 35.46869345, + 35.28164117, + 34.749499, + 34.40573103, + 34.50762389, + 35.16411379, + 35.10284332, + 34.57209848, + 34.31168831, + 33.46611342, + 34.26450371, + 34.53099287, + 33.69881466, + 32.99814274, + 32.59541985, + 31.75696594, + 31.14832536, + 30.8917513, + 30.56612982, + 29.75285171, + 29.22391559, + 28.79202411, + 28.18680064, + 27.71454381, + 26.94358748, + 26.13165206, + 25.36711479, + 24.78408637, + 24.49892557, + 24.34256055, + 24.25324675, + 23.96731902, + 23.3953401, + 22.78797997, + 22.52794337, + 22.18157944, + 21.54406273, + 20.88284597, + 20.26073907, + 19.73945642, + 19.25116713, + 18.79844243, + 18.3497807, + 18.25095057, + 18.2204924, + 18.45787546, + 18.40380743, + 18.99504195, + 17.97238372, + ] + us = [ + 30.89748411, + 29.51891217, + 29.34276869, + 28.51337535, + 28.30646144, + 28.16661991, + 28.19557735, + 27.76578899, + 27.9004622, + 27.30836054, + 27.43402867, + 26.94941363, + 26.25996487, + 25.83134349, + 25.74427582, + 25.28771204, + 24.38412814, + 23.59186681, + 23.94328194, + 22.36400776, + 22.06009466, + 21.01328205, + 20.47463895, + 19.45290876, + 18.22953818, + 17.44855678, + 17.00126975, + 16.5162476, + 16.24744487, + 15.86401127, + 15.45147174, + 15.46986912, + 15.1499578, + 15.13654544, + 14.91544059, + 14.31762091, + 14.02052225, + 13.55213736, + 13.39571457, + 13.36670812, + 12.84874656, + 12.85719022, + 12.63753733, + 12.39142968, + 12.02130767, + 11.96023574, + 11.48458378, + 11.56435375, + 11.91022276, + 11.79401904, + 11.38345975, + 11.32948829, + 10.81535229, + ] - index = pd.date_range('1960-01-01', '2012-01-01', freq='AS', - name='Time') - for label, values in [('Australia', au), ('Japan', jp), - ('United States', us)]: + index = pd.date_range("1960-01-01", "2012-01-01", freq="AS", name="Time") + for label, values in [("Australia", au), ("Japan", jp), ("United States", us)]: expected = pd.Series(values, index=index, name=label) tm.assert_series_equal(df[label], expected) def test_get_tourism(self): - df = web.DataReader('TOURISM_INBOUND', 'oecd', - start=datetime(2008, 1, 1), - end=datetime(2012, 1, 1)) + df = web.DataReader( + "TOURISM_INBOUND", + "oecd", + start=datetime(2008, 1, 1), + end=datetime(2012, 1, 1), + ) - jp = np.array([8351000, 6790000, 8611000, 6219000, - 8368000], dtype=float) - us = np.array([175702309, 160507417, 164079732, 167600277, - 171320408], dtype=float) - index = pd.date_range('2008-01-01', '2012-01-01', freq='AS', - name='Year') - for label, values in [('Japan', jp), ('United States', us)]: - expected = pd.Series(values, index=index, - name='Total international arrivals') - tm.assert_series_equal(df[label]['Total international arrivals'], - expected) + jp = np.array([8351000, 6790000, 8611000, 6219000, 8368000], dtype=float) + us = np.array( + [175702309, 160507417, 164079732, 167600277, 171320408], dtype=float + ) + index = pd.date_range("2008-01-01", "2012-01-01", freq="AS", name="Year") + for label, values in [("Japan", jp), ("United States", us)]: + expected = pd.Series( + values, index=index, name="Total international arrivals" + ) + tm.assert_series_equal(df[label]["Total international arrivals"], expected) def test_oecd_invalid_symbol(self): with pytest.raises(RemoteDataError): - web.DataReader('INVALID_KEY', 'oecd') + web.DataReader("INVALID_KEY", "oecd") with pytest.raises(ValueError): - web.DataReader(1234, 'oecd') + web.DataReader(1234, "oecd") diff --git a/pandas_datareader/tests/test_robinhood.py b/pandas_datareader/tests/test_robinhood.py index faab2000..8289e78e 100644 --- a/pandas_datareader/tests/test_robinhood.py +++ b/pandas_datareader/tests/test_robinhood.py @@ -2,14 +2,13 @@ import pandas as pd import pytest -from pandas_datareader.robinhood import RobinhoodQuoteReader, \ - RobinhoodHistoricalReader +from pandas_datareader.robinhood import RobinhoodHistoricalReader, RobinhoodQuoteReader -syms = ['GOOG', ['GOOG', 'AAPL']] +syms = ["GOOG", ["GOOG", "AAPL"]] ids = list(map(str, syms)) -@pytest.fixture(params=['GOOG', ['GOOG', 'AAPL']], ids=ids) +@pytest.fixture(params=["GOOG", ["GOOG", "AAPL"]], ids=ids) def symbols(request): return request.param @@ -26,7 +25,7 @@ def test_robinhood_quote(symbols): @pytest.mark.xfail(reason="Deprecated") def test_robinhood_quote_too_many(): syms = np.random.randint(65, 90, size=(10000, 4)).tolist() - syms = list(map(lambda r: ''.join(map(chr, r)), syms)) + syms = list(map(lambda r: "".join(map(chr, r)), syms)) syms = list(set(syms)) with pytest.raises(ValueError): RobinhoodQuoteReader(symbols=syms) @@ -35,7 +34,7 @@ def test_robinhood_quote_too_many(): @pytest.mark.xfail(reason="Deprecated") def test_robinhood_historical_too_many(): syms = np.random.randint(65, 90, size=(10000, 4)).tolist() - syms = list(map(lambda r: ''.join(map(chr, r)), syms)) + syms = list(map(lambda r: "".join(map(chr, r)), syms)) syms = list(set(syms)) with pytest.raises(ValueError): RobinhoodHistoricalReader(symbols=syms) diff --git a/pandas_datareader/tests/test_stooq.py b/pandas_datareader/tests/test_stooq.py index 1158b191..436d1dc3 100644 --- a/pandas_datareader/tests/test_stooq.py +++ b/pandas_datareader/tests/test_stooq.py @@ -7,35 +7,35 @@ def test_stooq_dji(): - f = web.DataReader('GS', 'stooq') + f = web.DataReader("GS", "stooq") assert f.shape[0] > 0 def test_get_data_stooq_dji(): - f = get_data_stooq('AMZN') + f = get_data_stooq("AMZN") assert f.shape[0] > 0 def test_get_data_stooq_dates(): - f = get_data_stooq('SPY', start='20180101', end='20180115') + f = get_data_stooq("SPY", start="20180101", end="20180115") assert f.shape[0] == 9 def test_stooq_sp500(): - f = get_data_stooq('^SPX') + f = get_data_stooq("^SPX") assert f.shape[0] > 0 def test_get_data_stooq_dax(): - f = get_data_stooq('^DAX') + f = get_data_stooq("^DAX") assert f.shape[0] > 0 def test_stooq_googl(): - f = get_data_stooq('GOOGL.US') + f = get_data_stooq("GOOGL.US") assert f.shape[0] > 0 def test_get_data_ibm(): - f = get_data_stooq('IBM.UK') + f = get_data_stooq("IBM.UK") assert f.shape[0] > 0 diff --git a/pandas_datareader/tests/test_tsp.py b/pandas_datareader/tests/test_tsp.py index d41f1afb..8ebc2998 100644 --- a/pandas_datareader/tests/test_tsp.py +++ b/pandas_datareader/tests/test_tsp.py @@ -9,19 +9,20 @@ class TestTSPFunds(object): def test_get_allfunds(self): - tspdata = tsp.TSPReader(start='2015-11-2', end='2015-11-2').read() + tspdata = tsp.TSPReader(start="2015-11-2", end="2015-11-2").read() assert len(tspdata == 1) - assert round(tspdata['I Fund'][dt.date(2015, 11, 2)], 5) == 25.0058 + assert round(tspdata["I Fund"][dt.date(2015, 11, 2)], 5) == 25.0058 def test_sanitize_response(self): class response(object): pass + r = response() - r.text = ' , ' + r.text = " , " ret = tsp.TSPReader._sanitize_response(r) - assert ret == '' - r.text = ' a,b ' + assert ret == "" + r.text = " a,b " ret = tsp.TSPReader._sanitize_response(r) - assert ret == 'a,b' + assert ret == "a,b" diff --git a/pandas_datareader/tests/test_wb.py b/pandas_datareader/tests/test_wb.py index 05cf4638..f6e82b27 100644 --- a/pandas_datareader/tests/test_wb.py +++ b/pandas_datareader/tests/test_wb.py @@ -6,39 +6,42 @@ import pytest import requests -from pandas_datareader.compat import assert_raises_regex -from pandas_datareader.wb import (search, download, get_countries, - get_indicators, WorldBankReader) -from pandas_datareader._utils import RemoteDataError from pandas_datareader._testing import skip_on_exception +from pandas_datareader._utils import RemoteDataError +from pandas_datareader.compat import assert_raises_regex +from pandas_datareader.wb import ( + WorldBankReader, + download, + get_countries, + get_indicators, + search, +) class TestWB(object): - def test_wdi_search(self): # Test that a name column exists, and that some results were returned # ...without being too strict about what the actual contents of the # results actually are. The fact that there are some, is good enough. - result = search('gdp.*capita.*constant') - assert result.name.str.contains('GDP').any() + result = search("gdp.*capita.*constant") + assert result.name.str.contains("GDP").any() # check cache returns the results within 0.5 sec current_time = time.time() - result = search('gdp.*capita.*constant') - assert result.name.str.contains('GDP').any() + result = search("gdp.*capita.*constant") + assert result.name.str.contains("GDP").any() assert time.time() - current_time < 0.5 - result2 = WorldBankReader().search('gdp.*capita.*constant') + result2 = WorldBankReader().search("gdp.*capita.*constant") session = requests.Session() - result3 = search('gdp.*capita.*constant', session=session) - result4 = WorldBankReader(session=session).search( - 'gdp.*capita.*constant') + result3 = search("gdp.*capita.*constant", session=session) + result4 = WorldBankReader(session=session).search("gdp.*capita.*constant") for result in [result2, result3, result4]: - assert result.name.str.contains('GDP').any() + assert result.name.str.contains("GDP").any() def test_wdi_download(self): @@ -51,35 +54,45 @@ def test_wdi_download(self): # own exceptions, and don't clean up legacy country codes. # ...but NOT a retired indicator (user should want it to error). - cntry_codes = ['CA', 'MX', 'USA', 'US', 'US', 'KSV', 'BLA'] - inds = ['NY.GDP.PCAP.CD', 'BAD.INDICATOR'] + cntry_codes = ["CA", "MX", "USA", "US", "US", "KSV", "BLA"] + inds = ["NY.GDP.PCAP.CD", "BAD.INDICATOR"] # These are the expected results, rounded (robust against # data revisions in the future). - expected = {'NY.GDP.PCAP.CD': {('Canada', '2004'): 32000.0, - ('Canada', '2003'): 28000.0, - ('Kosovo', '2004'): 2000.0, - ('Kosovo', '2003'): 2000.0, - ('Mexico', '2004'): 7000.0, - ('Mexico', '2003'): 7000.0, - ('United States', '2004'): 42000.0, - ('United States', '2003'): 39000.0}} + expected = { + "NY.GDP.PCAP.CD": { + ("Canada", "2004"): 32000.0, + ("Canada", "2003"): 28000.0, + ("Kosovo", "2004"): 2000.0, + ("Kosovo", "2003"): 2000.0, + ("Mexico", "2004"): 7000.0, + ("Mexico", "2003"): 7000.0, + ("United States", "2004"): 42000.0, + ("United States", "2003"): 39000.0, + } + } expected = pd.DataFrame(expected) expected = expected.sort_index() - result = download(country=cntry_codes, indicator=inds, - start=2003, end=2004, errors='ignore') + result = download( + country=cntry_codes, indicator=inds, start=2003, end=2004, errors="ignore" + ) result = result.sort_index() # Round, to ignore revisions to data. result = np.round(result, decimals=-3) - expected.index.names = ['country', 'year'] + expected.index.names = ["country", "year"] tm.assert_frame_equal(result, expected) # pass start and end as string - result = download(country=cntry_codes, indicator=inds, - start='2003', end='2004', errors='ignore') + result = download( + country=cntry_codes, + indicator=inds, + start="2003", + end="2004", + errors="ignore", + ) result = result.sort_index() # Round, to ignore revisions to data. @@ -90,63 +103,80 @@ def test_wdi_download_str(self): # These are the expected results, rounded (robust against # data revisions in the future). - expected = {'NY.GDP.PCAP.CD': {('Japan', '2004'): 38000.0, - ('Japan', '2003'): 35000.0, - ('Japan', '2002'): 32000.0, - ('Japan', '2001'): 34000.0, - ('Japan', '2000'): 39000.0}} + expected = { + "NY.GDP.PCAP.CD": { + ("Japan", "2004"): 38000.0, + ("Japan", "2003"): 35000.0, + ("Japan", "2002"): 32000.0, + ("Japan", "2001"): 34000.0, + ("Japan", "2000"): 39000.0, + } + } expected = pd.DataFrame(expected) expected = expected.sort_index() - cntry_codes = 'JP' - inds = 'NY.GDP.PCAP.CD' - result = download(country=cntry_codes, indicator=inds, - start=2000, end=2004, errors='ignore') + cntry_codes = "JP" + inds = "NY.GDP.PCAP.CD" + result = download( + country=cntry_codes, indicator=inds, start=2000, end=2004, errors="ignore" + ) result = result.sort_index() result = np.round(result, decimals=-3) - expected.index.names = ['country', 'year'] + expected.index.names = ["country", "year"] tm.assert_frame_equal(result, expected) - result = WorldBankReader(inds, countries=cntry_codes, - start=2000, end=2004, errors='ignore').read() + result = WorldBankReader( + inds, countries=cntry_codes, start=2000, end=2004, errors="ignore" + ).read() result = result.sort_index() result = np.round(result, decimals=-3) tm.assert_frame_equal(result, expected) def test_wdi_download_error_handling(self): - cntry_codes = ['USA', 'XX'] - inds = 'NY.GDP.PCAP.CD' + cntry_codes = ["USA", "XX"] + inds = "NY.GDP.PCAP.CD" msg = "Invalid Country Code\\(s\\): XX" with assert_raises_regex(ValueError, msg): - download(country=cntry_codes, indicator=inds, - start=2003, end=2004, errors='raise') + download( + country=cntry_codes, + indicator=inds, + start=2003, + end=2004, + errors="raise", + ) with tm.assert_produces_warning(): - result = download(country=cntry_codes, indicator=inds, - start=2003, end=2004, errors='warn') + result = download( + country=cntry_codes, indicator=inds, start=2003, end=2004, errors="warn" + ) assert isinstance(result, pd.DataFrame) assert len(result), 2 - cntry_codes = ['USA'] - inds = ['NY.GDP.PCAP.CD', 'BAD_INDICATOR'] + cntry_codes = ["USA"] + inds = ["NY.GDP.PCAP.CD", "BAD_INDICATOR"] - msg = ("The provided parameter value is not valid\\. " - "Indicator: BAD_INDICATOR") + msg = "The provided parameter value is not valid\\. " "Indicator: BAD_INDICATOR" with assert_raises_regex(ValueError, msg): - download(country=cntry_codes, indicator=inds, - start=2003, end=2004, errors='raise') + download( + country=cntry_codes, + indicator=inds, + start=2003, + end=2004, + errors="raise", + ) with tm.assert_produces_warning(): - result = download(country=cntry_codes, indicator=inds, - start=2003, end=2004, errors='warn') + result = download( + country=cntry_codes, indicator=inds, start=2003, end=2004, errors="warn" + ) assert isinstance(result, pd.DataFrame) assert len(result) == 2 def test_wdi_download_w_retired_indicator(self): - cntry_codes = ['CA', 'MX', 'US'] + cntry_codes = ["CA", "MX", "US"] # Despite showing up in the search feature, and being listed online, # the api calls to GDPPCKD don't work in their own query builder, nor # pandas module. GDPPCKD used to be a common symbol. @@ -158,11 +188,16 @@ def test_wdi_download_w_retired_indicator(self): # World bank ever finishes the deprecation of this symbol, # this test should still pass. - inds = ['GDPPCKD'] + inds = ["GDPPCKD"] with pytest.raises(ValueError): - result = download(country=cntry_codes, indicator=inds, - start=2003, end=2004, errors='ignore') + result = download( + country=cntry_codes, + indicator=inds, + start=2003, + end=2004, + errors="ignore", + ) # If it ever gets here, it means WB unretired the indicator. # even if they dropped it completely, it would still @@ -174,12 +209,17 @@ def test_wdi_download_w_retired_indicator(self): def test_wdi_download_w_crash_inducing_countrycode(self): - cntry_codes = ['CA', 'MX', 'US', 'XXX'] - inds = ['NY.GDP.PCAP.CD'] + cntry_codes = ["CA", "MX", "US", "XXX"] + inds = ["NY.GDP.PCAP.CD"] with pytest.raises(ValueError): - result = download(country=cntry_codes, indicator=inds, - start=2003, end=2004, errors='ignore') + result = download( + country=cntry_codes, + indicator=inds, + start=2003, + end=2004, + errors="ignore", + ) # If it ever gets here, it means the country code XXX # got used by WB @@ -197,7 +237,7 @@ def test_wdi_get_countries(self): result4 = WorldBankReader(session=session).get_countries() for result in [result1, result2, result3, result4]: - assert 'Zimbabwe' in list(result['name']) + assert "Zimbabwe" in list(result["name"]) assert len(result) > 100 assert pd.notnull(result.latitude.mean()) assert pd.notnull(result.longitude.mean()) @@ -211,70 +251,101 @@ def test_wdi_get_indicators(self): result4 = WorldBankReader(session=session).get_indicators() for result in [result1, result2, result3, result4]: - exp_col = pd.Index(['id', 'name', 'source', 'sourceNote', - 'sourceOrganization', 'topics', 'unit']) + exp_col = pd.Index( + [ + "id", + "name", + "source", + "sourceNote", + "sourceOrganization", + "topics", + "unit", + ] + ) # assert_index_equal doesn't exists assert result.columns.equals(exp_col) assert len(result) > 10000 @skip_on_exception(RemoteDataError) def test_wdi_download_monthly(self): - expected = {'COPPER': {('World', '2012M01'): 8040.47, - ('World', '2011M12'): 7565.48, - ('World', '2011M11'): 7581.02, - ('World', '2011M10'): 7394.19, - ('World', '2011M09'): 8300.14, - ('World', '2011M08'): 9000.76, - ('World', '2011M07'): 9650.46, - ('World', '2011M06'): 9066.85, - ('World', '2011M05'): 8959.90, - ('World', '2011M04'): 9492.79, - ('World', '2011M03'): 9503.36, - ('World', '2011M02'): 9867.60, - ('World', '2011M01'): 9555.70}} + expected = { + "COPPER": { + ("World", "2012M01"): 8040.47, + ("World", "2011M12"): 7565.48, + ("World", "2011M11"): 7581.02, + ("World", "2011M10"): 7394.19, + ("World", "2011M09"): 8300.14, + ("World", "2011M08"): 9000.76, + ("World", "2011M07"): 9650.46, + ("World", "2011M06"): 9066.85, + ("World", "2011M05"): 8959.90, + ("World", "2011M04"): 9492.79, + ("World", "2011M03"): 9503.36, + ("World", "2011M02"): 9867.60, + ("World", "2011M01"): 9555.70, + } + } expected = pd.DataFrame(expected) # Round, to ignore revisions to data. expected = np.round(expected, decimals=-3) expected = expected.sort_index() - cntry_codes = 'ALL' - inds = 'COPPER' - result = download(country=cntry_codes, indicator=inds, - start=2011, end=2012, freq='M', errors='ignore') + cntry_codes = "ALL" + inds = "COPPER" + result = download( + country=cntry_codes, + indicator=inds, + start=2011, + end=2012, + freq="M", + errors="ignore", + ) result = result.sort_index() result = np.round(result, decimals=-3) - expected.index.names = ['country', 'year'] + expected.index.names = ["country", "year"] tm.assert_frame_equal(result, expected) - result = WorldBankReader(inds, countries=cntry_codes, start=2011, - end=2012, freq='M', errors='ignore').read() + result = WorldBankReader( + inds, countries=cntry_codes, start=2011, end=2012, freq="M", errors="ignore" + ).read() result = result.sort_index() result = np.round(result, decimals=-3) tm.assert_frame_equal(result, expected) def test_wdi_download_quarterly(self): - code = 'DT.DOD.PUBS.CD.US' - expected = {code: {('Albania', '2012Q1'): 3240539817.18, - ('Albania', '2011Q4'): 3213979715.15, - ('Albania', '2011Q3'): 3187681048.95, - ('Albania', '2011Q2'): 3248041513.86, - ('Albania', '2011Q1'): 3137210567.92}} + code = "DT.DOD.PUBS.CD.US" + expected = { + code: { + ("Albania", "2012Q1"): 3240539817.18, + ("Albania", "2011Q4"): 3213979715.15, + ("Albania", "2011Q3"): 3187681048.95, + ("Albania", "2011Q2"): 3248041513.86, + ("Albania", "2011Q1"): 3137210567.92, + } + } expected = pd.DataFrame(expected) # Round, to ignore revisions to data. expected = np.round(expected, decimals=-3) expected = expected.sort_index() - cntry_codes = 'ALB' - inds = 'DT.DOD.PUBS.CD.US' - result = download(country=cntry_codes, indicator=inds, - start=2011, end=2012, freq='Q', errors='ignore') + cntry_codes = "ALB" + inds = "DT.DOD.PUBS.CD.US" + result = download( + country=cntry_codes, + indicator=inds, + start=2011, + end=2012, + freq="Q", + errors="ignore", + ) result = result.sort_index() result = np.round(result, decimals=-3) - expected.index.names = ['country', 'year'] + expected.index.names = ["country", "year"] tm.assert_frame_equal(result, expected) - result = WorldBankReader(inds, countries=cntry_codes, start=2011, - end=2012, freq='Q', errors='ignore').read() + result = WorldBankReader( + inds, countries=cntry_codes, start=2011, end=2012, freq="Q", errors="ignore" + ).read() result = result.sort_index() result = np.round(result, decimals=-1) tm.assert_frame_equal(result, expected) diff --git a/pandas_datareader/tests/yahoo/test_options.py b/pandas_datareader/tests/yahoo/test_options.py index eae227c9..34cda1c1 100644 --- a/pandas_datareader/tests/yahoo/test_options.py +++ b/pandas_datareader/tests/yahoo/test_options.py @@ -1,18 +1,17 @@ -import os from datetime import datetime +import os import numpy as np import pandas as pd - -import pytest import pandas.util.testing as tm +import pytest import pandas_datareader.data as web @pytest.yield_fixture def aapl(): - aapl = web.Options('aapl', 'yahoo') + aapl = web.Options("aapl", "yahoo") yield aapl aapl.close() @@ -51,18 +50,16 @@ def expiry(month, year): @pytest.fixture def json1(datapath): - dirpath = datapath('yahoo', 'data') - json1 = 'file://' + os.path.join( - dirpath, 'yahoo_options1.json') + dirpath = datapath("yahoo", "data") + json1 = "file://" + os.path.join(dirpath, "yahoo_options1.json") return json1 @pytest.fixture def json2(datapath): # see gh-22: empty table - dirpath = datapath('yahoo', 'data') - json2 = 'file://' + os.path.join( - dirpath, 'yahoo_options2.json') + dirpath = datapath("yahoo", "data") + json2 = "file://" + os.path.join(dirpath, "yahoo_options2.json") return json2 @@ -72,9 +69,8 @@ def data1(aapl, json1): class TestYahooOptions(object): - def setup_class(cls): - pytest.skip('Skip all Yahoo! tests.') + pytest.skip("Skip all Yahoo! tests.") def assert_option_result(self, df): """ @@ -83,16 +79,42 @@ def assert_option_result(self, df): assert isinstance(df, pd.DataFrame) assert len(df) > 1 - exp_columns = pd.Index(['Last', 'Bid', 'Ask', 'Chg', 'PctChg', 'Vol', - 'Open_Int', 'IV', 'Root', 'IsNonstandard', - 'Underlying', 'Underlying_Price', 'Quote_Time', - 'Last_Trade_Date', 'JSON']) + exp_columns = pd.Index( + [ + "Last", + "Bid", + "Ask", + "Chg", + "PctChg", + "Vol", + "Open_Int", + "IV", + "Root", + "IsNonstandard", + "Underlying", + "Underlying_Price", + "Quote_Time", + "Last_Trade_Date", + "JSON", + ] + ) tm.assert_index_equal(df.columns, exp_columns) - assert df.index.names == [u'Strike', u'Expiry', u'Type', u'Symbol'] - - dtypes = [np.dtype(x) for x in ['float64'] * 7 + - ['float64', 'object', 'bool', 'object', 'float64', - 'datetime64[ns]', 'datetime64[ns]', 'object']] + assert df.index.names == [u"Strike", u"Expiry", u"Type", u"Symbol"] + + dtypes = [ + np.dtype(x) + for x in ["float64"] * 7 + + [ + "float64", + "object", + "bool", + "object", + "float64", + "datetime64[ns]", + "datetime64[ns]", + "object", + ] + ] tm.assert_series_equal(df.dtypes, pd.Series(dtypes, index=exp_columns)) def test_get_options_data(self, aapl, expiry): @@ -107,25 +129,24 @@ def test_get_options_data(self, aapl, expiry): self.assert_option_result(options) def test_get_near_stock_price(self, aapl, expiry): - options = aapl.get_near_stock_price(call=True, put=True, - expiry=expiry) + options = aapl.get_near_stock_price(call=True, put=True, expiry=expiry) self.assert_option_result(options) def test_options_is_not_none(self): - option = web.Options('aapl', 'yahoo') + option = web.Options("aapl", "yahoo") assert option is not None def test_get_call_data(self, aapl, expiry): calls = aapl.get_call_data(expiry=expiry) self.assert_option_result(calls) - assert calls.index.levels[2][0] == 'call' + assert calls.index.levels[2][0] == "call" def test_get_put_data(self, aapl, expiry): puts = aapl.get_put_data(expiry=expiry) self.assert_option_result(puts) - assert puts.index.levels[2][1] == 'put' + assert puts.index.levels[2][1] == "put" def test_get_expiry_dates(self, aapl): dates = aapl._get_expiry_dates() @@ -151,7 +172,7 @@ def test_get_all_data_calls_only(self, aapl): def test_get_underlying_price(self, aapl): # see gh-7 - options_object = web.Options('^spxpm', 'yahoo') + options_object = web.Options("^spxpm", "yahoo") quote_price = options_object.underlying_price assert isinstance(quote_price, float) @@ -162,49 +183,44 @@ def test_get_underlying_price(self, aapl): assert isinstance(price, (int, float, complex)) assert isinstance(quote_time, (datetime, pd.Timestamp)) - @pytest.mark.xfail(reason='Invalid URL scheme') + @pytest.mark.xfail(reason="Invalid URL scheme") def test_chop(self, aapl, data1): # gh-7625: regression test - aapl._chop_data(data1, above_below=2, - underlying_price=np.nan) - chopped = aapl._chop_data(data1, above_below=2, - underlying_price=100) + aapl._chop_data(data1, above_below=2, underlying_price=np.nan) + chopped = aapl._chop_data(data1, above_below=2, underlying_price=100) assert isinstance(chopped, pd.DataFrame) assert len(chopped) > 1 - chopped2 = aapl._chop_data(data1, above_below=2, - underlying_price=None) + chopped2 = aapl._chop_data(data1, above_below=2, underlying_price=None) assert isinstance(chopped2, pd.DataFrame) assert len(chopped2) > 1 - @pytest.mark.xfail(reason='Invalid URL scheme') + @pytest.mark.xfail(reason="Invalid URL scheme") def test_chop_out_of_strike_range(self, aapl, data1): # gh-7625: regression test - aapl._chop_data(data1, above_below=2, - underlying_price=np.nan) - chopped = aapl._chop_data(data1, above_below=2, - underlying_price=100000) + aapl._chop_data(data1, above_below=2, underlying_price=np.nan) + chopped = aapl._chop_data(data1, above_below=2, underlying_price=100000) assert isinstance(chopped, pd.DataFrame) assert len(chopped) > 1 - @pytest.mark.xfail(reason='Invalid URL scheme') + @pytest.mark.xfail(reason="Invalid URL scheme") def test_sample_page_chg_float(self, data1): # Tests that numeric columns with comma's are appropriately dealt with - assert data1['Chg'].dtype == 'float64' + assert data1["Chg"].dtype == "float64" def test_month_year(self, aapl, month, year): # see gh-168 data = aapl.get_call_data(month=month, year=year) assert len(data) > 1 - assert data.index.levels[0].dtype == 'float64' + assert data.index.levels[0].dtype == "float64" self.assert_option_result(data) - @pytest.mark.xfail(reason='Invalid URL scheme') + @pytest.mark.xfail(reason="Invalid URL scheme") def test_empty_table(self, aapl, json2): # see gh-22 empty = aapl._process_data(aapl._parse_url(json2)) diff --git a/pandas_datareader/tests/yahoo/test_yahoo.py b/pandas_datareader/tests/yahoo/test_yahoo.py index 9a886772..16b8b79c 100644 --- a/pandas_datareader/tests/yahoo/test_yahoo.py +++ b/pandas_datareader/tests/yahoo/test_yahoo.py @@ -1,25 +1,24 @@ from datetime import datetime -import requests import numpy as np import pandas as pd from pandas import DataFrame -from requests.exceptions import ConnectionError -import pytest import pandas.util.testing as tm +import pytest +import requests +from requests.exceptions import ConnectionError +from pandas_datareader._testing import skip_on_exception +from pandas_datareader._utils import RemoteDataError import pandas_datareader.data as web from pandas_datareader.data import YahooDailyReader -from pandas_datareader._utils import RemoteDataError -from pandas_datareader._testing import skip_on_exception -XFAIL_REASON = 'Known connection failures on Yahoo when testing!' +XFAIL_REASON = "Known connection failures on Yahoo when testing!" pytestmark = pytest.mark.stable class TestYahoo(object): - @classmethod def setup_class(cls): pytest.importorskip("lxml") @@ -30,37 +29,36 @@ def test_yahoo(self): start = datetime(2010, 1, 1) end = datetime(2013, 1, 25) - assert round(web.DataReader('F', 'yahoo', start, end)['Close'][-1], - 2) == 13.68 + assert round(web.DataReader("F", "yahoo", start, end)["Close"][-1], 2) == 13.68 def test_yahoo_fails(self): start = datetime(2010, 1, 1) end = datetime(2013, 1, 27) with pytest.raises(Exception): - web.DataReader('NON EXISTENT TICKER', 'yahoo', start, end) + web.DataReader("NON EXISTENT TICKER", "yahoo", start, end) def test_get_quote_series(self): - stringlist = ['GOOG', 'AAPL'] - fields = ['exchange', 'sharesOutstanding', 'epsForward'] + stringlist = ["GOOG", "AAPL"] + fields = ["exchange", "sharesOutstanding", "epsForward"] try: - AAPL = web.get_quote_yahoo('AAPL') + AAPL = web.get_quote_yahoo("AAPL") df = web.get_quote_yahoo(pd.Series(stringlist)) except ConnectionError: pytest.xfail(reason=XFAIL_REASON) - tm.assert_series_equal(AAPL.iloc[0][fields], df.loc['AAPL'][fields]) + tm.assert_series_equal(AAPL.iloc[0][fields], df.loc["AAPL"][fields]) assert sorted(stringlist) == sorted(list(df.index.values)) def test_get_quote_string(self): try: - df = web.get_quote_yahoo('GOOG') + df = web.get_quote_yahoo("GOOG") except ConnectionError: pytest.xfail(reason=XFAIL_REASON) - assert not pd.isnull(df['marketCap'][0]) + assert not pd.isnull(df["marketCap"][0]) def test_get_quote_stringlist(self): - stringlist = ['GOOG', 'AAPL'] + stringlist = ["GOOG", "AAPL"] try: df = web.get_quote_yahoo(stringlist) except ConnectionError: @@ -69,44 +67,44 @@ def test_get_quote_stringlist(self): def test_get_quote_comma_name(self): try: - df = web.get_quote_yahoo(['RGLD']) + df = web.get_quote_yahoo(["RGLD"]) except ConnectionError: pytest.xfail(reason=XFAIL_REASON) - assert df['longName'][0] == 'Royal Gold, Inc.' + assert df["longName"][0] == "Royal Gold, Inc." - @pytest.mark.skip('Unreliable test, receive partial ' - 'components back for dow_jones') + @pytest.mark.skip( + "Unreliable test, receive partial " "components back for dow_jones" + ) def test_get_components_dow_jones(self): # pragma: no cover - df = web.get_components_yahoo('^DJI') # Dow Jones + df = web.get_components_yahoo("^DJI") # Dow Jones assert isinstance(df, pd.DataFrame) assert len(df) == 30 - @pytest.mark.skip('Unreliable test, receive partial ' - 'components back for dax') + @pytest.mark.skip("Unreliable test, receive partial " "components back for dax") def test_get_components_dax(self): # pragma: no cover - df = web.get_components_yahoo('^GDAXI') # DAX + df = web.get_components_yahoo("^GDAXI") # DAX assert isinstance(df, pd.DataFrame) assert len(df) == 30 - assert df[df.name.str.contains('adidas', case=False)].index == 'ADS.DE' + assert df[df.name.str.contains("adidas", case=False)].index == "ADS.DE" - @pytest.mark.skip('Unreliable test, receive partial ' - 'components back for nasdaq_100') + @pytest.mark.skip( + "Unreliable test, receive partial " "components back for nasdaq_100" + ) def test_get_components_nasdaq_100(self): # pragma: no cover # As of 7/12/13, the conditional will # return false because the link is invalid - df = web.get_components_yahoo('^NDX') # NASDAQ-100 + df = web.get_components_yahoo("^NDX") # NASDAQ-100 assert isinstance(df, pd.DataFrame) if len(df) > 1: # Usual culprits, should be around for a while - assert 'AAPL' in df.index - assert 'GOOG' in df.index - assert 'AMZN' in df.index + assert "AAPL" in df.index + assert "GOOG" in df.index + assert "AMZN" in df.index else: - expected = DataFrame({'exchange': 'N/A', 'name': '@^NDX'}, - index=['@^NDX']) + expected = DataFrame({"exchange": "N/A", "name": "@^NDX"}, index=["@^NDX"]) tm.assert_frame_equal(df, expected) @skip_on_exception(RemoteDataError) @@ -114,86 +112,84 @@ def test_get_data_single_symbol(self): # single symbol # http://finance.yahoo.com/q/hp?s=GOOG&a=09&b=08&c=2010&d=09&e=10&f=2010&g=d # just test that we succeed - web.get_data_yahoo('GOOG') + web.get_data_yahoo("GOOG") @skip_on_exception(RemoteDataError) def test_data_with_no_actions(self): - web.get_data_yahoo('TSLA') + web.get_data_yahoo("TSLA") @skip_on_exception(RemoteDataError) def test_get_data_adjust_price(self): - goog = web.get_data_yahoo('GOOG') - goog_adj = web.get_data_yahoo('GOOG', adjust_price=True) - assert 'Adj Close' not in goog_adj.columns - assert (goog['Open'] * goog_adj['Adj_Ratio']).equals(goog_adj['Open']) + goog = web.get_data_yahoo("GOOG") + goog_adj = web.get_data_yahoo("GOOG", adjust_price=True) + assert "Adj Close" not in goog_adj.columns + assert (goog["Open"] * goog_adj["Adj_Ratio"]).equals(goog_adj["Open"]) @pytest.mark.xfail(reason="Yahoo are returning an extra day 31st Dec 2012") def test_get_data_interval(self): # daily interval data - pan = web.get_data_yahoo('XOM', '2013-01-01', - '2013-12-31', interval='d') + pan = web.get_data_yahoo("XOM", "2013-01-01", "2013-12-31", interval="d") assert len(pan) == 252 # weekly interval data - pan = web.get_data_yahoo('XOM', '2013-01-01', - '2013-12-31', interval='w') + pan = web.get_data_yahoo("XOM", "2013-01-01", "2013-12-31", interval="w") assert len(pan) == 53 # monthly interval data - pan = web.get_data_yahoo('XOM', '2012-12-31', - '2013-12-31', interval='m') + pan = web.get_data_yahoo("XOM", "2012-12-31", "2013-12-31", interval="m") assert len(pan) == 12 # test fail on invalid interval with pytest.raises(ValueError): - web.get_data_yahoo('XOM', interval='NOT VALID') + web.get_data_yahoo("XOM", interval="NOT VALID") @skip_on_exception(RemoteDataError) def test_get_data_multiple_symbols(self): # just test that we succeed - sl = ['AAPL', 'AMZN', 'GOOG'] - web.get_data_yahoo(sl, '2012') + sl = ["AAPL", "AMZN", "GOOG"] + web.get_data_yahoo(sl, "2012") - @pytest.mark.parametrize('adj_pr', [True, False]) + @pytest.mark.parametrize("adj_pr", [True, False]) @skip_on_exception(RemoteDataError) def test_get_data_null_as_missing_data(self, adj_pr): - result = web.get_data_yahoo('SRCE', '20160626', '20160705', - adjust_price=adj_pr) + result = web.get_data_yahoo("SRCE", "20160626", "20160705", adjust_price=adj_pr) # sanity checking - floats = ['Open', 'High', 'Low', 'Close'] + floats = ["Open", "High", "Low", "Close"] if adj_pr: - floats.append('Adj_Ratio') + floats.append("Adj_Ratio") else: - floats.append('Adj Close') + floats.append("Adj Close") assert result[floats].dtypes.all() == np.floating @skip_on_exception(RemoteDataError) def test_get_data_multiple_symbols_two_dates(self): - data = web.get_data_yahoo(['GE', 'MSFT', 'INTC'], 'JAN-01-12', - 'JAN-31-12') - result = data.Close.loc['01-18-12'].T + data = web.get_data_yahoo(["GE", "MSFT", "INTC"], "JAN-01-12", "JAN-31-12") + result = data.Close.loc["01-18-12"].T assert result.size == 3 # sanity checking assert result.dtypes == np.floating - expected = np.array([[18.99, 28.4, 25.18], - [18.58, 28.31, 25.13], - [19.03, 28.16, 25.52], - [18.81, 28.82, 25.87]]) + expected = np.array( + [ + [18.99, 28.4, 25.18], + [18.58, 28.31, 25.13], + [19.03, 28.16, 25.52], + [18.81, 28.82, 25.87], + ] + ) df = data.Open - result = df[(df.index >= 'Jan-15-12') & (df.index <= 'Jan-20-12')] + result = df[(df.index >= "Jan-15-12") & (df.index <= "Jan-20-12")] assert expected.shape == result.shape def test_get_date_ret_index(self): - pan = web.get_data_yahoo(['GE', 'INTC', 'IBM'], '1977', '1987', - ret_index=True) - assert hasattr(pan, 'Ret_Index') + pan = web.get_data_yahoo(["GE", "INTC", "IBM"], "1977", "1987", ret_index=True) + assert hasattr(pan, "Ret_Index") - if hasattr(pan, 'Ret_Index') and hasattr(pan.Ret_Index, 'INTC'): + if hasattr(pan, "Ret_Index") and hasattr(pan.Ret_Index, "INTC"): tstamp = pan.Ret_Index.INTC.first_valid_index() - result = pan.Ret_Index.loc[tstamp, 'INTC'] + result = pan.Ret_Index.loc[tstamp, "INTC"] assert result == 1.0 # sanity checking @@ -203,83 +199,142 @@ def test_get_data_yahoo_actions(self): start = datetime(1990, 1, 1) end = datetime(2018, 4, 5) - actions = web.get_data_yahoo_actions('AAPL', start, end, - adjust_dividends=False) + actions = web.get_data_yahoo_actions("AAPL", start, end, adjust_dividends=False) - assert sum(actions['action'] == 'DIVIDEND') == 47 - assert sum(actions['action'] == 'SPLIT') == 3 + assert sum(actions["action"] == "DIVIDEND") == 47 + assert sum(actions["action"] == "SPLIT") == 3 - assert actions.loc['2005-02-28', 'action'][0] == 'SPLIT' - assert actions.loc['2005-02-28', 'value'][0] == 1/2.0 + assert actions.loc["2005-02-28", "action"][0] == "SPLIT" + assert actions.loc["2005-02-28", "value"][0] == 1 / 2.0 - assert actions.loc['1995-11-21', 'action'][0] == 'DIVIDEND' - assert round(actions.loc['1995-11-21', 'value'][0], 3) == 0.120 + assert actions.loc["1995-11-21", "action"][0] == "DIVIDEND" + assert round(actions.loc["1995-11-21", "value"][0], 3) == 0.120 - actions = web.get_data_yahoo_actions('AAPL', start, end, - adjust_dividends=True) + actions = web.get_data_yahoo_actions("AAPL", start, end, adjust_dividends=True) - assert actions.loc['1995-11-21', 'action'][0] == 'DIVIDEND' - assert round(actions.loc['1995-11-21', 'value'][0], 4) == 0.0043 + assert actions.loc["1995-11-21", "action"][0] == "DIVIDEND" + assert round(actions.loc["1995-11-21", "value"][0], 4) == 0.0043 def test_get_data_yahoo_actions_invalid_symbol(self): start = datetime(1990, 1, 1) end = datetime(2000, 4, 5) with pytest.raises(IOError): - web.get_data_yahoo_actions('UNKNOWN TICKER', start, end) + web.get_data_yahoo_actions("UNKNOWN TICKER", start, end) @skip_on_exception(RemoteDataError) def test_yahoo_reader_class(self): - r = YahooDailyReader('GOOG') + r = YahooDailyReader("GOOG") df = r.read() - assert df.Volume.loc['JAN-02-2015'] == 1447500 + assert df.Volume.loc["JAN-02-2015"] == 1447500 session = requests.Session() - r = YahooDailyReader('GOOG', session=session) + r = YahooDailyReader("GOOG", session=session) assert r.session is session def test_yahoo_DataReader(self): start = datetime(2010, 1, 1) end = datetime(2015, 5, 9) # yahoo will adjust for dividends by default - result = web.DataReader('AAPL', 'yahoo-actions', start, end) - - exp_idx = pd.DatetimeIndex(['2015-05-07', '2015-02-05', - '2014-11-06', '2014-08-07', - '2014-06-09', '2014-05-08', - '2014-02-06', '2013-11-06', - '2013-08-08', '2013-05-09', - '2013-02-07', '2012-11-07', - '2012-08-09']) - - exp = pd.DataFrame({'action': ['DIVIDEND', 'DIVIDEND', 'DIVIDEND', - 'DIVIDEND', 'SPLIT', 'DIVIDEND', - 'DIVIDEND', 'DIVIDEND', - 'DIVIDEND', 'DIVIDEND', 'DIVIDEND', - 'DIVIDEND', 'DIVIDEND'], - 'value': [0.52, 0.47, 0.47, 0.47, 0.14285714, - 0.47, 0.43571, 0.43571, 0.43571, - 0.43571, 0.37857, 0.37857, 0.37857]}, - index=exp_idx) - exp.index.name = 'Date' + result = web.DataReader("AAPL", "yahoo-actions", start, end) + + exp_idx = pd.DatetimeIndex( + [ + "2015-05-07", + "2015-02-05", + "2014-11-06", + "2014-08-07", + "2014-06-09", + "2014-05-08", + "2014-02-06", + "2013-11-06", + "2013-08-08", + "2013-05-09", + "2013-02-07", + "2012-11-07", + "2012-08-09", + ] + ) + + exp = pd.DataFrame( + { + "action": [ + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "SPLIT", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + ], + "value": [ + 0.52, + 0.47, + 0.47, + 0.47, + 0.14285714, + 0.47, + 0.43571, + 0.43571, + 0.43571, + 0.43571, + 0.37857, + 0.37857, + 0.37857, + ], + }, + index=exp_idx, + ) + exp.index.name = "Date" tm.assert_frame_equal(result.reindex_like(exp).round(2), exp.round(2)) # where dividends are not adjusted for splits - result = web.get_data_yahoo_actions('AAPL', start, end, - adjust_dividends=False) - - exp = pd.DataFrame({'action': ['DIVIDEND', 'DIVIDEND', 'DIVIDEND', - 'DIVIDEND', 'SPLIT', 'DIVIDEND', - 'DIVIDEND', 'DIVIDEND', - 'DIVIDEND', 'DIVIDEND', 'DIVIDEND', - 'DIVIDEND', 'DIVIDEND'], - 'value': [0.52, 0.47, 0.47, 0.47, 0.14285714, - 3.29, 3.05, 3.05, 3.05, - 3.05, 2.65, 2.65, 2.65]}, - index=exp_idx) - exp.index.name = 'Date' + result = web.get_data_yahoo_actions("AAPL", start, end, adjust_dividends=False) + + exp = pd.DataFrame( + { + "action": [ + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "SPLIT", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + "DIVIDEND", + ], + "value": [ + 0.52, + 0.47, + 0.47, + 0.47, + 0.14285714, + 3.29, + 3.05, + 3.05, + 3.05, + 3.05, + 2.65, + 2.65, + 2.65, + ], + }, + index=exp_idx, + ) + exp.index.name = "Date" tm.assert_frame_equal(result.reindex_like(exp).round(4), exp.round(4)) # test cases with "1/0" split ratio in actions - @@ -287,22 +342,25 @@ def test_yahoo_DataReader(self): start = datetime(2017, 12, 30) end = datetime(2018, 12, 30) - result = web.DataReader('NTR', 'yahoo-actions', start, end) + result = web.DataReader("NTR", "yahoo-actions", start, end) - exp_idx = pd.DatetimeIndex(['2018-12-28', '2018-09-27', - '2018-06-28', '2018-03-28', - '2018-01-02']) + exp_idx = pd.DatetimeIndex( + ["2018-12-28", "2018-09-27", "2018-06-28", "2018-03-28", "2018-01-02"] + ) - exp = pd.DataFrame({'action': ['DIVIDEND', 'DIVIDEND', 'DIVIDEND', - 'DIVIDEND', 'SPLIT'], - 'value': [0.43, 0.40, 0.40, 0.40, 1.00]}, - index=exp_idx) - exp.index.name = 'Date' + exp = pd.DataFrame( + { + "action": ["DIVIDEND", "DIVIDEND", "DIVIDEND", "DIVIDEND", "SPLIT"], + "value": [0.43, 0.40, 0.40, 0.40, 1.00], + }, + index=exp_idx, + ) + exp.index.name = "Date" tm.assert_frame_equal(result.reindex_like(exp).round(2), exp.round(2)) @skip_on_exception(RemoteDataError) def test_yahoo_DataReader_multi(self): start = datetime(2010, 1, 1) end = datetime(2015, 5, 9) - result = web.DataReader(['AAPL', 'F'], 'yahoo-actions', start, end) + result = web.DataReader(["AAPL", "F"], "yahoo-actions", start, end) assert isinstance(result, dict) diff --git a/pandas_datareader/tiingo.py b/pandas_datareader/tiingo.py index 7b9edc7c..f40e77c4 100644 --- a/pandas_datareader/tiingo.py +++ b/pandas_datareader/tiingo.py @@ -19,7 +19,7 @@ def get_tiingo_symbols(): ----- Reads https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip """ - url = 'https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip' + url = "https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip" return pd.read_csv(url) @@ -54,52 +54,69 @@ class TiingoIEXHistoricalReader(_BaseReader): TIINGO_API_KEY is read. The API key is *required*. """ - def __init__(self, symbols, start=None, end=None, retry_count=3, pause=0.1, - timeout=30, session=None, freq=None, api_key=None): - super().__init__(symbols, start, end, retry_count, pause, timeout, - session, freq) + def __init__( + self, + symbols, + start=None, + end=None, + retry_count=3, + pause=0.1, + timeout=30, + session=None, + freq=None, + api_key=None, + ): + super().__init__( + symbols, start, end, retry_count, pause, timeout, session, freq + ) if isinstance(self.symbols, str): self.symbols = [self.symbols] - self._symbol = '' + self._symbol = "" if api_key is None: - api_key = os.getenv('TIINGO_API_KEY') + api_key = os.getenv("TIINGO_API_KEY") if not api_key or not isinstance(api_key, str): - raise ValueError('The tiingo API key must be provided either ' - 'through the api_key variable or through the ' - 'environmental variable TIINGO_API_KEY.') + raise ValueError( + "The tiingo API key must be provided either " + "through the api_key variable or through the " + "environmental variable TIINGO_API_KEY." + ) self.api_key = api_key self._concat_axis = 0 @property def url(self): """API URL""" - _url = 'https://api.tiingo.com/iex/{ticker}/prices' + _url = "https://api.tiingo.com/iex/{ticker}/prices" return _url.format(ticker=self._symbol) @property def params(self): """Parameters to use in API calls""" - return {'startDate': self.start.strftime('%Y-%m-%d'), - 'endDate': self.end.strftime('%Y-%m-%d'), - 'resampleFreq': self.freq, - 'format': 'json'} + return { + "startDate": self.start.strftime("%Y-%m-%d"), + "endDate": self.end.strftime("%Y-%m-%d"), + "resampleFreq": self.freq, + "format": "json", + } def _get_crumb(self, *args): pass def _read_one_data(self, url, params): """ read one data from specified URL """ - headers = {'Content-Type': 'application/json', - 'Authorization': 'Token ' + self.api_key} + headers = { + "Content-Type": "application/json", + "Authorization": "Token " + self.api_key, + } out = self._get_response(url, params=params, headers=headers).json() return self._read_lines(out) def _read_lines(self, out): df = pd.DataFrame(out) - df['symbol'] = self._symbol - df['date'] = pd.to_datetime(df['date']) - df = df.set_index(['symbol', 'date']) + df["symbol"] = self._symbol + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index(["symbol", "date"]) return df def read(self): @@ -140,51 +157,67 @@ class TiingoDailyReader(_BaseReader): TIINGO_API_KEY is read. The API key is *required*. """ - def __init__(self, symbols, start=None, end=None, retry_count=3, pause=0.1, - timeout=30, session=None, freq=None, api_key=None): - super(TiingoDailyReader, self).__init__(symbols, start, end, - retry_count, pause, timeout, - session, freq) + def __init__( + self, + symbols, + start=None, + end=None, + retry_count=3, + pause=0.1, + timeout=30, + session=None, + freq=None, + api_key=None, + ): + super(TiingoDailyReader, self).__init__( + symbols, start, end, retry_count, pause, timeout, session, freq + ) if isinstance(self.symbols, str): self.symbols = [self.symbols] - self._symbol = '' + self._symbol = "" if api_key is None: - api_key = os.getenv('TIINGO_API_KEY') + api_key = os.getenv("TIINGO_API_KEY") if not api_key or not isinstance(api_key, str): - raise ValueError('The tiingo API key must be provided either ' - 'through the api_key variable or through the ' - 'environmental variable TIINGO_API_KEY.') + raise ValueError( + "The tiingo API key must be provided either " + "through the api_key variable or through the " + "environmental variable TIINGO_API_KEY." + ) self.api_key = api_key self._concat_axis = 0 @property def url(self): """API URL""" - _url = 'https://api.tiingo.com/tiingo/daily/{ticker}/prices' + _url = "https://api.tiingo.com/tiingo/daily/{ticker}/prices" return _url.format(ticker=self._symbol) @property def params(self): """Parameters to use in API calls""" - return {'startDate': self.start.strftime('%Y-%m-%d'), - 'endDate': self.end.strftime('%Y-%m-%d'), - 'format': 'json'} + return { + "startDate": self.start.strftime("%Y-%m-%d"), + "endDate": self.end.strftime("%Y-%m-%d"), + "format": "json", + } def _get_crumb(self, *args): pass def _read_one_data(self, url, params): """ read one data from specified URL """ - headers = {'Content-Type': 'application/json', - 'Authorization': 'Token ' + self.api_key} + headers = { + "Content-Type": "application/json", + "Authorization": "Token " + self.api_key, + } out = self._get_response(url, params=params, headers=headers).json() return self._read_lines(out) def _read_lines(self, out): df = pd.DataFrame(out) - df['symbol'] = self._symbol - df['date'] = pd.to_datetime(df['date']) - df = df.set_index(['symbol', 'date']) + df["symbol"] = self._symbol + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index(["symbol", "date"]) return df def read(self): @@ -224,17 +257,27 @@ class TiingoMetaDataReader(TiingoDailyReader): TIINGO_API_KEY is read. The API key is *required*. """ - def __init__(self, symbols, start=None, end=None, retry_count=3, pause=0.1, - timeout=30, session=None, freq=None, api_key=None): - super(TiingoMetaDataReader, self).__init__(symbols, start, end, - retry_count, pause, timeout, - session, freq, api_key) + def __init__( + self, + symbols, + start=None, + end=None, + retry_count=3, + pause=0.1, + timeout=30, + session=None, + freq=None, + api_key=None, + ): + super(TiingoMetaDataReader, self).__init__( + symbols, start, end, retry_count, pause, timeout, session, freq, api_key + ) self._concat_axis = 1 @property def url(self): """API URL""" - _url = 'https://api.tiingo.com/tiingo/daily/{ticker}' + _url = "https://api.tiingo.com/tiingo/daily/{ticker}" return _url.format(ticker=self._symbol) @property diff --git a/pandas_datareader/tsp.py b/pandas_datareader/tsp.py index 108a4826..d89caf43 100644 --- a/pandas_datareader/tsp.py +++ b/pandas_datareader/tsp.py @@ -26,21 +26,29 @@ class TSPReader(_BaseReader): requests.sessions.Session instance to be used """ - def __init__(self, - symbols=('Linc', 'L2020', 'L2030', 'L2040', - 'L2050', 'G', 'F', 'C', 'S', 'I'), - start=None, end=None, retry_count=3, pause=0.1, - session=None): - super(TSPReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) - self._format = 'string' + def __init__( + self, + symbols=("Linc", "L2020", "L2030", "L2040", "L2050", "G", "F", "C", "S", "I"), + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + ): + super(TSPReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) + self._format = "string" @property def url(self): """API URL""" - return 'https://www.tsp.gov/InvestmentFunds/FundPerformance/index.html' + return "https://www.tsp.gov/InvestmentFunds/FundPerformance/index.html" def read(self): """ read one data from specified URL """ @@ -51,10 +59,12 @@ def read(self): @property def params(self): """Parameters to use in API calls""" - return {'startdate': self.start.strftime('%m/%d/%Y'), - 'enddate': self.end.strftime('%m/%d/%Y'), - 'fundgroup': self.symbols, - 'whichButton': 'CSV'} + return { + "startdate": self.start.strftime("%m/%d/%Y"), + "enddate": self.end.strftime("%m/%d/%Y"), + "fundgroup": self.symbols, + "whichButton": "CSV", + } @staticmethod def _sanitize_response(response): @@ -62,6 +72,6 @@ def _sanitize_response(response): Clean up the response string """ text = response.text.strip() - if text[-1] == ',': + if text[-1] == ",": return text[0:-1] return text diff --git a/pandas_datareader/wb.py b/pandas_datareader/wb.py index ecd583d7..c514f6b1 100644 --- a/pandas_datareader/wb.py +++ b/pandas_datareader/wb.py @@ -2,11 +2,11 @@ import warnings -from pandas_datareader.compat import reduce, lrange, string_types -import pandas as pd import numpy as np +import pandas as pd from pandas_datareader.base import _BaseReader +from pandas_datareader.compat import lrange, reduce, string_types # This list of country codes was pulled from wikipedia during October 2014. # While some exceptions do exist, it is the best proxy for countries supported @@ -14,65 +14,511 @@ # 3-digit ISO 3166-1 alpha-3, codes, with 'all', 'ALL', and 'All' appended ot # the end. -WB_API_URL = 'https://api.worldbank.org/v2' - -country_codes = ['AD', 'AE', 'AF', 'AG', 'AI', 'AL', 'AM', 'AO', 'AQ', 'AR', - 'AS', 'AT', 'AU', 'AW', 'AX', 'AZ', 'BA', 'BB', 'BD', 'BE', - 'BF', 'BG', 'BH', 'BI', 'BJ', 'BL', 'BM', 'BN', 'BO', 'BQ', - 'BR', 'BS', 'BT', 'BV', 'BW', 'BY', 'BZ', 'CA', 'CC', 'CD', - 'CF', 'CG', 'CH', 'CI', 'CK', 'CL', 'CM', 'CN', 'CO', 'CR', - 'CU', 'CV', 'CW', 'CX', 'CY', 'CZ', 'DE', 'DJ', 'DK', 'DM', - 'DO', 'DZ', 'EC', 'EE', 'EG', 'EH', 'ER', 'ES', 'ET', 'FI', - 'FJ', 'FK', 'FM', 'FO', 'FR', 'GA', 'GB', 'GD', 'GE', 'GF', - 'GG', 'GH', 'GI', 'GL', 'GM', 'GN', 'GP', 'GQ', 'GR', 'GS', - 'GT', 'GU', 'GW', 'GY', 'HK', 'HM', 'HN', 'HR', 'HT', 'HU', - 'ID', 'IE', 'IL', 'IM', 'IN', 'IO', 'IQ', 'IR', 'IS', 'IT', - 'JE', 'JM', 'JO', 'JP', 'KE', 'KG', 'KH', 'KI', 'KM', 'KN', - 'KP', 'KR', 'KW', 'KY', 'KZ', 'LA', 'LB', 'LC', 'LI', 'LK', - 'LR', 'LS', 'LT', 'LU', 'LV', 'LY', 'MA', 'MC', 'MD', 'ME', - 'MF', 'MG', 'MH', 'MK', 'ML', 'MM', 'MN', 'MO', 'MP', 'MQ', - 'MR', 'MS', 'MT', 'MU', 'MV', 'MW', 'MX', 'MY', 'MZ', 'NA', - 'NC', 'NE', 'NF', 'NG', 'NI', 'NL', 'NO', 'NP', 'NR', 'NU', - 'NZ', 'OM', 'PA', 'PE', 'PF', 'PG', 'PH', 'PK', 'PL', 'PM', - 'PN', 'PR', 'PS', 'PT', 'PW', 'PY', 'QA', 'RE', 'RO', 'RS', - 'RU', 'RW', 'SA', 'SB', 'SC', 'SD', 'SE', 'SG', 'SH', 'SI', - 'SJ', 'SK', 'SL', 'SM', 'SN', 'SO', 'SR', 'SS', 'ST', 'SV', - 'SX', 'SY', 'SZ', 'TC', 'TD', 'TF', 'TG', 'TH', 'TJ', 'TK', - 'TL', 'TM', 'TN', 'TO', 'TR', 'TT', 'TV', 'TW', 'TZ', 'UA', - 'UG', 'UM', 'US', 'UY', 'UZ', 'VA', 'VC', 'VE', 'VG', 'VI', - 'VN', 'VU', 'WF', 'WS', 'YE', 'YT', 'ZA', 'ZM', 'ZW', - 'ABW', 'AFG', 'AGO', 'AIA', 'ALA', 'ALB', 'AND', 'ARE', - 'ARG', 'ARM', 'ASM', 'ATA', 'ATF', 'ATG', 'AUS', 'AUT', - 'AZE', 'BDI', 'BEL', 'BEN', 'BES', 'BFA', 'BGD', 'BGR', - 'BHR', 'BHS', 'BIH', 'BLM', 'BLR', 'BLZ', 'BMU', 'BOL', - 'BRA', 'BRB', 'BRN', 'BTN', 'BVT', 'BWA', 'CAF', 'CAN', - 'CCK', 'CHE', 'CHL', 'CHN', 'CIV', 'CMR', 'COD', 'COG', - 'COK', 'COL', 'COM', 'CPV', 'CRI', 'CUB', 'CUW', 'CXR', - 'CYM', 'CYP', 'CZE', 'DEU', 'DJI', 'DMA', 'DNK', 'DOM', - 'DZA', 'ECU', 'EGY', 'ERI', 'ESH', 'ESP', 'EST', 'ETH', - 'FIN', 'FJI', 'FLK', 'FRA', 'FRO', 'FSM', 'GAB', 'GBR', - 'GEO', 'GGY', 'GHA', 'GIB', 'GIN', 'GLP', 'GMB', 'GNB', - 'GNQ', 'GRC', 'GRD', 'GRL', 'GTM', 'GUF', 'GUM', 'GUY', - 'HKG', 'HMD', 'HND', 'HRV', 'HTI', 'HUN', 'IDN', 'IMN', - 'IND', 'IOT', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA', - 'JAM', 'JEY', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', - 'KIR', 'KNA', 'KOR', 'KWT', 'LAO', 'LBN', 'LBR', 'LBY', - 'LCA', 'LIE', 'LKA', 'LSO', 'LTU', 'LUX', 'LVA', 'MAC', - 'MAF', 'MAR', 'MCO', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL', - 'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', - 'MRT', 'MSR', 'MTQ', 'MUS', 'MWI', 'MYS', 'MYT', 'NAM', - 'NCL', 'NER', 'NFK', 'NGA', 'NIC', 'NIU', 'NLD', 'NOR', - 'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PCN', 'PER', - 'PHL', 'PLW', 'PNG', 'POL', 'PRI', 'PRK', 'PRT', 'PRY', - 'PSE', 'PYF', 'QAT', 'REU', 'ROU', 'RUS', 'RWA', 'SAU', - 'SDN', 'SEN', 'SGP', 'SGS', 'SHN', 'SJM', 'SLB', 'SLE', - 'SLV', 'SMR', 'SOM', 'SPM', 'SRB', 'SSD', 'STP', 'SUR', - 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM', 'SYC', 'SYR', 'TCA', - 'TCD', 'TGO', 'THA', 'TJK', 'TKL', 'TKM', 'TLS', 'TON', - 'TTO', 'TUN', 'TUR', 'TUV', 'TWN', 'TZA', 'UGA', 'UKR', - 'UMI', 'URY', 'USA', 'UZB', 'VAT', 'VCT', 'VEN', 'VGB', - 'VIR', 'VNM', 'VUT', 'WLF', 'WSM', 'YEM', 'ZAF', 'ZMB', - 'ZWE', 'all', 'ALL', 'All'] +WB_API_URL = "https://api.worldbank.org/v2" + +country_codes = [ + "AD", + "AE", + "AF", + "AG", + "AI", + "AL", + "AM", + "AO", + "AQ", + "AR", + "AS", + "AT", + "AU", + "AW", + "AX", + "AZ", + "BA", + "BB", + "BD", + "BE", + "BF", + "BG", + "BH", + "BI", + "BJ", + "BL", + "BM", + "BN", + "BO", + "BQ", + "BR", + "BS", + "BT", + "BV", + "BW", + "BY", + "BZ", + "CA", + "CC", + "CD", + "CF", + "CG", + "CH", + "CI", + "CK", + "CL", + "CM", + "CN", + "CO", + "CR", + "CU", + "CV", + "CW", + "CX", + "CY", + "CZ", + "DE", + "DJ", + "DK", + "DM", + "DO", + "DZ", + "EC", + "EE", + "EG", + "EH", + "ER", + "ES", + "ET", + "FI", + "FJ", + "FK", + "FM", + "FO", + "FR", + "GA", + "GB", + "GD", + "GE", + "GF", + "GG", + "GH", + "GI", + "GL", + "GM", + "GN", + "GP", + "GQ", + "GR", + "GS", + "GT", + "GU", + "GW", + "GY", + "HK", + "HM", + "HN", + "HR", + "HT", + "HU", + "ID", + "IE", + "IL", + "IM", + "IN", + "IO", + "IQ", + "IR", + "IS", + "IT", + "JE", + "JM", + "JO", + "JP", + "KE", + "KG", + "KH", + "KI", + "KM", + "KN", + "KP", + "KR", + "KW", + "KY", + "KZ", + "LA", + "LB", + "LC", + "LI", + "LK", + "LR", + "LS", + "LT", + "LU", + "LV", + "LY", + "MA", + "MC", + "MD", + "ME", + "MF", + "MG", + "MH", + "MK", + "ML", + "MM", + "MN", + "MO", + "MP", + "MQ", + "MR", + "MS", + "MT", + "MU", + "MV", + "MW", + "MX", + "MY", + "MZ", + "NA", + "NC", + "NE", + "NF", + "NG", + "NI", + "NL", + "NO", + "NP", + "NR", + "NU", + "NZ", + "OM", + "PA", + "PE", + "PF", + "PG", + "PH", + "PK", + "PL", + "PM", + "PN", + "PR", + "PS", + "PT", + "PW", + "PY", + "QA", + "RE", + "RO", + "RS", + "RU", + "RW", + "SA", + "SB", + "SC", + "SD", + "SE", + "SG", + "SH", + "SI", + "SJ", + "SK", + "SL", + "SM", + "SN", + "SO", + "SR", + "SS", + "ST", + "SV", + "SX", + "SY", + "SZ", + "TC", + "TD", + "TF", + "TG", + "TH", + "TJ", + "TK", + "TL", + "TM", + "TN", + "TO", + "TR", + "TT", + "TV", + "TW", + "TZ", + "UA", + "UG", + "UM", + "US", + "UY", + "UZ", + "VA", + "VC", + "VE", + "VG", + "VI", + "VN", + "VU", + "WF", + "WS", + "YE", + "YT", + "ZA", + "ZM", + "ZW", + "ABW", + "AFG", + "AGO", + "AIA", + "ALA", + "ALB", + "AND", + "ARE", + "ARG", + "ARM", + "ASM", + "ATA", + "ATF", + "ATG", + "AUS", + "AUT", + "AZE", + "BDI", + "BEL", + "BEN", + "BES", + "BFA", + "BGD", + "BGR", + "BHR", + "BHS", + "BIH", + "BLM", + "BLR", + "BLZ", + "BMU", + "BOL", + "BRA", + "BRB", + "BRN", + "BTN", + "BVT", + "BWA", + "CAF", + "CAN", + "CCK", + "CHE", + "CHL", + "CHN", + "CIV", + "CMR", + "COD", + "COG", + "COK", + "COL", + "COM", + "CPV", + "CRI", + "CUB", + "CUW", + "CXR", + "CYM", + "CYP", + "CZE", + "DEU", + "DJI", + "DMA", + "DNK", + "DOM", + "DZA", + "ECU", + "EGY", + "ERI", + "ESH", + "ESP", + "EST", + "ETH", + "FIN", + "FJI", + "FLK", + "FRA", + "FRO", + "FSM", + "GAB", + "GBR", + "GEO", + "GGY", + "GHA", + "GIB", + "GIN", + "GLP", + "GMB", + "GNB", + "GNQ", + "GRC", + "GRD", + "GRL", + "GTM", + "GUF", + "GUM", + "GUY", + "HKG", + "HMD", + "HND", + "HRV", + "HTI", + "HUN", + "IDN", + "IMN", + "IND", + "IOT", + "IRL", + "IRN", + "IRQ", + "ISL", + "ISR", + "ITA", + "JAM", + "JEY", + "JOR", + "JPN", + "KAZ", + "KEN", + "KGZ", + "KHM", + "KIR", + "KNA", + "KOR", + "KWT", + "LAO", + "LBN", + "LBR", + "LBY", + "LCA", + "LIE", + "LKA", + "LSO", + "LTU", + "LUX", + "LVA", + "MAC", + "MAF", + "MAR", + "MCO", + "MDA", + "MDG", + "MDV", + "MEX", + "MHL", + "MKD", + "MLI", + "MLT", + "MMR", + "MNE", + "MNG", + "MNP", + "MOZ", + "MRT", + "MSR", + "MTQ", + "MUS", + "MWI", + "MYS", + "MYT", + "NAM", + "NCL", + "NER", + "NFK", + "NGA", + "NIC", + "NIU", + "NLD", + "NOR", + "NPL", + "NRU", + "NZL", + "OMN", + "PAK", + "PAN", + "PCN", + "PER", + "PHL", + "PLW", + "PNG", + "POL", + "PRI", + "PRK", + "PRT", + "PRY", + "PSE", + "PYF", + "QAT", + "REU", + "ROU", + "RUS", + "RWA", + "SAU", + "SDN", + "SEN", + "SGP", + "SGS", + "SHN", + "SJM", + "SLB", + "SLE", + "SLV", + "SMR", + "SOM", + "SPM", + "SRB", + "SSD", + "STP", + "SUR", + "SVK", + "SVN", + "SWE", + "SWZ", + "SXM", + "SYC", + "SYR", + "TCA", + "TCD", + "TGO", + "THA", + "TJK", + "TKL", + "TKM", + "TLS", + "TON", + "TTO", + "TUN", + "TUR", + "TUV", + "TWN", + "TZA", + "UGA", + "UKR", + "UMI", + "URY", + "USA", + "UZB", + "VAT", + "VCT", + "VEN", + "VGB", + "VIR", + "VNM", + "VUT", + "WLF", + "WSM", + "YEM", + "ZAF", + "ZMB", + "ZWE", + "all", + "ALL", + "All", +] class WorldBankReader(_BaseReader): @@ -102,24 +548,37 @@ class WorldBankReader(_BaseReader): errors='raise', will raise a ValueError on a bad country code. """ - _format = 'json' - - def __init__(self, symbols=None, countries=None, - start=None, end=None, freq=None, - retry_count=3, pause=0.1, session=None, errors='warn'): + _format = "json" + + def __init__( + self, + symbols=None, + countries=None, + start=None, + end=None, + freq=None, + retry_count=3, + pause=0.1, + session=None, + errors="warn", + ): if symbols is None: - symbols = ['NY.GDP.MKTP.CD', 'NY.GNS.ICTR.ZS'] + symbols = ["NY.GDP.MKTP.CD", "NY.GNS.ICTR.ZS"] elif isinstance(symbols, string_types): symbols = [symbols] - super(WorldBankReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session) + super(WorldBankReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) if countries is None: - countries = ['MX', 'CA', 'US'] + countries = ["MX", "CA", "US"] elif isinstance(countries, string_types): countries = [countries] @@ -127,17 +586,17 @@ def __init__(self, symbols=None, countries=None, # Validate the input if len(bad_countries) > 0: tmp = ", ".join(bad_countries) - if errors == 'raise': + if errors == "raise": raise ValueError("Invalid Country Code(s): %s" % tmp) - if errors == 'warn': - warnings.warn('Non-standard ISO ' - 'country codes: %s' % tmp, UserWarning) + if errors == "warn": + warnings.warn( + "Non-standard ISO " "country codes: %s" % tmp, UserWarning + ) - freq_symbols = ['M', 'Q', 'A', None] + freq_symbols = ["M", "Q", "A", None] if freq not in freq_symbols: - msg = 'The frequency `{0}` is not in the accepted ' \ - 'list.'.format(freq) + msg = "The frequency `{0}` is not in the accepted " "list.".format(freq) raise ValueError(msg) self.freq = freq @@ -147,24 +606,34 @@ def __init__(self, symbols=None, countries=None, @property def url(self): """API URL""" - countries = ';'.join(self.countries) - return WB_API_URL + '/countries/' + countries + '/indicators/' + countries = ";".join(self.countries) + return WB_API_URL + "/countries/" + countries + "/indicators/" @property def params(self): """Parameters to use in API calls""" - if self.freq == 'M': - return {'date': '{0}M{1:02d}:{2}M{3:02d}'.format(self.start.year, - self.start.month, self.end.year, self.end.month), - 'per_page': 25000, 'format': 'json'} - elif self.freq == 'Q': - return {'date': '{0}Q{1}:{2}Q{3}'.format(self.start.year, - self.start.quarter, self.end.year, - self.end.quarter), 'per_page': 25000, - 'format': 'json'} + if self.freq == "M": + return { + "date": "{0}M{1:02d}:{2}M{3:02d}".format( + self.start.year, self.start.month, self.end.year, self.end.month + ), + "per_page": 25000, + "format": "json", + } + elif self.freq == "Q": + return { + "date": "{0}Q{1}:{2}Q{3}".format( + self.start.year, self.start.quarter, self.end.year, self.end.quarter + ), + "per_page": 25000, + "format": "json", + } else: - return {'date': '{0}:{1}'.format(self.start.year, self.end.year), - 'per_page': 25000, 'format': 'json'} + return { + "date": "{0}:{1}".format(self.start.year, self.end.year), + "per_page": 25000, + "format": "json", + } def read(self): """Read data""" @@ -179,22 +648,22 @@ def _read(self): # Build URL for api call try: df = self._read_one_data(self.url + indicator, self.params) - df.columns = ['country', 'iso_code', 'year', indicator] + df.columns = ["country", "iso_code", "year", indicator] data.append(df) except ValueError as e: - msg = str(e) + ' Indicator: ' + indicator - if self.errors == 'raise': + msg = str(e) + " Indicator: " + indicator + if self.errors == "raise": raise ValueError(msg) - elif self.errors == 'warn': + elif self.errors == "warn": warnings.warn(msg) # Confirm we actually got some data, and build Dataframe if len(data) > 0: - out = reduce(lambda x, y: x.merge(y, how='outer'), data) - out = out.drop('iso_code', axis=1) - out = out.set_index(['country', 'year']) - out = out.apply(pd.to_numeric, errors='ignore') + out = reduce(lambda x, y: x.merge(y, how="outer"), data) + out = out.drop("iso_code", axis=1) + out = out.set_index(["country", "year"]) + out = out.apply(pd.to_numeric, errors="ignore") return out else: @@ -205,32 +674,32 @@ def _read_lines(self, out): # Check to see if there is a possible problem possible_message = out[0] - if 'message' in possible_message.keys(): - msg = possible_message['message'][0] + if "message" in possible_message.keys(): + msg = possible_message["message"][0] try: - msg = msg['key'].split() + ["\n "] + msg['value'].split() - wb_err = ' '.join(msg) + msg = msg["key"].split() + ["\n "] + msg["value"].split() + wb_err = " ".join(msg) except Exception: wb_err = "" - if 'key' in msg.keys(): - wb_err = msg['key'] + "\n " - if 'value' in msg.keys(): - wb_err += msg['value'] + if "key" in msg.keys(): + wb_err = msg["key"] + "\n " + if "value" in msg.keys(): + wb_err += msg["value"] msg = "Problem with a World Bank Query \n %s." % wb_err raise ValueError(msg) - if 'total' in possible_message.keys(): - if possible_message['total'] == 0: + if "total" in possible_message.keys(): + if possible_message["total"] == 0: msg = "No results found from world bank." raise ValueError(msg) # Parse JSON file data = out[1] - country = [x['country']['value'] for x in data] - iso_code = [x['country']['id'] for x in data] - year = [x['date'] for x in data] - value = [x['value'] for x in data] + country = [x["country"]["value"] for x in data] + iso_code = [x["country"]["id"] for x in data] + year = [x["date"] for x in data] + value = [x["value"] for x in data] # Prepare output df = pd.DataFrame([country, iso_code, year, value]).T return df @@ -250,21 +719,19 @@ def get_countries(self): * and longitude """ - url = WB_API_URL + '/countries/?per_page=1000&format=json' + url = WB_API_URL + "/countries/?per_page=1000&format=json" resp = self._get_response(url) data = resp.json()[1] data = pd.DataFrame(data) - data.adminregion = [x['value'] for x in data.adminregion] - data.incomeLevel = [x['value'] for x in data.incomeLevel] - data.lendingType = [x['value'] for x in data.lendingType] - data.region = [x['value'] for x in data.region] - data.latitude = [float(x) if x != "" else np.nan - for x in data.latitude] - data.longitude = [float(x) if x != "" else np.nan - for x in data.longitude] - data = data.rename(columns={'id': 'iso3c', 'iso2Code': 'iso2c'}) + data.adminregion = [x["value"] for x in data.adminregion] + data.incomeLevel = [x["value"] for x in data.incomeLevel] + data.lendingType = [x["value"] for x in data.lendingType] + data.region = [x["value"] for x in data.region] + data.latitude = [float(x) if x != "" else np.nan for x in data.latitude] + data.longitude = [float(x) if x != "" else np.nan for x in data.longitude] + data = data.rename(columns={"id": "iso3c", "iso2Code": "iso2c"}) return data def get_indicators(self): @@ -273,35 +740,35 @@ def get_indicators(self): if isinstance(_cached_series, pd.DataFrame): return _cached_series.copy() - url = WB_API_URL + '/indicators?per_page=50000&format=json' + url = WB_API_URL + "/indicators?per_page=50000&format=json" resp = self._get_response(url) data = resp.json()[1] data = pd.DataFrame(data) # Clean fields - data.source = [x['value'] for x in data.source] + data.source = [x["value"] for x in data.source] def encode_ascii(x): - return x.encode('ascii', 'ignore') + return x.encode("ascii", "ignore") data.sourceOrganization = data.sourceOrganization.apply(encode_ascii) # Clean topic field def get_value(x): try: - return x['value'] + return x["value"] except Exception: - return '' + return "" def get_list_of_values(x): return [get_value(y) for y in x] data.topics = data.topics.apply(get_list_of_values) - data.topics = data.topics.apply(lambda x: ' ; '.join(x)) + data.topics = data.topics.apply(lambda x: " ; ".join(x)) # Clean output - data = data.sort_values(by='id') + data = data.sort_values(by="id") data.index = pd.Index(lrange(data.shape[0])) # cache @@ -309,7 +776,7 @@ def get_list_of_values(x): return data - def search(self, string='gdp.*capi', field='name', case=False): + def search(self, string="gdp.*capi", field="name", case=False): """ Search available data series from the world bank @@ -346,8 +813,15 @@ def search(self, string='gdp.*capi', field='name', case=False): return out -def download(country=None, indicator=None, start=2003, end=2005, freq=None, - errors='warn', **kwargs): +def download( + country=None, + indicator=None, + start=2003, + end=2005, + freq=None, + errors="warn", + **kwargs +): """ Download data series from the World Bank's World Development Indicators @@ -384,9 +858,15 @@ def download(country=None, indicator=None, start=2003, end=2005, freq=None, data : DataFrame DataFrame with columns country, iso_code, year, indicator value """ - return WorldBankReader(symbols=indicator, countries=country, - start=start, end=end, freq=freq, errors=errors, - **kwargs).read() + return WorldBankReader( + symbols=indicator, + countries=country, + start=start, + end=end, + freq=freq, + errors=errors, + **kwargs + ).read() def get_countries(**kwargs): @@ -422,7 +902,7 @@ def get_indicators(**kwargs): _cached_series = None -def search(string='gdp.*capi', field='name', case=False, **kwargs): +def search(string="gdp.*capi", field="name", case=False, **kwargs): """ Search available data series from the world bank @@ -455,5 +935,4 @@ def search(string='gdp.*capi', field='name', case=False, **kwargs): * topics: """ - return WorldBankReader(**kwargs).search(string=string, field=field, - case=case) + return WorldBankReader(**kwargs).search(string=string, field=field, case=case) diff --git a/pandas_datareader/yahoo/actions.py b/pandas_datareader/yahoo/actions.py index 5ca7e2b3..48a0f549 100644 --- a/pandas_datareader/yahoo/actions.py +++ b/pandas_datareader/yahoo/actions.py @@ -28,21 +28,21 @@ def get_actions(self): def _get_one_action(data): - actions = DataFrame(columns=['action', 'value']) + actions = DataFrame(columns=["action", "value"]) - if 'Dividends' in data.columns: + if "Dividends" in data.columns: # Add a label column so we can combine our two DFs - dividends = DataFrame(data['Dividends']).dropna() + dividends = DataFrame(data["Dividends"]).dropna() dividends["action"] = "DIVIDEND" - dividends = dividends.rename(columns={'Dividends': 'value'}) + dividends = dividends.rename(columns={"Dividends": "value"}) actions = concat([actions, dividends], sort=True) actions = actions.sort_index(ascending=False) - if 'Splits' in data.columns: + if "Splits" in data.columns: # Add a label column so we can combine our two DFs - splits = DataFrame(data['Splits']).dropna() + splits = DataFrame(data["Splits"]).dropna() splits["action"] = "SPLIT" - splits = splits.rename(columns={'Splits': 'value'}) + splits = splits.rename(columns={"Splits": "value"}) actions = concat([actions, splits], sort=True) actions = actions.sort_index(ascending=False) @@ -50,14 +50,12 @@ def _get_one_action(data): class YahooDivReader(YahooActionReader): - def read(self): data = super(YahooDivReader, self).read() - return data[data['action'] == 'DIVIDEND'] + return data[data["action"] == "DIVIDEND"] class YahooSplitReader(YahooActionReader): - def read(self): data = super(YahooSplitReader, self).read() - return data[data['action'] == 'SPLIT'] + return data[data["action"] == "SPLIT"] diff --git a/pandas_datareader/yahoo/components.py b/pandas_datareader/yahoo/components.py index 857acd23..004b6dd2 100644 --- a/pandas_datareader/yahoo/components.py +++ b/pandas_datareader/yahoo/components.py @@ -1,10 +1,9 @@ from pandas import DataFrame from pandas.io.common import urlopen -from pandas_datareader.exceptions import ImmediateDeprecationError, \ - DEP_ERROR_MSG +from pandas_datareader.exceptions import DEP_ERROR_MSG, ImmediateDeprecationError -_URL = 'http://download.finance.yahoo.com/d/quotes.csv?' +_URL = "http://download.finance.yahoo.com/d/quotes.csv?" def _get_data(idx_sym): # pragma: no cover @@ -28,13 +27,13 @@ def _get_data(idx_sym): # pragma: no cover ------- idx_df : DataFrame """ - raise ImmediateDeprecationError(DEP_ERROR_MSG.format('Yahoo Components')) - stats = 'snx' + raise ImmediateDeprecationError(DEP_ERROR_MSG.format("Yahoo Components")) + stats = "snx" # URL of form: # http://download.finance.yahoo.com/d/quotes.csv?s=@%5EIXIC&f=snxl1d1t1c1ohgv - url = _URL + 's={0}&f={1}&e=.csv&h={2}' + url = _URL + "s={0}&f={1}&e=.csv&h={2}" - idx_mod = idx_sym.replace('^', '@%5E') + idx_mod = idx_sym.replace("^", "@%5E") url_str = url.format(idx_mod, stats, 1) idx_df = DataFrame() @@ -47,12 +46,12 @@ def _get_data(idx_sym): # pragma: no cover url_str = url.format(idx_mod, stats, comp_idx) with urlopen(url_str) as resp: raw = resp.read() - lines = raw.decode('utf-8').strip().strip('"').split('"\r\n"') + lines = raw.decode("utf-8").strip().strip('"').split('"\r\n"') lines = [line.strip().split('","') for line in lines] - temp_df = DataFrame(lines, columns=['ticker', 'name', 'exchange']) + temp_df = DataFrame(lines, columns=["ticker", "name", "exchange"]) temp_df = temp_df.drop_duplicates() - temp_df = temp_df.set_index('ticker') + temp_df = temp_df.set_index("ticker") mask = ~temp_df.index.isin(idx_df.index) comp_idx = comp_idx + 50 diff --git a/pandas_datareader/yahoo/daily.py b/pandas_datareader/yahoo/daily.py index 821ca07d..2a383eb8 100644 --- a/pandas_datareader/yahoo/daily.py +++ b/pandas_datareader/yahoo/daily.py @@ -3,7 +3,9 @@ import json import re import time -from pandas import (DataFrame, to_datetime, notnull, isnull) + +from pandas import DataFrame, isnull, notnull, to_datetime + from pandas_datareader._utils import RemoteDataError from pandas_datareader.base import _DailyBaseReader @@ -49,26 +51,41 @@ class YahooDailyReader(_DailyBaseReader): If True, adjusts dividends for splits. """ - def __init__(self, symbols=None, start=None, end=None, retry_count=3, - pause=0.1, session=None, adjust_price=False, - ret_index=False, chunksize=1, interval='d', - get_actions=False, adjust_dividends=True): - super(YahooDailyReader, self).__init__(symbols=symbols, - start=start, end=end, - retry_count=retry_count, - pause=pause, session=session, - chunksize=chunksize) + def __init__( + self, + symbols=None, + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + adjust_price=False, + ret_index=False, + chunksize=1, + interval="d", + get_actions=False, + adjust_dividends=True, + ): + super(YahooDailyReader, self).__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + chunksize=chunksize, + ) # Ladder up the wait time between subsequent requests to improve # probability of a successful retry self.pause_multiplier = 2.5 self.headers = { - 'Connection': 'keep-alive', - 'Expires': str(-1), - 'Upgrade-Insecure-Requests': str(1), + "Connection": "keep-alive", + "Expires": str(-1), + "Upgrade-Insecure-Requests": str(1), # Google Chrome: - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.99 Safari/537.36' # noqa + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.99 Safari/537.36", # noqa } self.adjust_price = adjust_price @@ -76,17 +93,19 @@ def __init__(self, symbols=None, start=None, end=None, retry_count=3, self.interval = interval self._get_actions = get_actions - if self.interval not in ['d', 'wk', 'mo', 'm', 'w']: - raise ValueError("Invalid interval: valid values are 'd', 'wk' and 'mo'. 'm' and 'w' have been implemented for " # noqa - "backward compatibility. 'v' has been moved to the yahoo-actions or yahoo-dividends APIs.") # noqa - elif self.interval in ['m', 'mo']: - self.pdinterval = 'm' - self.interval = 'mo' - elif self.interval in ['w', 'wk']: - self.pdinterval = 'w' - self.interval = 'wk' - - self.interval = '1' + self.interval + if self.interval not in ["d", "wk", "mo", "m", "w"]: + raise ValueError( + "Invalid interval: valid values are 'd', 'wk' and 'mo'. 'm' and 'w' have been implemented for " # noqa + "backward compatibility. 'v' has been moved to the yahoo-actions or yahoo-dividends APIs." + ) # noqa + elif self.interval in ["m", "mo"]: + self.pdinterval = "m" + self.interval = "mo" + elif self.interval in ["w", "wk"]: + self.pdinterval = "w" + self.interval = "wk" + + self.interval = "1" + self.interval self.adjust_dividends = adjust_dividends @property @@ -95,7 +114,7 @@ def get_actions(self): @property def url(self): - return 'https://finance.yahoo.com/quote/{}/history' + return "https://finance.yahoo.com/quote/{}/history" # Test test_get_data_interval() crashed because of this issue, probably # whole yahoo part of package wasn't @@ -110,88 +129,87 @@ def _get_params(self, symbol): unix_end += four_hours_in_seconds params = { - 'period1': unix_start, - 'period2': unix_end, - 'interval': self.interval, - 'frequency': self.interval, - 'filter': 'history', - 'symbol': symbol + "period1": unix_start, + "period2": unix_end, + "interval": self.interval, + "frequency": self.interval, + "filter": "history", + "symbol": symbol, } return params def _read_one_data(self, url, params): """ read one data from specified symbol """ - symbol = params['symbol'] - del params['symbol'] + symbol = params["symbol"] + del params["symbol"] url = url.format(symbol) resp = self._get_response(url, params=params) - ptrn = r'root\.App\.main = (.*?);\n}\(this\)\);' + ptrn = r"root\.App\.main = (.*?);\n}\(this\)\);" try: j = json.loads(re.search(ptrn, resp.text, re.DOTALL).group(1)) - data = j['context']['dispatcher']['stores']['HistoricalPriceStore'] + data = j["context"]["dispatcher"]["stores"]["HistoricalPriceStore"] except KeyError: - msg = 'No data fetched for symbol {} using {}' + msg = "No data fetched for symbol {} using {}" raise RemoteDataError(msg.format(symbol, self.__class__.__name__)) # price data - prices = DataFrame(data['prices']) + prices = DataFrame(data["prices"]) prices.columns = [col.capitalize() for col in prices.columns] - prices['Date'] = to_datetime( - to_datetime(prices['Date'], unit='s').dt.date) + prices["Date"] = to_datetime(to_datetime(prices["Date"], unit="s").dt.date) - if 'Data' in prices.columns: - prices = prices[prices['Data'].isnull()] - prices = prices[['Date', 'High', 'Low', 'Open', 'Close', 'Volume', - 'Adjclose']] - prices = prices.rename(columns={'Adjclose': 'Adj Close'}) + if "Data" in prices.columns: + prices = prices[prices["Data"].isnull()] + prices = prices[["Date", "High", "Low", "Open", "Close", "Volume", "Adjclose"]] + prices = prices.rename(columns={"Adjclose": "Adj Close"}) - prices = prices.set_index('Date') - prices = prices.sort_index().dropna(how='all') + prices = prices.set_index("Date") + prices = prices.sort_index().dropna(how="all") if self.ret_index: - prices['Ret_Index'] = \ - _calc_return_index(prices['Adj Close']) + prices["Ret_Index"] = _calc_return_index(prices["Adj Close"]) if self.adjust_price: prices = _adjust_prices(prices) # dividends & splits data - if self.get_actions and data['eventsData']: + if self.get_actions and data["eventsData"]: - actions = DataFrame(data['eventsData']) + actions = DataFrame(data["eventsData"]) actions.columns = [col.capitalize() for col in actions.columns] - actions['Date'] = to_datetime( - to_datetime(actions['Date'], unit='s').dt.date) + actions["Date"] = to_datetime( + to_datetime(actions["Date"], unit="s").dt.date + ) - types = actions['Type'].unique() - if 'DIVIDEND' in types: - divs = actions[actions.Type == 'DIVIDEND'].copy() - divs = divs[['Date', 'Amount']].reset_index(drop=True) - divs = divs.set_index('Date') - divs = divs.rename(columns={'Amount': 'Dividends'}) - prices = prices.join(divs, how='outer') + types = actions["Type"].unique() + if "DIVIDEND" in types: + divs = actions[actions.Type == "DIVIDEND"].copy() + divs = divs[["Date", "Amount"]].reset_index(drop=True) + divs = divs.set_index("Date") + divs = divs.rename(columns={"Amount": "Dividends"}) + prices = prices.join(divs, how="outer") - if 'SPLIT' in types: + if "SPLIT" in types: def split_ratio(row): - if float(row['Numerator']) > 0: - return eval(row['Splitratio']) + if float(row["Numerator"]) > 0: + return eval(row["Splitratio"]) else: return 1 - splits = actions[actions.Type == 'SPLIT'].copy() - splits['SplitRatio'] = splits.apply(split_ratio, axis=1) + splits = actions[actions.Type == "SPLIT"].copy() + splits["SplitRatio"] = splits.apply(split_ratio, axis=1) splits = splits.reset_index(drop=True) - splits = splits.set_index('Date') - splits['Splits'] = splits['SplitRatio'] - prices = prices.join(splits['Splits'], how='outer') + splits = splits.set_index("Date") + splits["Splits"] = splits["SplitRatio"] + prices = prices.join(splits["Splits"], how="outer") - if 'DIVIDEND' in types and not self.adjust_dividends: + if "DIVIDEND" in types and not self.adjust_dividends: # dividends are adjusted automatically by Yahoo - adj = prices['Splits'].sort_index(ascending=False).fillna( - 1).cumprod() - prices['Dividends'] = prices['Dividends'] / adj + adj = ( + prices["Splits"].sort_index(ascending=False).fillna(1).cumprod() + ) + prices["Dividends"] = prices["Dividends"] / adj return prices @@ -202,14 +220,14 @@ def _adjust_prices(hist_data, price_list=None): 'Adj Close' price. Adds 'Adj_Ratio' column. """ if price_list is None: - price_list = 'Open', 'High', 'Low', 'Close' - adj_ratio = hist_data['Adj Close'] / hist_data['Close'] + price_list = "Open", "High", "Low", "Close" + adj_ratio = hist_data["Adj Close"] / hist_data["Close"] data = hist_data.copy() for item in price_list: data[item] = hist_data[item] * adj_ratio - data['Adj_Ratio'] = adj_ratio - del data['Adj Close'] + data["Adj_Ratio"] = adj_ratio + del data["Adj Close"] return data diff --git a/pandas_datareader/yahoo/fx.py b/pandas_datareader/yahoo/fx.py index adf264e7..68bd03f7 100644 --- a/pandas_datareader/yahoo/fx.py +++ b/pandas_datareader/yahoo/fx.py @@ -1,10 +1,12 @@ -import time import json +import time import warnings -from pandas import (DataFrame, Series, to_datetime, concat) -from pandas_datareader.yahoo.daily import YahooDailyReader -from pandas_datareader._utils import (RemoteDataError, SymbolWarning) + +from pandas import DataFrame, Series, concat, to_datetime + +from pandas_datareader._utils import RemoteDataError, SymbolWarning from pandas_datareader.compat import string_types +from pandas_datareader.yahoo.daily import YahooDailyReader class YahooFXReader(YahooDailyReader): @@ -41,13 +43,13 @@ def _get_params(self, symbol): unix_end = int(time.mktime(day_end.timetuple())) params = { - 'symbol': symbol + '=X', - 'period1': unix_start, - 'period2': unix_end, - 'interval': self.interval, # deal with this - 'includePrePost': 'true', - 'events': 'div|split|earn', - 'corsDomain': 'finance.yahoo.com' + "symbol": symbol + "=X", + "period1": unix_start, + "period2": unix_end, + "interval": self.interval, # deal with this + "includePrePost": "true", + "events": "div|split|earn", + "corsDomain": "finance.yahoo.com", } return params @@ -64,29 +66,27 @@ def read(self): else: df = self._dl_mult_symbols(self.symbols) - if 'Date' in df: - df = df.set_index('Date') + if "Date" in df: + df = df.set_index("Date") - if 'Volume' in df: - df = df.drop('Volume', axis=1) + if "Volume" in df: + df = df.drop("Volume", axis=1) - return df.sort_index().dropna(how='all') + return df.sort_index().dropna(how="all") finally: self.close() def _read_one_data(self, symbol): """ read one data from specified URL """ - url = 'https://query1.finance.yahoo.com/v8/finance/chart/{}=X'\ - .format(symbol) + url = "https://query1.finance.yahoo.com/v8/finance/chart/{}=X".format(symbol) params = self._get_params(symbol) resp = self._get_response(url, params=params) jsn = json.loads(resp.text) - data = jsn['chart']['result'][0] - df = DataFrame(data['indicators']['quote'][0]) - df.insert(0, 'date', to_datetime(Series(data['timestamp']), - unit='s').dt.date) + data = jsn["chart"]["result"][0] + df = DataFrame(data["indicators"]["quote"][0]) + df.insert(0, "date", to_datetime(Series(data["timestamp"]), unit="s").dt.date) df.columns = map(str.capitalize, df.columns) return df @@ -97,11 +97,11 @@ def _dl_mult_symbols(self, symbols): for sym in symbols: try: df = self._read_one_data(sym) - df['PairCode'] = sym + df["PairCode"] = sym stocks[sym] = df passed.append(sym) except IOError: - msg = 'Failed to read symbol: {0!r}, replacing with NaN.' + msg = "Failed to read symbol: {0!r}, replacing with NaN." warnings.warn(msg.format(sym), SymbolWarning) failed.append(sym) @@ -109,4 +109,4 @@ def _dl_mult_symbols(self, symbols): msg = "No data fetched using {0!r}" raise RemoteDataError(msg.format(self.__class__.__name__)) else: - return concat(stocks).set_index(['PairCode', 'Date']) + return concat(stocks).set_index(["PairCode", "Date"]) diff --git a/pandas_datareader/yahoo/options.py b/pandas_datareader/yahoo/options.py index a8a5eea8..30e7af34 100644 --- a/pandas_datareader/yahoo/options.py +++ b/pandas_datareader/yahoo/options.py @@ -1,13 +1,11 @@ -import warnings import datetime as dt -import numpy as np import json +import warnings -from pandas import to_datetime -from pandas import concat, DatetimeIndex, Series, MultiIndex +import numpy as np +from pandas import DataFrame, DatetimeIndex, MultiIndex, Series, concat, to_datetime from pandas.io.json import read_json from pandas.tseries.offsets import MonthEnd -from pandas import DataFrame from pandas_datareader._utils import RemoteDataError from pandas_datareader.base import _OptionBaseReader @@ -64,8 +62,7 @@ class Options(_OptionBaseReader): >>> all_data = aapl.get_all_data() """ - _OPTIONS_BASE_URL = ('https://query1.finance.yahoo.com/' - 'v7/finance/options/{sym}') + _OPTIONS_BASE_URL = "https://query1.finance.yahoo.com/" "v7/finance/options/{sym}" def get_options_data(self, month=None, year=None, expiry=None): """ @@ -135,27 +132,30 @@ def get_options_data(self, month=None, year=None, expiry=None): the year, month and day for the expiry of the options. """ - return concat([f(month, year, expiry) - for f in (self.get_put_data, - self.get_call_data)]).sort_index() + return concat( + [f(month, year, expiry) for f in (self.get_put_data, self.get_call_data)] + ).sort_index() def _option_from_url(self, url): jd = self._parse_url(url) - result = jd['optionChain']['result'] + result = jd["optionChain"]["result"] try: - calls = result['options']['calls'] - puts = result['options']['puts'] + calls = result["options"]["calls"] + puts = result["options"]["puts"] except IndexError: - raise RemoteDataError('Option json not available ' - 'for url: %s' % url) - - self.underlying_price = (result['quote']['regularMarketPrice'] - if result['quote']['marketState'] == 'PRE' - else result['quote']['preMarketPrice']) - quote_unix_time = (result['quote']['regularMarketTime'] - if result['quote']['marketState'] == 'PRE' - else result['quote']['preMarketTime']) + raise RemoteDataError("Option json not available " "for url: %s" % url) + + self.underlying_price = ( + result["quote"]["regularMarketPrice"] + if result["quote"]["marketState"] == "PRE" + else result["quote"]["preMarketPrice"] + ) + quote_unix_time = ( + result["quote"]["regularMarketTime"] + if result["quote"]["marketState"] == "PRE" + else result["quote"]["preMarketTime"] + ) self.quote_time = dt.datetime.fromtimestamp(quote_unix_time) calls = _parse_options_data(calls) @@ -164,10 +164,10 @@ def _option_from_url(self, url): calls = self._process_data(calls) puts = self._process_data(puts) - return {'calls': calls, 'puts': puts} + return {"calls": calls, "puts": puts} def _get_option_data(self, expiry, name): - frame_name = '_frames' + self._expiry_to_string(expiry) + frame_name = "_frames" + self._expiry_to_string(expiry) try: frames = getattr(self, frame_name) @@ -323,8 +323,9 @@ def get_put_data(self, month=None, year=None, expiry=None): expiry = self._try_parse_dates(year, month, expiry) return self._get_data_in_date_range(expiry, put=True, call=False) - def get_near_stock_price(self, above_below=2, call=True, put=False, - month=None, year=None, expiry=None): + def get_near_stock_price( + self, above_below=2, call=True, put=False, month=None, year=None, expiry=None + ): """ ***Experimental*** Returns a DataFrame of options that are near the current stock price. @@ -403,17 +404,19 @@ def _chop_data(self, df, above_below=2, underlying_price=None): except AttributeError: underlying_price = np.nan - max_strike = max(df.index.get_level_values('Strike')) - min_strike = min(df.index.get_level_values('Strike')) + max_strike = max(df.index.get_level_values("Strike")) + min_strike = min(df.index.get_level_values("Strike")) - if (not np.isnan(underlying_price) and - min_strike < underlying_price < max_strike): - start_index = np.where(df.index.get_level_values('Strike') > - underlying_price)[0][0] + if ( + not np.isnan(underlying_price) + and min_strike < underlying_price < max_strike + ): + start_index = np.where( + df.index.get_level_values("Strike") > underlying_price + )[0][0] - get_range = slice(start_index - above_below, - start_index + above_below + 1) - df = df[get_range].dropna(how='all') + get_range = slice(start_index - above_below, start_index + above_below + 1) + df = df[get_range].dropna(how="all") return df @@ -440,20 +443,25 @@ def _try_parse_dates(self, year, month, expiry): # Checks if the user gave one of the month or the year # but not both and did not provide an expiry: - if ((month is not None and year is None) or - (month is None and year is not None) and expiry is None): - msg = ("You must specify either (`year` and `month`) or `expiry` " - "or none of these options for the next expiry.") + if ( + (month is not None and year is None) + or (month is None and year is not None) + and expiry is None + ): + msg = ( + "You must specify either (`year` and `month`) or `expiry` " + "or none of these options for the next expiry." + ) raise ValueError(msg) if expiry is not None: - if hasattr(expiry, '__iter__'): + if hasattr(expiry, "__iter__"): expiry = [self._validate_expiry(exp) for exp in expiry] else: expiry = [self._validate_expiry(expiry)] if len(expiry) == 0: - raise ValueError('No expiries available for given input.') + raise ValueError("No expiries available for given input.") elif year is None and month is None: # No arguments passed, provide next expiry @@ -464,11 +472,13 @@ def _try_parse_dates(self, year, month, expiry): else: # Year and month passed, provide all expiries in that month - expiry = [expiry for expiry in self.expiry_dates - if expiry.year == year and expiry.month == month] + expiry = [ + expiry + for expiry in self.expiry_dates + if expiry.year == year and expiry.month == month + ] if len(expiry) == 0: - raise ValueError('No expiries available ' - 'in %s-%s' % (year, month)) + raise ValueError("No expiries available " "in %s-%s" % (year, month)) return expiry @@ -482,7 +492,7 @@ def _validate_expiry(self, expiry): expiry_dates = self.expiry_dates expiry = to_datetime(expiry) - if hasattr(expiry, 'date'): + if hasattr(expiry, "date"): expiry = expiry.date() if expiry in expiry_dates: @@ -491,8 +501,9 @@ def _validate_expiry(self, expiry): index = DatetimeIndex(expiry_dates).sort_values() return index[index.date >= expiry][0].date() - def get_forward_data(self, months, call=True, put=False, near=False, - above_below=2): # pragma: no cover + def get_forward_data( + self, months, call=True, put=False, near=False, above_below=2 + ): # pragma: no cover """ ***Experimental*** Gets either call, put, or both data for months starting in the current @@ -610,16 +621,16 @@ def get_all_data(self, call=True, put=True): def _get_data_in_date_range(self, dates, call=True, put=True): - to_ret = Series({'call': call, 'put': put}) + to_ret = Series({"call": call, "put": put}) to_ret = to_ret[to_ret].index df = self._load_data(dates) types = [typ for typ in to_ret] - df_filtered_by_type = df[df.index.map( - lambda x: x[2] in types).tolist()] + df_filtered_by_type = df[df.index.map(lambda x: x[2] in types).tolist()] df_filtered_by_expiry = df_filtered_by_type[ - df_filtered_by_type.index.get_level_values('Expiry').isin(dates)] + df_filtered_by_type.index.get_level_values("Expiry").isin(dates) + ] return df_filtered_by_expiry @property @@ -669,12 +680,13 @@ def _get_expiry_dates(self): url = self._OPTIONS_BASE_URL.format(sym=self.symbol) jd = self._parse_url(url) - expiry_dates = [dt.datetime.utcfromtimestamp(ts).date() - for ts in jd['optionChain'][ - 'result'][0]['expirationDates']] + expiry_dates = [ + dt.datetime.utcfromtimestamp(ts).date() + for ts in jd["optionChain"]["result"][0]["expirationDates"] + ] if len(expiry_dates) == 0: - raise RemoteDataError('Data not available') # pragma: no cover + raise RemoteDataError("Data not available") # pragma: no cover self._expiry_dates = expiry_dates return expiry_dates @@ -694,8 +706,9 @@ def _parse_url(self, url): """ jd = json.loads(self._read_url_as_StringIO(url).read()) if jd is None: # pragma: no cover - raise RemoteDataError("Parsed URL {0!r} is not " - "a valid json object".format(url)) + raise RemoteDataError( + "Parsed URL {0!r} is not " "a valid json object".format(url) + ) return jd def _process_data(self, jd): @@ -716,22 +729,39 @@ def _process_data(self, jd): A DataFrame with requested options data. """ - columns = ['Last', 'Bid', 'Ask', 'Chg', 'PctChg', 'Vol', - 'Open_Int', 'IV', 'Root', 'IsNonstandard', 'Underlying', - 'Underlying_Price', 'Quote_Time', 'Last_Trade_Date', 'JSON'] - indexes = ['Strike', 'Expiry', 'Type', 'Symbol'] + columns = [ + "Last", + "Bid", + "Ask", + "Chg", + "PctChg", + "Vol", + "Open_Int", + "IV", + "Root", + "IsNonstandard", + "Underlying", + "Underlying_Price", + "Quote_Time", + "Last_Trade_Date", + "JSON", + ] + indexes = ["Strike", "Expiry", "Type", "Symbol"] rows_list, index = self._process_rows(jd) if len(rows_list) > 0: - df = DataFrame(rows_list, columns=columns, - index=MultiIndex.from_tuples(index, names=indexes)) + df = DataFrame( + rows_list, + columns=columns, + index=MultiIndex.from_tuples(index, names=indexes), + ) else: df = DataFrame(columns=columns) - df['IsNonstandard'] = df['Root'] != self.symbol.replace('-', '') + df["IsNonstandard"] = df["Root"] != self.symbol.replace("-", "") # Make dtype consistent, requires float64 as there can be NaNs - df['Vol'] = df['Vol'].astype('float64') - df['Open_Int'] = df['Open_Int'].astype('float64') + df["Vol"] = df["Vol"].astype("float64") + df["Open_Int"] = df["Open_Int"].astype("float64") return df.sort_index() @@ -740,59 +770,65 @@ def _process_rows(self, jd): index = [] # handle no results - if len(jd['optionChain']['result']) <= 0: + if len(jd["optionChain"]["result"]) <= 0: return rows_list, index - quote = jd['optionChain']['result'][0]['quote'] - for option in jd['optionChain']['result'][0]['options']: - for typ in ['calls', 'puts']: + quote = jd["optionChain"]["result"][0]["quote"] + for option in jd["optionChain"]["result"][0]["options"]: + for typ in ["calls", "puts"]: options_by_type = option[typ] for option_by_strike in options_by_type: d = {} for dkey, rkey, ntype in [ - ('Last', 'lastPrice', float), - ('Bid', 'bid', float), - ('Ask', 'ask', float), - ('Chg', 'change', float), - ('PctChg', 'percentChange', float), - ('Vol', 'volume', int), - ('Open_Int', 'openInterest', int), - ('IV', 'impliedVolatility', float), - ('Last_Trade_Date', 'lastTradeDate', int) + ("Last", "lastPrice", float), + ("Bid", "bid", float), + ("Ask", "ask", float), + ("Chg", "change", float), + ("PctChg", "percentChange", float), + ("Vol", "volume", int), + ("Open_Int", "openInterest", int), + ("IV", "impliedVolatility", float), + ("Last_Trade_Date", "lastTradeDate", int), ]: try: d[dkey] = ntype(option_by_strike[rkey]) except (KeyError, ValueError): pass - d['JSON'] = option_by_strike - d['Root'] = option_by_strike['contractSymbol'][:-15] - d['Underlying'] = self.symbol - - d['Underlying_Price'] = quote['regularMarketPrice'] - quote_unix_time = quote['regularMarketTime'] - if (quote['marketState'] == 'PRE' and - 'preMarketPrice' in quote): - d['Underlying_Price'] = quote['preMarketPrice'] - quote_unix_time = quote['preMarketTime'] - elif (quote['marketState'] == 'POSTPOST' and - 'postMarketPrice' in quote): - d['Underlying_Price'] = quote['postMarketPrice'] - quote_unix_time = quote['postMarketTime'] - d['Quote_Time'] = dt.datetime.utcfromtimestamp( - quote_unix_time) - - self._underlying_price = d['Underlying_Price'] - self._quote_time = d['Quote_Time'] - - d['Last_Trade_Date'] = dt.datetime.utcfromtimestamp( - d['Last_Trade_Date']) + d["JSON"] = option_by_strike + d["Root"] = option_by_strike["contractSymbol"][:-15] + d["Underlying"] = self.symbol + + d["Underlying_Price"] = quote["regularMarketPrice"] + quote_unix_time = quote["regularMarketTime"] + if quote["marketState"] == "PRE" and "preMarketPrice" in quote: + d["Underlying_Price"] = quote["preMarketPrice"] + quote_unix_time = quote["preMarketTime"] + elif ( + quote["marketState"] == "POSTPOST" + and "postMarketPrice" in quote + ): + d["Underlying_Price"] = quote["postMarketPrice"] + quote_unix_time = quote["postMarketTime"] + d["Quote_Time"] = dt.datetime.utcfromtimestamp(quote_unix_time) + + self._underlying_price = d["Underlying_Price"] + self._quote_time = d["Quote_Time"] + + d["Last_Trade_Date"] = dt.datetime.utcfromtimestamp( + d["Last_Trade_Date"] + ) rows_list.append(d) - index.append((float(option_by_strike['strike']), - dt.datetime.utcfromtimestamp( - option_by_strike['expiration']), - typ.replace('s', ''), - option_by_strike['contractSymbol'])) + index.append( + ( + float(option_by_strike["strike"]), + dt.datetime.utcfromtimestamp( + option_by_strike["expiration"] + ), + typ.replace("s", ""), + option_by_strike["contractSymbol"], + ) + ) return rows_list, index def _load_data(self, exp_dates=None): @@ -816,14 +852,18 @@ def _load_data(self, exp_dates=None): try: if exp_dates is None: exp_dates = self._get_expiry_dates() - exp_unix_times = [int((dt.datetime(exp_date.year, - exp_date.month, - exp_date.day) - epoch - ).total_seconds()) - for exp_date in exp_dates] + exp_unix_times = [ + int( + ( + dt.datetime(exp_date.year, exp_date.month, exp_date.day) - epoch + ).total_seconds() + ) + for exp_date in exp_dates + ] for exp_date in exp_unix_times: - url = (self._OPTIONS_BASE_URL + '?date={exp_date}').format( - sym=self.symbol, exp_date=exp_date) + url = (self._OPTIONS_BASE_URL + "?date={exp_date}").format( + sym=self.symbol, exp_date=exp_date + ) jd = self._parse_url(url) data.append(self._process_data(jd)) return concat(data).sort_index() diff --git a/pandas_datareader/yahoo/quotes.py b/pandas_datareader/yahoo/quotes.py index b779e2ee..4dddff3b 100644 --- a/pandas_datareader/yahoo/quotes.py +++ b/pandas_datareader/yahoo/quotes.py @@ -1,14 +1,15 @@ -import json from collections import OrderedDict +import json + from pandas import DataFrame from pandas_datareader.base import _BaseReader from pandas_datareader.compat import string_types _DEFAULT_PARAMS = { - 'lang': 'en-US', - 'corsDomain': 'finance.yahoo.com', - '.tsrc': 'finance', + "lang": "en-US", + "corsDomain": "finance.yahoo.com", + ".tsrc": "finance", } @@ -18,7 +19,7 @@ class YahooQuotesReader(_BaseReader): @property def url(self): - return 'https://query1.finance.yahoo.com/v7/finance/quote' + return "https://query1.finance.yahoo.com/v7/finance/quote" def read(self): if isinstance(self.symbols, string_types): @@ -26,20 +27,20 @@ def read(self): else: data = OrderedDict() for symbol in self.symbols: - data[symbol] = \ - self._read_one_data( - self.url, self.params(symbol)).loc[symbol] - return DataFrame.from_dict(data, orient='index') + data[symbol] = self._read_one_data(self.url, self.params(symbol)).loc[ + symbol + ] + return DataFrame.from_dict(data, orient="index") def params(self, symbol): """Parameters to use in API calls""" # Construct the code request string. - params = {'symbols': symbol} + params = {"symbols": symbol} params.update(_DEFAULT_PARAMS) return params def _read_lines(self, out): - data = json.loads(out.read())['quoteResponse']['result'][0] - idx = data.pop('symbol') - data['price'] = data['regularMarketPrice'] + data = json.loads(out.read())["quoteResponse"]["result"][0] + idx = data.pop("symbol") + data["price"] = data["regularMarketPrice"] return DataFrame(data, index=[idx]) diff --git a/setup.cfg b/setup.cfg index 626ed74a..108cb1be 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,11 +13,13 @@ known_compat=pandas_datareader.compat.* sections=FUTURE,COMPAT,STDLIB,THIRDPARTY,PRE_CORE,FIRSTPARTY,LOCALFOLDER known_first_party=pandas_datareader known_third_party=numpy,pandas,pytest,requests -multi_line_output=0 +multi_line_output=3 force_grid_wrap=0 +use_parentheses=True +include_trailing_comma=True combine_as_imports=True force_sort_within_sections=True -line_width=99 +line_width=88 [tool:pytest] # sync minversion with setup.cfg & install.rst