diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index ec4a6995..010d6dff 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -1,4 +1,4 @@
- [ ] closes #xxxx
- [ ] tests added / passed
-- [ ] passes `git diff upstream/master -u -- "*.py" | flake8 --diff`
+- [ ] passes `black --check pandas_datareader`
- [ ] added entry to docs/source/whatsnew/vLATEST.txt
diff --git a/.travis.yml b/.travis.yml
index c665049e..ca050bde 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -35,15 +35,17 @@ matrix:
install:
- source ci/pypi-install.sh;
- - pip install codecov coveralls beautifulsoup4 flake8
- pip list
- python setup.py install
script:
+
- if [[ -n "${TEST_TYPE+x}" ]]; then export MARKERS="-m ${TEST_TYPE}"; fi
- pytest -s -r xX "${MARKERS}" --cov-config .coveragerc --cov=pandas_datareader --cov-report xml:/tmp/cov-datareader.xml --junitxml=/tmp/datareader.xml
- - flake8 --version
- - flake8 pandas_datareader
+ - |
+ if [[ "$TRAVIS_PYTHON_VERSION" != 2.7 ]]; then
+ black --check pandas_datareader
+ fi
after_script:
- |
diff --git a/ci/pypi-install.sh b/ci/pypi-install.sh
index 4d2ea62e..c14f9f2c 100644
--- a/ci/pypi-install.sh
+++ b/ci/pypi-install.sh
@@ -1,15 +1,19 @@
#!/usr/bin/env bash
-echo "PyPI install"
-
pip install pip --upgrade
-pip install numpy=="$NUMPY" pytz python-dateutil coverage setuptools html5lib lxml pytest pytest-cov wrapt
+pip install numpy=="$NUMPY" pytz python-dateutil coverage setuptools html5lib lxml pytest pytest-cov wrapt codecov coveralls beautifulsoup4 isort
+
+if [[ "$TRAVIS_PYTHON_VERSION" != 2.7 ]]; then
+ pip install black
+fi
+
if [[ "$PANDAS" == "MASTER" ]]; then
PRE_WHEELS="https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com"
pip install --pre --upgrade --timeout=60 -f "$PRE_WHEELS" pandas
else
pip install pandas=="$PANDAS"
fi
+
if [[ "$DOCBUILD" ]]; then
pip install sphinx ipython matplotlib sphinx_rtd_theme doctr
fi
diff --git a/pandas_datareader/__init__.py b/pandas_datareader/__init__.py
index 1246c2d7..67525138 100644
--- a/pandas_datareader/__init__.py
+++ b/pandas_datareader/__init__.py
@@ -1,24 +1,60 @@
from ._version import get_versions
-from .data import (DataReader, Options, get_components_yahoo,
- get_dailysummary_iex, get_data_enigma, get_data_famafrench,
- get_data_fred, get_data_moex, get_data_quandl,
- get_data_stooq, get_data_yahoo, get_data_yahoo_actions,
- get_iex_book, get_iex_symbols, get_last_iex,
- get_markets_iex, get_nasdaq_symbols, get_quote_yahoo,
- get_recent_iex, get_records_iex, get_summary_iex,
- get_tops_iex, get_data_tiingo, get_iex_data_tiingo,
- get_data_alphavantage)
+from .data import (
+ DataReader,
+ Options,
+ get_components_yahoo,
+ get_dailysummary_iex,
+ get_data_alphavantage,
+ get_data_enigma,
+ get_data_famafrench,
+ get_data_fred,
+ get_data_moex,
+ get_data_quandl,
+ get_data_stooq,
+ get_data_tiingo,
+ get_data_yahoo,
+ get_data_yahoo_actions,
+ get_iex_book,
+ get_iex_data_tiingo,
+ get_iex_symbols,
+ get_last_iex,
+ get_markets_iex,
+ get_nasdaq_symbols,
+ get_quote_yahoo,
+ get_recent_iex,
+ get_records_iex,
+ get_summary_iex,
+ get_tops_iex,
+)
-__version__ = get_versions()['version']
+__version__ = get_versions()["version"]
del get_versions
-__all__ = ['__version__', 'get_components_yahoo', 'get_data_enigma',
- 'get_data_famafrench', 'get_data_yahoo',
- 'get_data_yahoo_actions', 'get_quote_yahoo',
- 'get_iex_book', 'get_iex_symbols', 'get_last_iex',
- 'get_markets_iex', 'get_recent_iex', 'get_records_iex',
- 'get_summary_iex', 'get_tops_iex',
- 'get_nasdaq_symbols', 'get_data_quandl', 'get_data_moex',
- 'get_data_fred', 'get_dailysummary_iex',
- 'get_data_stooq', 'DataReader', 'Options',
- 'get_data_tiingo', 'get_iex_data_tiingo', 'get_data_alphavantage']
+__all__ = [
+ "__version__",
+ "get_components_yahoo",
+ "get_data_enigma",
+ "get_data_famafrench",
+ "get_data_yahoo",
+ "get_data_yahoo_actions",
+ "get_quote_yahoo",
+ "get_iex_book",
+ "get_iex_symbols",
+ "get_last_iex",
+ "get_markets_iex",
+ "get_recent_iex",
+ "get_records_iex",
+ "get_summary_iex",
+ "get_tops_iex",
+ "get_nasdaq_symbols",
+ "get_data_quandl",
+ "get_data_moex",
+ "get_data_fred",
+ "get_dailysummary_iex",
+ "get_data_stooq",
+ "DataReader",
+ "Options",
+ "get_data_tiingo",
+ "get_iex_data_tiingo",
+ "get_data_alphavantage",
+]
diff --git a/pandas_datareader/_utils.py b/pandas_datareader/_utils.py
index b6058620..8f1a5edf 100644
--- a/pandas_datareader/_utils.py
+++ b/pandas_datareader/_utils.py
@@ -1,7 +1,8 @@
import datetime as dt
-import requests
from pandas import to_datetime
+import requests
+
from pandas_datareader.compat import is_number
@@ -33,7 +34,7 @@ def _sanitize_dates(start, end):
if end is None:
end = dt.datetime.today()
if start > end:
- raise ValueError('start must be an earlier date than end')
+ raise ValueError("start must be an earlier date than end")
return start, end
diff --git a/pandas_datareader/_version.py b/pandas_datareader/_version.py
index 1f0230d8..da7d0e17 100644
--- a/pandas_datareader/_version.py
+++ b/pandas_datareader/_version.py
@@ -1,4 +1,3 @@
-
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
@@ -58,17 +57,18 @@ class NotThisMethod(Exception):
def register_vcs_handler(vcs, method): # decorator
"""Decorator to mark a method as the handler for a particular VCS."""
+
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
+
return decorate
-def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
- env=None):
+def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
p = None
@@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try:
dispcmd = str([c] + args)
# remember shell=False, so use git.cmd on windows, not just git
- p = subprocess.Popen([c] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None))
+ p = subprocess.Popen(
+ [c] + args,
+ cwd=cwd,
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr else None),
+ )
break
except EnvironmentError:
e = sys.exc_info()[1]
@@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for i in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
- return {"version": dirname[len(parentdir_prefix):],
- "full-revisionid": None,
- "dirty": False, "error": None, "date": None}
+ return {
+ "version": dirname[len(parentdir_prefix) :],
+ "full-revisionid": None,
+ "dirty": False,
+ "error": None,
+ "date": None,
+ }
else:
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
- print("Tried directories %s but none started with prefix %s" %
- (str(rootdirs), parentdir_prefix))
+ print(
+ "Tried directories %s but none started with prefix %s"
+ % (str(rootdirs), parentdir_prefix)
+ )
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
- tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
+ tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)])
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
@@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
- tags = set([r for r in refs if re.search(r'\d', r)])
+ tags = set([r for r in refs if re.search(r"\d", r)])
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
@@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
- r = ref[len(tag_prefix):]
+ r = ref[len(tag_prefix) :]
if verbose:
print("picking %s" % r)
- return {"version": r,
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": None,
- "date": date}
+ return {
+ "version": r,
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False,
+ "error": None,
+ "date": date,
+ }
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
- return {"version": "0+unknown",
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": "no suitable tags", "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False,
+ "error": "no suitable tags",
+ "date": None,
+ }
@register_vcs_handler("git", "pieces_from_vcs")
@@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
- out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
+ out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
@@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
- "--always", "--long",
- "--match", "%s*" % tag_prefix],
- cwd=root)
+ describe_out, rc = run_command(
+ GITS,
+ [
+ "describe",
+ "--tags",
+ "--dirty",
+ "--always",
+ "--long",
+ "--match",
+ "%s*" % tag_prefix,
+ ],
+ cwd=root,
+ )
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
@@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
- git_describe = git_describe[:git_describe.rindex("-dirty")]
+ git_describe = git_describe[: git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
- mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
+ mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparseable. Maybe git-describe is misbehaving?
- pieces["error"] = ("unable to parse git-describe output: '%s'"
- % describe_out)
+ pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
# tag
@@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
- pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
- % (full_tag, tag_prefix))
+ pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
+ full_tag,
+ tag_prefix,
+ )
return pieces
- pieces["closest-tag"] = full_tag[len(tag_prefix):]
+ pieces["closest-tag"] = full_tag[len(tag_prefix) :]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
@@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
- count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
- cwd=root)
+ count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
pieces["distance"] = int(count_out) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
- cwd=root)[0].strip()
+ date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[
+ 0
+ ].strip()
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
@@ -330,8 +355,7 @@ def render_pep440(pieces):
rendered += ".dirty"
else:
# exception #1
- rendered = "0+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
@@ -445,11 +469,13 @@ def render_git_describe_long(pieces):
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
- return {"version": "unknown",
- "full-revisionid": pieces.get("long"),
- "dirty": None,
- "error": pieces["error"],
- "date": None}
+ return {
+ "version": "unknown",
+ "full-revisionid": pieces.get("long"),
+ "dirty": None,
+ "error": pieces["error"],
+ "date": None,
+ }
if not style or style == "default":
style = "pep440" # the default
@@ -469,9 +495,13 @@ def render(pieces, style):
else:
raise ValueError("unknown style '%s'" % style)
- return {"version": rendered, "full-revisionid": pieces["long"],
- "dirty": pieces["dirty"], "error": None,
- "date": pieces.get("date")}
+ return {
+ "version": rendered,
+ "full-revisionid": pieces["long"],
+ "dirty": pieces["dirty"],
+ "error": None,
+ "date": pieces.get("date"),
+ }
def get_versions():
@@ -485,8 +515,7 @@ def get_versions():
verbose = cfg.verbose
try:
- return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
- verbose)
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
except NotThisMethod:
pass
@@ -495,13 +524,16 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
- for i in cfg.versionfile_source.split('/'):
+ for i in cfg.versionfile_source.split("/"):
root = os.path.dirname(root)
except NameError:
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to find root of source tree",
- "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to find root of source tree",
+ "date": None,
+ }
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
@@ -515,6 +547,10 @@ def get_versions():
except NotThisMethod:
pass
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to compute version", "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to compute version",
+ "date": None,
+ }
diff --git a/pandas_datareader/av/__init__.py b/pandas_datareader/av/__init__.py
index 808dd490..dfe6147b 100644
--- a/pandas_datareader/av/__init__.py
+++ b/pandas_datareader/av/__init__.py
@@ -1,11 +1,11 @@
import os
-from pandas_datareader.base import _BaseReader
-from pandas_datareader._utils import RemoteDataError
-
import pandas as pd
-AV_BASE_URL = 'https://www.alphavantage.co/query'
+from pandas_datareader._utils import RemoteDataError
+from pandas_datareader.base import _BaseReader
+
+AV_BASE_URL = "https://www.alphavantage.co/query"
class AlphaVantage(_BaseReader):
@@ -16,20 +16,36 @@ class AlphaVantage(_BaseReader):
-----
See `Alpha Vantage `__
"""
- _format = 'json'
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None, api_key=None):
- super(AlphaVantage, self).__init__(symbols=symbols, start=start,
- end=end, retry_count=retry_count,
- pause=pause, session=session)
+ _format = "json"
+
+ def __init__(
+ self,
+ symbols=None,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ api_key=None,
+ ):
+ super(AlphaVantage, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
if api_key is None:
- api_key = os.getenv('ALPHAVANTAGE_API_KEY')
+ api_key = os.getenv("ALPHAVANTAGE_API_KEY")
if not api_key or not isinstance(api_key, str):
- raise ValueError('The AlphaVantage API key must be provided '
- 'either through the api_key variable or '
- 'through the environment variable '
- 'ALPHAVANTAGE_API_KEY')
+ raise ValueError(
+ "The AlphaVantage API key must be provided "
+ "either through the api_key variable or "
+ "through the environment variable "
+ "ALPHAVANTAGE_API_KEY"
+ )
self.api_key = api_key
@property
@@ -39,10 +55,7 @@ def url(self):
@property
def params(self):
- return {
- 'function': self.function,
- 'apikey': self.api_key
- }
+ return {"function": self.function, "apikey": self.api_key}
@property
def function(self):
@@ -56,12 +69,14 @@ def data_key(self):
def _read_lines(self, out):
try:
- df = pd.DataFrame.from_dict(out[self.data_key], orient='index')
+ df = pd.DataFrame.from_dict(out[self.data_key], orient="index")
except KeyError:
if "Error Message" in out:
- raise ValueError("The requested symbol {} could not be "
- "retrived. Check valid ticker"
- ".".format(self.symbols))
+ raise ValueError(
+ "The requested symbol {} could not be "
+ "retrived. Check valid ticker"
+ ".".format(self.symbols)
+ )
else:
raise RemoteDataError()
df = df[sorted(df.columns)]
diff --git a/pandas_datareader/av/forex.py b/pandas_datareader/av/forex.py
index b4df0577..d0a195b4 100644
--- a/pandas_datareader/av/forex.py
+++ b/pandas_datareader/av/forex.py
@@ -1,8 +1,7 @@
-from pandas_datareader.av import AlphaVantage
+import pandas as pd
from pandas_datareader._utils import RemoteDataError
-
-import pandas as pd
+from pandas_datareader.av import AlphaVantage
class AVForexReader(AlphaVantage):
@@ -27,15 +26,20 @@ class AVForexReader(AlphaVantage):
Alpha Vantage API key . If not provided the environmental variable
ALPHAVANTAGE_API_KEY is read. The API key is *required*.
"""
- def __init__(self, symbols=None, retry_count=3, pause=0.1, session=None,
- api_key=None):
- super(AVForexReader, self).__init__(symbols=symbols,
- start=None, end=None,
- retry_count=retry_count,
- pause=pause,
- session=session,
- api_key=api_key)
+ def __init__(
+ self, symbols=None, retry_count=3, pause=0.1, session=None, api_key=None
+ ):
+
+ super(AVForexReader, self).__init__(
+ symbols=symbols,
+ start=None,
+ end=None,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=api_key,
+ )
self.from_curr = {}
self.to_curr = {}
self.optional_params = {}
@@ -45,28 +49,27 @@ def __init__(self, symbols=None, retry_count=3, pause=0.1, session=None,
self.symbols = symbols
try:
for pair in self.symbols:
- self.from_curr[pair] = pair.split('/')[0]
- self.to_curr[pair] = pair.split('/')[1]
+ self.from_curr[pair] = pair.split("/")[0]
+ self.to_curr[pair] = pair.split("/")[1]
except Exception as e:
print(e)
- raise ValueError("Please input a currency pair "
- "formatted 'FROM/TO' or a list of "
- "currency symbols")
+ raise ValueError(
+ "Please input a currency pair "
+ "formatted 'FROM/TO' or a list of "
+ "currency symbols"
+ )
@property
def function(self):
- return 'CURRENCY_EXCHANGE_RATE'
+ return "CURRENCY_EXCHANGE_RATE"
@property
def data_key(self):
- return 'Realtime Currency Exchange Rate'
+ return "Realtime Currency Exchange Rate"
@property
def params(self):
- params = {
- 'apikey': self.api_key,
- 'function': self.function
- }
+ params = {"apikey": self.api_key, "function": self.function}
params.update(self.optional_params)
return params
@@ -74,9 +77,9 @@ def read(self):
result = []
for pair in self.symbols:
self.optional_params = {
- 'from_currency': self.from_curr[pair],
- 'to_currency': self.to_curr[pair],
- }
+ "from_currency": self.from_curr[pair],
+ "to_currency": self.to_curr[pair],
+ }
data = super(AVForexReader, self).read()
result.append(data)
df = pd.concat(result, axis=1)
@@ -85,7 +88,7 @@ def read(self):
def _read_lines(self, out):
try:
- df = pd.DataFrame.from_dict(out[self.data_key], orient='index')
+ df = pd.DataFrame.from_dict(out[self.data_key], orient="index")
except KeyError:
raise RemoteDataError()
df.sort_index(ascending=True, inplace=True)
diff --git a/pandas_datareader/av/quotes.py b/pandas_datareader/av/quotes.py
index ae2e1657..8b9b16a6 100644
--- a/pandas_datareader/av/quotes.py
+++ b/pandas_datareader/av/quotes.py
@@ -1,7 +1,7 @@
-from pandas_datareader.av import AlphaVantage
-
-import pandas as pd
import numpy as np
+import pandas as pd
+
+from pandas_datareader.av import AlphaVantage
class AVQuotesReader(AlphaVantage):
@@ -22,8 +22,10 @@ class AVQuotesReader(AlphaVantage):
session : Session, default None
requests.sessions.Session instance to be used
"""
- def __init__(self, symbols=None, retry_count=3, pause=0.1, session=None,
- api_key=None):
+
+ def __init__(
+ self, symbols=None, retry_count=3, pause=0.1, session=None, api_key=None
+ ):
if isinstance(symbols, str):
syms = [symbols]
elif isinstance(symbols, list):
@@ -31,27 +33,30 @@ def __init__(self, symbols=None, retry_count=3, pause=0.1, session=None,
raise ValueError("Up to 100 symbols at once are allowed.")
else:
syms = symbols
- super(AVQuotesReader, self).__init__(symbols=syms,
- start=None, end=None,
- retry_count=retry_count,
- pause=pause,
- session=session,
- api_key=api_key)
+ super(AVQuotesReader, self).__init__(
+ symbols=syms,
+ start=None,
+ end=None,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=api_key,
+ )
@property
def function(self):
- return 'BATCH_STOCK_QUOTES'
+ return "BATCH_STOCK_QUOTES"
@property
def data_key(self):
- return 'Stock Quotes'
+ return "Stock Quotes"
@property
def params(self):
return {
- 'symbols': ','.join(self.symbols),
- 'function': self.function,
- 'apikey': self.api_key,
+ "symbols": ",".join(self.symbols),
+ "function": self.function,
+ "apikey": self.api_key,
}
def _read_lines(self, out):
@@ -61,14 +66,13 @@ def _read_lines(self, out):
df = pd.DataFrame(quote, index=[0])
df.columns = [col[3:] for col in df.columns]
df.set_index("symbol", inplace=True)
- df["price"] = df["price"].astype('float64')
+ df["price"] = df["price"].astype("float64")
try:
- df["volume"] = df["volume"].astype('int64')
+ df["volume"] = df["volume"].astype("int64")
except ValueError:
df["volume"] = [np.nan * len(self.symbols)]
result.append(df)
if len(result) != len(self.symbols):
- raise ValueError("Not all symbols downloaded. Check valid "
- "ticker(s).")
+ raise ValueError("Not all symbols downloaded. Check valid " "ticker(s).")
else:
return pd.concat(result)
diff --git a/pandas_datareader/av/sector.py b/pandas_datareader/av/sector.py
index 43a2f3c8..46d445ca 100644
--- a/pandas_datareader/av/sector.py
+++ b/pandas_datareader/av/sector.py
@@ -1,7 +1,7 @@
import pandas as pd
-from pandas_datareader.av import AlphaVantage
from pandas_datareader._utils import RemoteDataError
+from pandas_datareader.av import AlphaVantage
class AVSectorPerformanceReader(AlphaVantage):
@@ -25,9 +25,10 @@ class AVSectorPerformanceReader(AlphaVantage):
Alpha Vantage API key . If not provided the environmental variable
ALPHAVANTAGE_API_KEY is read. The API key is *required*.
"""
+
@property
def function(self):
- return 'SECTOR'
+ return "SECTOR"
def _read_lines(self, out):
if "Information" in out:
@@ -35,7 +36,6 @@ def _read_lines(self, out):
else:
out.pop("Meta Data")
df = pd.DataFrame(out)
- columns = ["RT", "1D", "5D", "1M", "3M", "YTD", "1Y", "3Y", "5Y",
- "10Y"]
+ columns = ["RT", "1D", "5D", "1M", "3M", "YTD", "1Y", "3Y", "5Y", "10Y"]
df.columns = columns
return df
diff --git a/pandas_datareader/av/time_series.py b/pandas_datareader/av/time_series.py
index 13b6ef2d..8513bdad 100644
--- a/pandas_datareader/av/time_series.py
+++ b/pandas_datareader/av/time_series.py
@@ -1,7 +1,7 @@
-from pandas_datareader.av import AlphaVantage
-
from datetime import datetime
+from pandas_datareader.av import AlphaVantage
+
class AVTimeSeriesReader(AlphaVantage):
"""
@@ -29,24 +29,38 @@ class AVTimeSeriesReader(AlphaVantage):
AlphaVantage API key . If not provided the environmental variable
ALPHAVANTAGE_API_KEY is read. The API key is *required*.
"""
+
_FUNC_TO_DATA_KEY = {
- "TIME_SERIES_DAILY": "Time Series (Daily)",
- "TIME_SERIES_DAILY_ADJUSTED": "Time Series (Daily)",
- "TIME_SERIES_WEEKLY": "Weekly Time Series",
- "TIME_SERIES_WEEKLY_ADJUSTED": "Weekly Adjusted Time Series",
- "TIME_SERIES_MONTHLY": "Monthly Time Series",
- "TIME_SERIES_MONTHLY_ADJUSTED": "Monthly Adjusted Time Series",
- "TIME_SERIES_INTRADAY": "Time Series (1min)"
+ "TIME_SERIES_DAILY": "Time Series (Daily)",
+ "TIME_SERIES_DAILY_ADJUSTED": "Time Series (Daily)",
+ "TIME_SERIES_WEEKLY": "Weekly Time Series",
+ "TIME_SERIES_WEEKLY_ADJUSTED": "Weekly Adjusted Time Series",
+ "TIME_SERIES_MONTHLY": "Monthly Time Series",
+ "TIME_SERIES_MONTHLY_ADJUSTED": "Monthly Adjusted Time Series",
+ "TIME_SERIES_INTRADAY": "Time Series (1min)",
}
- def __init__(self, symbols=None, function="TIME_SERIES_DAILY",
- start=None, end=None, retry_count=3, pause=0.1,
- session=None, chunksize=25, api_key=None):
- super(AVTimeSeriesReader, self).__init__(symbols=symbols, start=start,
- end=end,
- retry_count=retry_count,
- pause=pause, session=session,
- api_key=api_key)
+ def __init__(
+ self,
+ symbols=None,
+ function="TIME_SERIES_DAILY",
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ chunksize=25,
+ api_key=None,
+ ):
+ super(AVTimeSeriesReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=api_key,
+ )
self._func = function
@@ -60,7 +74,7 @@ def output_size(self):
possible.
"""
delta = datetime.now() - self.start
- return 'full' if delta.days > 80 else 'compact'
+ return "full" if delta.days > 80 else "compact"
@property
def data_key(self):
@@ -72,7 +86,7 @@ def params(self):
"symbol": self.symbols,
"function": self.function,
"apikey": self.api_key,
- "outputsize": self.output_size
+ "outputsize": self.output_size,
}
if self.function == "TIME_SERIES_INTRADAY":
p.update({"interval": "1min"})
@@ -82,15 +96,15 @@ def _read_lines(self, out):
data = super(AVTimeSeriesReader, self)._read_lines(out)
# reverse since alphavantage returns descending by date
data = data[::-1]
- start_str = self.start.strftime('%Y-%m-%d')
- end_str = self.end.strftime('%Y-%m-%d')
+ start_str = self.start.strftime("%Y-%m-%d")
+ end_str = self.end.strftime("%Y-%m-%d")
data = data.loc[start_str:end_str]
if data.empty:
raise ValueError("Please input a valid date range")
else:
for column in data.columns:
- if column == 'volume':
- data[column] = data[column].astype('int64')
+ if column == "volume":
+ data[column] = data[column].astype("int64")
else:
- data[column] = data[column].astype('float64')
+ data[column] = data[column].astype("float64")
return data
diff --git a/pandas_datareader/bankofcanada.py b/pandas_datareader/bankofcanada.py
index c8ca75b1..0b0ea004 100644
--- a/pandas_datareader/bankofcanada.py
+++ b/pandas_datareader/bankofcanada.py
@@ -1,8 +1,7 @@
from __future__ import unicode_literals
-from pandas_datareader.compat import string_types
-
from pandas_datareader.base import _BaseReader
+from pandas_datareader.compat import string_types
class BankOfCanadaReader(_BaseReader):
@@ -12,26 +11,28 @@ class BankOfCanadaReader(_BaseReader):
-----
See `Bank of Canada `__"""
- _URL = 'http://www.bankofcanada.ca/valet/observations'
+ _URL = "http://www.bankofcanada.ca/valet/observations"
@property
def url(self):
"""API URL"""
if not isinstance(self.symbols, string_types):
- raise ValueError('data name must be string')
+ raise ValueError("data name must be string")
- return '{0}/{1}/csv'.format(self._URL, self.symbols)
+ return "{0}/{1}/csv".format(self._URL, self.symbols)
@property
def params(self):
"""Parameters to use in API calls"""
- return {'start_date': self.start.strftime('%Y-%m-%d'),
- 'end_date': self.end.strftime('%Y-%m-%d')}
+ return {
+ "start_date": self.start.strftime("%Y-%m-%d"),
+ "end_date": self.end.strftime("%Y-%m-%d"),
+ }
@staticmethod
def _sanitize_response(response):
"""
Clean up the response string
"""
- data = response.text.split('OBSERVATIONS')[1]
- return data.split('ERRORS')[0].strip()
+ data = response.text.split("OBSERVATIONS")[1]
+ return data.split("ERRORS")[0].strip()
diff --git a/pandas_datareader/base.py b/pandas_datareader/base.py
index 7e63839d..ddac3ce4 100644
--- a/pandas_datareader/base.py
+++ b/pandas_datareader/base.py
@@ -1,17 +1,18 @@
import time
import warnings
-import numpy as np
-
-import requests
-from pandas import DataFrame
-from pandas import read_csv, concat
+import numpy as np
+from pandas import DataFrame, concat, read_csv
from pandas.io.common import urlencode
-from pandas_datareader.compat import bytes_to_str, string_types, binary_type, \
- StringIO
+import requests
-from pandas_datareader._utils import (RemoteDataError, SymbolWarning,
- _sanitize_dates, _init_session)
+from pandas_datareader._utils import (
+ RemoteDataError,
+ SymbolWarning,
+ _init_session,
+ _sanitize_dates,
+)
+from pandas_datareader.compat import StringIO, binary_type, bytes_to_str, string_types
class _BaseReader(object):
@@ -36,10 +37,19 @@ class _BaseReader(object):
"""
_chunk_size = 1024 * 1024
- _format = 'string'
-
- def __init__(self, symbols, start=None, end=None, retry_count=3,
- pause=0.1, timeout=30, session=None, freq=None):
+ _format = "string"
+
+ def __init__(
+ self,
+ symbols,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ timeout=30,
+ session=None,
+ freq=None,
+ ):
self.symbols = symbols
@@ -80,9 +90,9 @@ def read(self):
def _read_one_data(self, url, params):
""" read one data from specified URL """
- if self._format == 'string':
+ if self._format == "string":
out = self._read_url_as_StringIO(url, params=params)
- elif self._format == 'json':
+ elif self._format == "json":
out = self._get_response(url, params=params).json()
else:
raise NotImplementedError(self._format)
@@ -97,8 +107,10 @@ def _read_url_as_StringIO(self, url, params=None):
out = StringIO()
if len(text) == 0:
service = self.__class__.__name__
- raise IOError("{} request returned no data; check URL for invalid "
- "inputs: {}".format(service, self.url))
+ raise IOError(
+ "{} request returned no data; check URL for invalid "
+ "inputs: {}".format(service, self.url)
+ )
if isinstance(text, binary_type):
out.write(bytes_to_str(text))
else:
@@ -125,11 +137,9 @@ def _get_response(self, url, params=None, headers=None):
# initial attempt + retry
pause = self.pause
- last_response_text = ''
+ last_response_text = ""
for i in range(self.retry_count + 1):
- response = self.session.get(url,
- params=params,
- headers=headers)
+ response = self.session.get(url, params=params, headers=headers)
if response.status_code == requests.codes.ok:
return response
@@ -140,8 +150,8 @@ def _get_response(self, url, params=None, headers=None):
# Increase time between subsequent requests, per subclass.
pause *= self.pause_multiplier
# Get a new breadcrumb if necessary, in case ours is invalidated
- if isinstance(params, list) and 'crumb' in params:
- params['crumb'] = self._get_crumb(self.retry_count)
+ if isinstance(params, list) and "crumb" in params:
+ params["crumb"] = self._get_crumb(self.retry_count)
# If our output error function returns True, exit the loop.
if self._output_error(response):
@@ -149,9 +159,9 @@ def _get_response(self, url, params=None, headers=None):
if params is not None and len(params) > 0:
url = url + "?" + urlencode(params)
- msg = 'Unable to read URL: {0}'.format(url)
+ msg = "Unable to read URL: {0}".format(url)
if last_response_text:
- msg += '\nResponse Text:\n{0}'.format(last_response_text)
+ msg += "\nResponse Text:\n{0}".format(last_response_text)
raise RemoteDataError(msg)
@@ -169,8 +179,7 @@ def _output_error(self, out):
return False
def _read_lines(self, out):
- rs = read_csv(out, index_col=0, parse_dates=True,
- na_values=('-', 'null'))[::-1]
+ rs = read_csv(out, index_col=0, parse_dates=True, na_values=("-", "null"))[::-1]
# Needed to remove blank space character in header names
rs.columns = list(map(lambda x: x.strip(), rs.columns.values.tolist()))
@@ -180,11 +189,12 @@ def _read_lines(self, out):
rs = rs[:-1]
# Get rid of unicode characters in index name.
try:
- rs.index.name = rs.index.name.decode(
- 'unicode_escape').encode('ascii', 'ignore')
+ rs.index.name = rs.index.name.decode("unicode_escape").encode(
+ "ascii", "ignore"
+ )
except AttributeError:
# Python 3 string has no decode method.
- rs.index.name = rs.index.name.encode('ascii', 'ignore').decode()
+ rs.index.name = rs.index.name.encode("ascii", "ignore").decode()
return rs
@@ -192,12 +202,24 @@ def _read_lines(self, out):
class _DailyBaseReader(_BaseReader):
""" Base class for Google / Yahoo daily reader """
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None, chunksize=25):
- super(_DailyBaseReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+ def __init__(
+ self,
+ symbols=None,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ chunksize=25,
+ ):
+ super(_DailyBaseReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
self.chunksize = chunksize
def _get_params(self, *args, **kwargs):
@@ -207,8 +229,7 @@ def read(self):
"""Read data"""
# If a single symbol, (e.g., 'GOOG')
if isinstance(self.symbols, (string_types, int)):
- df = self._read_one_data(self.url,
- params=self._get_params(self.symbols))
+ df = self._read_one_data(self.url, params=self._get_params(self.symbols))
# Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT'])
elif isinstance(self.symbols, DataFrame):
df = self._dl_mult_symbols(self.symbols.index)
@@ -223,11 +244,10 @@ def _dl_mult_symbols(self, symbols):
for sym_group in _in_chunks(symbols, self.chunksize):
for sym in sym_group:
try:
- stocks[sym] = self._read_one_data(self.url,
- self._get_params(sym))
+ stocks[sym] = self._read_one_data(self.url, self._get_params(sym))
passed.append(sym)
except IOError:
- msg = 'Failed to read symbol: {0!r}, replacing with NaN.'
+ msg = "Failed to read symbol: {0!r}, replacing with NaN."
warnings.warn(msg.format(sym), SymbolWarning)
failed.append(sym)
@@ -241,7 +261,7 @@ def _dl_mult_symbols(self, symbols):
for sym in failed:
stocks[sym] = df_na
result = concat(stocks).unstack(level=0)
- result.columns.names = ['Attributes', 'Symbols']
+ result.columns.names = ["Attributes", "Symbols"]
return result
except AttributeError:
# cannot construct a panel with just 1D nans indicating no data
@@ -253,16 +273,14 @@ def _in_chunks(seq, size):
"""
Return sequence in 'chunks' of size defined by size
"""
- return (seq[pos:pos + size] for pos in range(0, len(seq), size))
+ return (seq[pos : pos + size] for pos in range(0, len(seq), size))
class _OptionBaseReader(_BaseReader):
-
def __init__(self, symbol, session=None):
""" Instantiates options_data with a ticker saved as symbol """
self.symbol = symbol.upper()
- super(_OptionBaseReader, self).__init__(symbols=symbol,
- session=session)
+ super(_OptionBaseReader, self).__init__(symbols=symbol, session=session)
def get_options_data(self, month=None, year=None, expiry=None):
"""
@@ -288,16 +306,18 @@ def get_put_data(self, month=None, year=None, expiry=None):
"""
raise NotImplementedError
- def get_near_stock_price(self, above_below=2, call=True, put=False,
- month=None, year=None, expiry=None):
+ def get_near_stock_price(
+ self, above_below=2, call=True, put=False, month=None, year=None, expiry=None
+ ):
"""
***Experimental***
Returns a data frame of options that are near the current stock price.
"""
raise NotImplementedError
- def get_forward_data(self, months, call=True, put=False, near=False,
- above_below=2): # pragma: no cover
+ def get_forward_data(
+ self, months, call=True, put=False, near=False, above_below=2
+ ): # pragma: no cover
"""
***Experimental***
Gets either call, put, or both data for months starting in the current
diff --git a/pandas_datareader/compat/__init__.py b/pandas_datareader/compat/__init__.py
index 1d364077..a342b0e7 100644
--- a/pandas_datareader/compat/__init__.py
+++ b/pandas_datareader/compat/__init__.py
@@ -1,17 +1,18 @@
# flake8: noqa
+from distutils.version import LooseVersion
+import sys
+
import pandas as pd
import pandas.io.common as com
-import sys
-from distutils.version import LooseVersion
PY3 = sys.version_info >= (3, 0)
PANDAS_VERSION = LooseVersion(pd.__version__)
-PANDAS_0190 = (PANDAS_VERSION >= LooseVersion('0.19.0'))
-PANDAS_0200 = (PANDAS_VERSION >= LooseVersion('0.20.0'))
-PANDAS_0210 = (PANDAS_VERSION >= LooseVersion('0.21.0'))
-PANDAS_0230 = (PANDAS_VERSION >= LooseVersion('0.23.0'))
+PANDAS_0190 = PANDAS_VERSION >= LooseVersion("0.19.0")
+PANDAS_0200 = PANDAS_VERSION >= LooseVersion("0.20.0")
+PANDAS_0210 = PANDAS_VERSION >= LooseVersion("0.21.0")
+PANDAS_0230 = PANDAS_VERSION >= LooseVersion("0.23.0")
if PANDAS_0190:
from pandas.api.types import is_number
@@ -23,8 +24,7 @@
if PANDAS_0200:
from pandas.util.testing import assert_raises_regex
- def get_filepath_or_buffer(filepath_or_buffer, encoding=None,
- compression=None):
+ def get_filepath_or_buffer(filepath_or_buffer, encoding=None, compression=None):
# Dictionaries are no longer considered valid inputs
# for "get_filepath_or_buffer" starting in pandas >= 0.20.0
@@ -32,9 +32,13 @@ def get_filepath_or_buffer(filepath_or_buffer, encoding=None,
return filepath_or_buffer, encoding, compression
return com.get_filepath_or_buffer(
- filepath_or_buffer, encoding=encoding, compression=None)
+ filepath_or_buffer, encoding=encoding, compression=None
+ )
+
+
else:
from pandas.util.testing import assertRaisesRegexp as assert_raises_regex
+
get_filepath_or_buffer = com.get_filepath_or_buffer
if PANDAS_0190:
@@ -46,28 +50,28 @@ def get_filepath_or_buffer(filepath_or_buffer, encoding=None,
from urllib.error import HTTPError
from functools import reduce
- string_types = str,
+ string_types = (str,)
binary_type = bytes
from io import StringIO
def str_to_bytes(s, encoding=None):
- return s.encode(encoding or 'ascii')
-
+ return s.encode(encoding or "ascii")
def bytes_to_str(b, encoding=None):
- return b.decode(encoding or 'utf-8')
+ return b.decode(encoding or "utf-8")
+
+
else:
from urllib2 import HTTPError
from cStringIO import StringIO
+
reduce = reduce
binary_type = str
- string_types = basestring,
-
+ string_types = (basestring,)
def bytes_to_str(b, encoding=None):
return b
-
def str_to_bytes(s, encoding=None):
return s
@@ -84,6 +88,6 @@ def concat(*args, **kwargs):
"""
Shim to wokr around sort keyword
"""
- if not PANDAS_0230 and 'sort' in kwargs:
- del kwargs['sort']
+ if not PANDAS_0230 and "sort" in kwargs:
+ del kwargs["sort"]
return pd.concat(*args, **kwargs)
diff --git a/pandas_datareader/conftest.py b/pandas_datareader/conftest.py
index 79f3bc66..df71d4bd 100644
--- a/pandas_datareader/conftest.py
+++ b/pandas_datareader/conftest.py
@@ -21,7 +21,7 @@ def datapath(request):
ValueError
If the path doesn't exist and the --strict-data-files option is set.
"""
- BASE_PATH = os.path.join(os.path.dirname(__file__), 'tests')
+ BASE_PATH = os.path.join(os.path.dirname(__file__), "tests")
def deco(*args):
path = os.path.join(BASE_PATH, *args)
@@ -33,4 +33,5 @@ def deco(*args):
msg = "Could not find {}."
pytest.skip(msg.format(path))
return path
+
return deco
diff --git a/pandas_datareader/data.py b/pandas_datareader/data.py
index b7b80abb..6ff45f16 100644
--- a/pandas_datareader/data.py
+++ b/pandas_datareader/data.py
@@ -12,39 +12,54 @@
from pandas_datareader.econdb import EcondbReader
from pandas_datareader.enigma import EnigmaReader
from pandas_datareader.eurostat import EurostatReader
-from pandas_datareader.exceptions import DEP_ERROR_MSG, \
- ImmediateDeprecationError
+from pandas_datareader.exceptions import DEP_ERROR_MSG, ImmediateDeprecationError
from pandas_datareader.famafrench import FamaFrenchReader
from pandas_datareader.fred import FredReader
from pandas_datareader.iex.daily import IEXDailyReader
from pandas_datareader.iex.deep import Deep as IEXDeep
-from pandas_datareader.iex.tops import LastReader as IEXLasts, \
- TopsReader as IEXTops
+from pandas_datareader.iex.tops import LastReader as IEXLasts, TopsReader as IEXTops
from pandas_datareader.moex import MoexReader
from pandas_datareader.nasdaq_trader import get_nasdaq_symbols
from pandas_datareader.oecd import OECDReader
from pandas_datareader.quandl import QuandlReader
-from pandas_datareader.robinhood import RobinhoodHistoricalReader, \
- RobinhoodQuoteReader
+from pandas_datareader.robinhood import RobinhoodHistoricalReader, RobinhoodQuoteReader
from pandas_datareader.stooq import StooqDailyReader
-from pandas_datareader.tiingo import (TiingoDailyReader, TiingoQuoteReader,
- TiingoIEXHistoricalReader)
-from pandas_datareader.yahoo.actions import (YahooActionReader, YahooDivReader)
-from pandas_datareader.yahoo.components import _get_data as \
- get_components_yahoo
+from pandas_datareader.tiingo import (
+ TiingoDailyReader,
+ TiingoIEXHistoricalReader,
+ TiingoQuoteReader,
+)
+from pandas_datareader.yahoo.actions import YahooActionReader, YahooDivReader
+from pandas_datareader.yahoo.components import _get_data as get_components_yahoo
from pandas_datareader.yahoo.daily import YahooDailyReader
from pandas_datareader.yahoo.options import Options as YahooOptions
from pandas_datareader.yahoo.quotes import YahooQuotesReader
-__all__ = ['get_components_yahoo', 'get_data_enigma', 'get_data_famafrench',
- 'get_data_fred', 'get_data_moex',
- 'get_data_quandl', 'get_data_yahoo', 'get_data_yahoo_actions',
- 'get_nasdaq_symbols', 'get_quote_yahoo',
- 'get_tops_iex', 'get_summary_iex', 'get_records_iex',
- 'get_recent_iex', 'get_markets_iex', 'get_last_iex',
- 'get_iex_symbols', 'get_iex_book', 'get_dailysummary_iex',
- 'get_data_stooq', 'get_data_robinhood',
- 'get_quotes_robinhood', 'DataReader']
+__all__ = [
+ "get_components_yahoo",
+ "get_data_enigma",
+ "get_data_famafrench",
+ "get_data_fred",
+ "get_data_moex",
+ "get_data_quandl",
+ "get_data_yahoo",
+ "get_data_yahoo_actions",
+ "get_nasdaq_symbols",
+ "get_quote_yahoo",
+ "get_tops_iex",
+ "get_summary_iex",
+ "get_records_iex",
+ "get_recent_iex",
+ "get_markets_iex",
+ "get_last_iex",
+ "get_iex_symbols",
+ "get_iex_book",
+ "get_dailysummary_iex",
+ "get_data_stooq",
+ "get_data_robinhood",
+ "get_quotes_robinhood",
+ "DataReader",
+]
def get_data_alphavantage(*args, **kwargs):
@@ -139,6 +154,7 @@ def get_markets_iex(*args, **kwargs):
:return: DataFrame
"""
from pandas_datareader.iex.market import MarketReader
+
return MarketReader(*args, **kwargs).read()
@@ -157,6 +173,7 @@ def get_dailysummary_iex(*args, **kwargs):
:return: DataFrame
"""
from pandas_datareader.iex.stats import DailySummaryReader
+
return DailySummaryReader(*args, **kwargs).read()
@@ -175,6 +192,7 @@ def get_summary_iex(*args, **kwargs):
:return: DataFrame
"""
from pandas_datareader.iex.stats import MonthlySummaryReader
+
return MonthlySummaryReader(*args, **kwargs).read()
@@ -189,6 +207,7 @@ def get_records_iex(*args, **kwargs):
:return: DataFrame
"""
from pandas_datareader.iex.stats import RecordsReader
+
return RecordsReader(*args, **kwargs).read()
@@ -203,6 +222,7 @@ def get_recent_iex(*args, **kwargs):
:return: DataFrame
"""
from pandas_datareader.iex.stats import RecentReader
+
return RecentReader(*args, **kwargs).read()
@@ -216,6 +236,7 @@ def get_iex_symbols(*args, **kwargs):
:return: DataFrame
"""
from pandas_datareader.iex.ref import SymbolsReader
+
return SymbolsReader(*args, **kwargs).read()
@@ -241,8 +262,16 @@ def get_iex_book(*args, **kwargs):
return IEXDeep(*args, **kwargs).read()
-def DataReader(name, data_source=None, start=None, end=None,
- retry_count=3, pause=0.1, session=None, access_key=None):
+def DataReader(
+ name,
+ data_source=None,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ access_key=None,
+):
"""
Imports data from a number of online sources.
@@ -291,170 +320,327 @@ def DataReader(name, data_source=None, start=None, end=None,
ff = DataReader("6_Portfolios_2x3", "famafrench")
ff = DataReader("F-F_ST_Reversal_Factor", "famafrench")
"""
- expected_source = ["yahoo", "iex", "iex-tops", "iex-last",
- "iex-last", "bankofcanada", "stooq", "iex-book",
- "enigma", "fred", "famafrench", "oecd", "eurostat",
- "nasdaq", "quandl", "moex", 'robinhood',
- "tiingo", "yahoo-actions", "yahoo-dividends",
- "av-forex", "av-daily", "av-daily-adjusted",
- "av-weekly", "av-weekly-adjusted", "av-monthly",
- "av-monthly-adjusted", "av-intraday", "econdb"]
+ expected_source = [
+ "yahoo",
+ "iex",
+ "iex-tops",
+ "iex-last",
+ "iex-last",
+ "bankofcanada",
+ "stooq",
+ "iex-book",
+ "enigma",
+ "fred",
+ "famafrench",
+ "oecd",
+ "eurostat",
+ "nasdaq",
+ "quandl",
+ "moex",
+ "robinhood",
+ "tiingo",
+ "yahoo-actions",
+ "yahoo-dividends",
+ "av-forex",
+ "av-daily",
+ "av-daily-adjusted",
+ "av-weekly",
+ "av-weekly-adjusted",
+ "av-monthly",
+ "av-monthly-adjusted",
+ "av-intraday",
+ "econdb",
+ ]
if data_source not in expected_source:
msg = "data_source=%r is not implemented" % data_source
raise NotImplementedError(msg)
if data_source == "yahoo":
- return YahooDailyReader(symbols=name, start=start, end=end,
- adjust_price=False, chunksize=25,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return YahooDailyReader(
+ symbols=name,
+ start=start,
+ end=end,
+ adjust_price=False,
+ chunksize=25,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "iex":
- return IEXDailyReader(symbols=name, start=start, end=end,
- chunksize=25, api_key=access_key,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return IEXDailyReader(
+ symbols=name,
+ start=start,
+ end=end,
+ chunksize=25,
+ api_key=access_key,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "iex-tops":
- return IEXTops(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return IEXTops(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "iex-last":
- return IEXLasts(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return IEXLasts(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "bankofcanada":
- return BankOfCanadaReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return BankOfCanadaReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "stooq":
- return StooqDailyReader(symbols=name,
- chunksize=25,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return StooqDailyReader(
+ symbols=name,
+ chunksize=25,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "iex-book":
- return IEXDeep(symbols=name, service="book", start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return IEXDeep(
+ symbols=name,
+ service="book",
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "enigma":
return EnigmaReader(dataset_id=name, api_key=access_key).read()
elif data_source == "fred":
- return FredReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return FredReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "famafrench":
- return FamaFrenchReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return FamaFrenchReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "oecd":
- return OECDReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return OECDReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "eurostat":
- return EurostatReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
- elif data_source == 'nasdaq':
- if name != 'symbols':
- raise ValueError("Only the string 'symbols' is supported for "
- "Nasdaq, not %r" % (name,))
+ return EurostatReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
+ elif data_source == "nasdaq":
+ if name != "symbols":
+ raise ValueError(
+ "Only the string 'symbols' is supported for " "Nasdaq, not %r" % (name,)
+ )
return get_nasdaq_symbols(retry_count=retry_count, pause=pause)
elif data_source == "quandl":
- return QuandlReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session, api_key=access_key).read()
+ return QuandlReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "moex":
- return MoexReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
- elif data_source == 'robinhood':
- return RobinhoodHistoricalReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
- elif data_source == 'tiingo':
- return TiingoDailyReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session,
- api_key=access_key).read()
+ return MoexReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
+ elif data_source == "robinhood":
+ return RobinhoodHistoricalReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
+ elif data_source == "tiingo":
+ return TiingoDailyReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "yahoo-actions":
- return YahooActionReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return YahooActionReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
elif data_source == "yahoo-dividends":
- return YahooDivReader(symbols=name, start=start, end=end,
- adjust_price=False, chunksize=25,
- retry_count=retry_count, pause=pause,
- session=session, interval='d').read()
+ return YahooDivReader(
+ symbols=name,
+ start=start,
+ end=end,
+ adjust_price=False,
+ chunksize=25,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ interval="d",
+ ).read()
elif data_source == "av-forex":
- return AVForexReader(symbols=name, retry_count=retry_count,
- pause=pause, session=session,
- api_key=access_key).read()
+ return AVForexReader(
+ symbols=name,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "av-daily":
- return AVTimeSeriesReader(symbols=name,
- function="TIME_SERIES_DAILY", start=start,
- end=end, retry_count=retry_count,
- pause=pause, session=session,
- api_key=access_key).read()
+ return AVTimeSeriesReader(
+ symbols=name,
+ function="TIME_SERIES_DAILY",
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "av-daily-adjusted":
- return AVTimeSeriesReader(symbols=name,
- function="TIME_SERIES_DAILY_ADJUSTED",
- start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session, api_key=access_key).read()
+ return AVTimeSeriesReader(
+ symbols=name,
+ function="TIME_SERIES_DAILY_ADJUSTED",
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "av-weekly":
- return AVTimeSeriesReader(symbols=name,
- function="TIME_SERIES_WEEKLY", start=start,
- end=end, retry_count=retry_count,
- pause=pause, session=session,
- api_key=access_key).read()
+ return AVTimeSeriesReader(
+ symbols=name,
+ function="TIME_SERIES_WEEKLY",
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "av-weekly-adjusted":
- return AVTimeSeriesReader(symbols=name,
- function="TIME_SERIES_WEEKLY_ADJUSTED",
- start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session, api_key=access_key).read()
+ return AVTimeSeriesReader(
+ symbols=name,
+ function="TIME_SERIES_WEEKLY_ADJUSTED",
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "av-monthly":
- return AVTimeSeriesReader(symbols=name,
- function="TIME_SERIES_MONTHLY", start=start,
- end=end, retry_count=retry_count,
- pause=pause, session=session,
- api_key=access_key).read()
+ return AVTimeSeriesReader(
+ symbols=name,
+ function="TIME_SERIES_MONTHLY",
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "av-monthly-adjusted":
- return AVTimeSeriesReader(symbols=name,
- function="TIME_SERIES_MONTHLY_ADJUSTED",
- start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session, api_key=access_key).read()
+ return AVTimeSeriesReader(
+ symbols=name,
+ function="TIME_SERIES_MONTHLY_ADJUSTED",
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "av-intraday":
- return AVTimeSeriesReader(symbols=name,
- function="TIME_SERIES_INTRADAY",
- start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session, api_key=access_key).read()
+ return AVTimeSeriesReader(
+ symbols=name,
+ function="TIME_SERIES_INTRADAY",
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ api_key=access_key,
+ ).read()
elif data_source == "econdb":
- return EcondbReader(symbols=name, start=start, end=end,
- retry_count=retry_count, pause=pause,
- session=session).read()
+ return EcondbReader(
+ symbols=name,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ ).read()
else:
msg = "data_source=%r is not implemented" % data_source
@@ -463,11 +649,15 @@ def DataReader(name, data_source=None, start=None, end=None,
def Options(symbol, data_source=None, session=None):
if data_source is None:
- warnings.warn("Options(symbol) is deprecated, use Options(symbol,"
- " data_source) instead", FutureWarning, stacklevel=2)
+ warnings.warn(
+ "Options(symbol) is deprecated, use Options(symbol,"
+ " data_source) instead",
+ FutureWarning,
+ stacklevel=2,
+ )
data_source = "yahoo"
if data_source == "yahoo":
- raise ImmediateDeprecationError(DEP_ERROR_MSG.format('Yahoo Options'))
+ raise ImmediateDeprecationError(DEP_ERROR_MSG.format("Yahoo Options"))
return YahooOptions(symbol, session=session)
else:
raise NotImplementedError("currently only yahoo supported")
diff --git a/pandas_datareader/econdb.py b/pandas_datareader/econdb.py
index 99312f2f..11277488 100644
--- a/pandas_datareader/econdb.py
+++ b/pandas_datareader/econdb.py
@@ -7,44 +7,50 @@
class EcondbReader(_BaseReader):
"""Get data for the given name from Econdb."""
- _URL = 'https://www.econdb.com/api/series/'
+ _URL = "https://www.econdb.com/api/series/"
_format = None
- _show = 'labels'
+ _show = "labels"
@property
def url(self):
"""API URL"""
if not isinstance(self.symbols, str):
- raise ValueError('data name must be string')
+ raise ValueError("data name must be string")
- return ('{0}?{1}&format=json&page_size=500&expand=both'
- .format(self._URL, self.symbols))
+ return "{0}?{1}&format=json&page_size=500&expand=both".format(
+ self._URL, self.symbols
+ )
def read(self):
""" read one data from specified URL """
- results = requests.get(self.url).json()['results']
- df = pd.DataFrame({'dates': []}).set_index('dates')
+ results = requests.get(self.url).json()["results"]
+ df = pd.DataFrame({"dates": []}).set_index("dates")
- if self._show == 'labels':
- def show_func(x): return x.split(':')[1]
- elif self._show == 'codes':
- def show_func(x): return x.split(':')[0]
+ if self._show == "labels":
+
+ def show_func(x):
+ return x.split(":")[1]
+
+ elif self._show == "codes":
+
+ def show_func(x):
+ return x.split(":")[0]
for entry in results:
- series = (pd.DataFrame(entry['data'])[['dates', 'values']]
- .set_index('dates'))
+ series = pd.DataFrame(entry["data"])[["dates", "values"]].set_index("dates")
- head = entry['additional_metadata']
+ head = entry["additional_metadata"]
if head != "": # this additional metadata is not blank
series.columns = pd.MultiIndex.from_tuples(
[[show_func(x) for x in head.values()]],
- names=[show_func(x) for x in head.keys()])
+ names=[show_func(x) for x in head.keys()],
+ )
if not df.empty:
- df = df.join(series, how='outer')
+ df = df.join(series, how="outer")
else:
df = series
- df.index = pd.to_datetime(df.index, errors='ignore')
- df.index.name = 'TIME_PERIOD'
+ df.index = pd.to_datetime(df.index, errors="ignore")
+ df.index.name = "TIME_PERIOD"
df = df.truncate(self.start, self.end)
return df
diff --git a/pandas_datareader/enigma.py b/pandas_datareader/enigma.py
index 6513b223..d2442293 100644
--- a/pandas_datareader/enigma.py
+++ b/pandas_datareader/enigma.py
@@ -42,24 +42,29 @@ class EnigmaReader(_BaseReader):
>>> df = EnigmaReader(dataset_id='bedaf052-5fcd-4758-8d27-048ce8746c6a',
... api_key='INSERT_API_KEY').read()
"""
- def __init__(self,
- dataset_id=None,
- api_key=None,
- retry_count=5,
- pause=.75,
- session=None,
- base_url="https://public.enigma.com/api"):
-
- super(EnigmaReader, self).__init__(symbols=[],
- retry_count=retry_count,
- pause=pause, session=session)
+
+ def __init__(
+ self,
+ dataset_id=None,
+ api_key=None,
+ retry_count=5,
+ pause=0.75,
+ session=None,
+ base_url="https://public.enigma.com/api",
+ ):
+
+ super(EnigmaReader, self).__init__(
+ symbols=[], retry_count=retry_count, pause=pause, session=session
+ )
if api_key is None:
- self._api_key = os.getenv('ENIGMA_API_KEY')
+ self._api_key = os.getenv("ENIGMA_API_KEY")
if self._api_key is None:
- raise ValueError("Please provide an Enigma API key or set "
- "the ENIGMA_API_KEY environment variable\n"
- "If you do not have an API key, you can get "
- "one here: http://public.enigma.com/signup")
+ raise ValueError(
+ "Please provide an Enigma API key or set "
+ "the ENIGMA_API_KEY environment variable\n"
+ "If you do not have an API key, you can get "
+ "one here: http://public.enigma.com/signup"
+ )
else:
self._api_key = api_key
@@ -67,11 +72,12 @@ def __init__(self,
if not isinstance(self._dataset_id, string_types):
raise ValueError(
"The Enigma dataset_id must be a string (ex: "
- "'bedaf052-5fcd-4758-8d27-048ce8746c6a')")
+ "'bedaf052-5fcd-4758-8d27-048ce8746c6a')"
+ )
headers = {
- 'Authorization': 'Bearer {0}'.format(self._api_key),
- 'User-Agent': 'pandas-datareader',
+ "Authorization": "Bearer {0}".format(self._api_key),
+ "User-Agent": "pandas-datareader",
}
self.session.headers.update(headers)
self._base_url = base_url
@@ -111,7 +117,7 @@ def _get(self, url):
def get_current_snapshot_id(self, dataset_id):
"""Get ID of the most current snapshot of a dataset"""
dataset_metadata = self.get_dataset_metadata(dataset_id)
- return dataset_metadata['current_snapshot']['id']
+ return dataset_metadata["current_snapshot"]["id"]
def get_dataset_metadata(self, dataset_id):
"""Get the Dataset Model of this EnigmaReader's dataset
diff --git a/pandas_datareader/eurostat.py b/pandas_datareader/eurostat.py
index 8777fcb3..fd59037b 100644
--- a/pandas_datareader/eurostat.py
+++ b/pandas_datareader/eurostat.py
@@ -2,34 +2,32 @@
import pandas as pd
-from pandas_datareader.io.sdmx import read_sdmx, _read_sdmx_dsd
-from pandas_datareader.compat import string_types
from pandas_datareader.base import _BaseReader
+from pandas_datareader.compat import string_types
+from pandas_datareader.io.sdmx import _read_sdmx_dsd, read_sdmx
class EurostatReader(_BaseReader):
"""Get data for the given name from Eurostat."""
- _URL = 'http://ec.europa.eu/eurostat/SDMX/diss-web/rest'
+ _URL = "http://ec.europa.eu/eurostat/SDMX/diss-web/rest"
@property
def url(self):
"""API URL"""
if not isinstance(self.symbols, string_types):
- raise ValueError('data name must be string')
+ raise ValueError("data name must be string")
- q = '{0}/data/{1}/?startperiod={2}&endperiod={3}'
- return q.format(self._URL, self.symbols,
- self.start.year, self.end.year)
+ q = "{0}/data/{1}/?startperiod={2}&endperiod={3}"
+ return q.format(self._URL, self.symbols, self.start.year, self.end.year)
@property
def dsd_url(self):
"""API DSD URL"""
if not isinstance(self.symbols, string_types):
- raise ValueError('data name must be string')
+ raise ValueError("data name must be string")
- return '{0}/datastructure/ESTAT/DSD_{1}'.format(
- self._URL, self.symbols)
+ return "{0}/datastructure/ESTAT/DSD_{1}".format(self._URL, self.symbols)
def _read_one_data(self, url, params):
resp_dsd = self._get_response(self.dsd_url)
diff --git a/pandas_datareader/famafrench.py b/pandas_datareader/famafrench.py
index 6aa6230a..eb2b9174 100644
--- a/pandas_datareader/famafrench.py
+++ b/pandas_datareader/famafrench.py
@@ -1,17 +1,16 @@
import datetime as dt
import re
import tempfile
-
from zipfile import ZipFile
from pandas import read_csv, to_datetime
-from pandas_datareader.compat import lmap, StringIO
from pandas_datareader.base import _BaseReader
+from pandas_datareader.compat import StringIO, lmap
-_URL = 'http://mba.tuck.dartmouth.edu/pages/faculty/ken.french/'
-_URL_PREFIX = 'ftp/'
-_URL_SUFFIX = '_CSV.zip'
+_URL = "http://mba.tuck.dartmouth.edu/pages/faculty/ken.french/"
+_URL_PREFIX = "ftp/"
+_URL_SUFFIX = "_CSV.zip"
def get_available_datasets(**kwargs):
@@ -27,13 +26,13 @@ def get_available_datasets(**kwargs):
-------
A list of valid inputs for get_data_famafrench.
"""
- return FamaFrenchReader(symbols='', **kwargs).get_available_datasets()
+ return FamaFrenchReader(symbols="", **kwargs).get_available_datasets()
def _parse_date_famafrench(x):
x = x.strip()
try:
- return dt.datetime.strptime(x, '%Y%m')
+ return dt.datetime.strptime(x, "%Y%m")
except Exception:
pass
return to_datetime(x)
@@ -50,7 +49,7 @@ class FamaFrenchReader(_BaseReader):
@property
def url(self):
"""API URL"""
- return ''.join([_URL, _URL_PREFIX, self.symbols, _URL_SUFFIX])
+ return "".join([_URL, _URL_PREFIX, self.symbols, _URL_SUFFIX])
def _read_zipfile(self, url):
raw = self._get_response(url).content
@@ -58,7 +57,7 @@ def _read_zipfile(self, url):
with tempfile.TemporaryFile() as tmpf:
tmpf.write(raw)
- with ZipFile(tmpf, 'r') as zf:
+ with ZipFile(tmpf, "r") as zf:
data = zf.open(zf.namelist()[0]).read().decode()
return data
@@ -77,40 +76,42 @@ def read(self):
def _read_one_data(self, url, params):
- params = {'index_col': 0,
- 'parse_dates': [0],
- 'date_parser': _parse_date_famafrench}
+ params = {
+ "index_col": 0,
+ "parse_dates": [0],
+ "date_parser": _parse_date_famafrench,
+ }
# headers in these files are not valid
- if self.symbols.endswith('_Breakpoints'):
+ if self.symbols.endswith("_Breakpoints"):
- if self.symbols.find('-') > -1:
- c = ['<=0', '>0']
+ if self.symbols.find("-") > -1:
+ c = ["<=0", ">0"]
else:
- c = ['Count']
+ c = ["Count"]
r = list(range(0, 105, 5))
- params['names'] = ['Date'] + c + list(zip(r, r[1:]))
+ params["names"] = ["Date"] + c + list(zip(r, r[1:]))
- if self.symbols != 'Prior_2-12_Breakpoints':
- params['skiprows'] = 1
+ if self.symbols != "Prior_2-12_Breakpoints":
+ params["skiprows"] = 1
else:
- params['skiprows'] = 3
+ params["skiprows"] = 3
doc_chunks, tables = [], []
data = self._read_zipfile(url)
- for chunk in data.split(2 * '\r\n'):
+ for chunk in data.split(2 * "\r\n"):
if len(chunk) < 800:
- doc_chunks.append(chunk.replace('\r\n', ' ').strip())
+ doc_chunks.append(chunk.replace("\r\n", " ").strip())
else:
tables.append(chunk)
datasets, table_desc = {}, []
for i, src in enumerate(tables):
- match = re.search(r'^\s*,', src, re.M) # the table starts there
+ match = re.search(r"^\s*,", src, re.M) # the table starts there
start = 0 if not match else match.start()
- df = read_csv(StringIO('Date' + src[start:]), **params)
+ df = read_csv(StringIO("Date" + src[start:]), **params)
try:
idx_name = df.index.name # hack for pandas 0.16.2
df = df.to_period(df.index.inferred_freq[:1])
@@ -120,17 +121,17 @@ def _read_one_data(self, url, params):
df = df.truncate(self.start, self.end)
datasets[i] = df
- title = src[:start].replace('\r\n', ' ').strip()
- shape = '({0} rows x {1} cols)'.format(*df.shape)
- table_desc.append('{0} {1}'.format(title, shape).strip())
+ title = src[:start].replace("\r\n", " ").strip()
+ shape = "({0} rows x {1} cols)".format(*df.shape)
+ table_desc.append("{0} {1}".format(title, shape).strip())
- descr = '{0}\n{1}\n\n'.format(self.symbols.replace('_', ' '),
- len(self.symbols) * '-')
+ descr = "{0}\n{1}\n\n".format(
+ self.symbols.replace("_", " "), len(self.symbols) * "-"
+ )
if doc_chunks:
- descr += ' '.join(doc_chunks).replace(2 * ' ', ' ') + '\n\n'
- table_descr = map(lambda x: '{0:3} : {1}'.format(*x),
- enumerate(table_desc))
- datasets['DESCR'] = descr + '\n'.join(table_descr)
+ descr += " ".join(doc_chunks).replace(2 * " ", " ") + "\n\n"
+ table_descr = map(lambda x: "{0:3} : {1}".format(*x), enumerate(table_desc))
+ datasets["DESCR"] = descr + "\n".join(table_descr)
return datasets
@@ -146,15 +147,21 @@ def get_available_datasets(self):
try:
from lxml.html import document_fromstring
except ImportError:
- raise ImportError("Please install lxml if you want to use the "
- "get_datasets_famafrench function")
+ raise ImportError(
+ "Please install lxml if you want to use the "
+ "get_datasets_famafrench function"
+ )
- response = self.session.get(_URL + 'data_library.html')
+ response = self.session.get(_URL + "data_library.html")
root = document_fromstring(response.content)
- datasets = [e.attrib['href'] for e in root.findall('.//a')
- if 'href' in e.attrib]
- datasets = [ds for ds in datasets if ds.startswith(_URL_PREFIX)
- and ds.endswith(_URL_SUFFIX)]
+ datasets = [
+ e.attrib["href"] for e in root.findall(".//a") if "href" in e.attrib
+ ]
+ datasets = [
+ ds
+ for ds in datasets
+ if ds.startswith(_URL_PREFIX) and ds.endswith(_URL_SUFFIX)
+ ]
- return lmap(lambda x: x[len(_URL_PREFIX):-len(_URL_SUFFIX)], datasets)
+ return lmap(lambda x: x[len(_URL_PREFIX) : -len(_URL_SUFFIX)], datasets)
diff --git a/pandas_datareader/fred.py b/pandas_datareader/fred.py
index 52bffd68..699c91c3 100644
--- a/pandas_datareader/fred.py
+++ b/pandas_datareader/fred.py
@@ -1,8 +1,7 @@
-from pandas_datareader.compat import is_list_like
-
from pandas import concat, read_csv
from pandas_datareader.base import _BaseReader
+from pandas_datareader.compat import is_list_like
class FredReader(_BaseReader):
@@ -40,16 +39,26 @@ def _read(self):
def fetch_data(url, name):
"""Utillity to fetch data"""
resp = self._read_url_as_StringIO(url)
- data = read_csv(resp, index_col=0, parse_dates=True,
- header=None, skiprows=1, names=["DATE", name],
- na_values='.')
+ data = read_csv(
+ resp,
+ index_col=0,
+ parse_dates=True,
+ header=None,
+ skiprows=1,
+ names=["DATE", name],
+ na_values=".",
+ )
try:
return data.truncate(self.start, self.end)
except KeyError: # pragma: no cover
- if data.iloc[3].name[7:12] == 'Error':
- raise IOError("Failed to get the data. Check that "
- "{0!r} is a valid FRED series.".format(name))
+ if data.iloc[3].name[7:12] == "Error":
+ raise IOError(
+ "Failed to get the data. Check that "
+ "{0!r} is a valid FRED series.".format(name)
+ )
raise
- df = concat([fetch_data(url, n) for url, n in zip(urls, names)],
- axis=1, join='outer')
+
+ df = concat(
+ [fetch_data(url, n) for url, n in zip(urls, names)], axis=1, join="outer"
+ )
return df
diff --git a/pandas_datareader/iex/__init__.py b/pandas_datareader/iex/__init__.py
index 195e02f7..221009b1 100644
--- a/pandas_datareader/iex/__init__.py
+++ b/pandas_datareader/iex/__init__.py
@@ -2,8 +2,8 @@
import pandas as pd
from pandas.io.common import urlencode
-from pandas_datareader.base import _BaseReader
+from pandas_datareader.base import _BaseReader
# Data provided for free by IEX
# Data is furnished in compliance with the guidelines promulgated in the IEX
@@ -17,14 +17,19 @@ class IEX(_BaseReader):
Serves as the base class for all IEX API services.
"""
- _format = 'json'
+ _format = "json"
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
- super(IEX, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
+ super(IEX, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
@@ -36,8 +41,7 @@ def service(self):
def url(self):
"""API URL"""
qstring = urlencode(self._get_params(self.symbols))
- return "https://api.iextrading.com/1.0/{}?{}".format(self.service,
- qstring)
+ return "https://api.iextrading.com/1.0/{}?{}".format(self.service, qstring)
def read(self):
"""Read data"""
@@ -51,9 +55,9 @@ def read(self):
def _get_params(self, symbols):
p = {}
if isinstance(symbols, list):
- p['symbols'] = ','.join(symbols)
+ p["symbols"] = ",".join(symbols)
elif isinstance(symbols, str):
- p['symbols'] = symbols
+ p["symbols"] = symbols
return p
def _output_error(self, out):
@@ -69,7 +73,7 @@ def _output_error(self, out):
for key, string in content.items():
e = "IEX Output error encountered: {}".format(string)
- if key == 'error':
+ if key == "error":
raise Exception(e)
def _read_lines(self, out):
diff --git a/pandas_datareader/iex/daily.py b/pandas_datareader/iex/daily.py
index d16c69a1..cb860522 100644
--- a/pandas_datareader/iex/daily.py
+++ b/pandas_datareader/iex/daily.py
@@ -2,9 +2,9 @@
import json
import os
+from dateutil.relativedelta import relativedelta
import pandas as pd
-from dateutil.relativedelta import relativedelta
from pandas_datareader.base import _DailyBaseReader
# Data provided for free by IEX
@@ -45,32 +45,48 @@ class IEXDailyReader(_DailyBaseReader):
IEX Cloud Secret Token
"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None, chunksize=25, api_key=None):
+ def __init__(
+ self,
+ symbols=None,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ chunksize=25,
+ api_key=None,
+ ):
if api_key is None:
- api_key = os.getenv('IEX_API_KEY')
+ api_key = os.getenv("IEX_API_KEY")
if not api_key or not isinstance(api_key, str):
- raise ValueError('The IEX Cloud API key must be provided either '
- 'through the api_key variable or through the '
- ' environment variable IEX_API_KEY')
+ raise ValueError(
+ "The IEX Cloud API key must be provided either "
+ "through the api_key variable or through the "
+ " environment variable IEX_API_KEY"
+ )
# Support for sandbox environment (testing purposes)
if os.getenv("IEX_SANDBOX") == "enable":
self.sandbox = True
else:
self.sandbox = False
self.api_key = api_key
- super(IEXDailyReader, self).__init__(symbols=symbols, start=start,
- end=end, retry_count=retry_count,
- pause=pause, session=session,
- chunksize=chunksize)
+ super(IEXDailyReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ chunksize=chunksize,
+ )
@property
def url(self):
"""API URL"""
if self.sandbox is True:
- return 'https://sandbox.iexapis.com/stable/stock/market/batch'
+ return "https://sandbox.iexapis.com/stable/stock/market/batch"
else:
- return 'https://cloud.iexapis.com/stable/stock/market/batch'
+ return "https://cloud.iexapis.com/stable/stock/market/batch"
@property
def endpoint(self):
@@ -80,20 +96,20 @@ def endpoint(self):
def _get_params(self, symbol):
chart_range = self._range_string_from_date()
if isinstance(symbol, list):
- symbolList = ','.join(symbol)
+ symbolList = ",".join(symbol)
else:
symbolList = symbol
params = {
"symbols": symbolList,
"types": self.endpoint,
"range": chart_range,
- "token": self.api_key
+ "token": self.api_key,
}
return params
def _range_string_from_date(self):
delta = relativedelta(self.start, datetime.datetime.now())
- years = (delta.years * -1)
+ years = delta.years * -1
if 5 <= years <= 15:
return "max"
if 2 <= years < 5:
@@ -113,14 +129,12 @@ def _range_string_from_date(self):
return "1y"
else:
- raise ValueError(
- "Invalid date specified. Must be within past 15 years.")
+ raise ValueError("Invalid date specified. Must be within past 15 years.")
def read(self):
"""Read data"""
try:
- return self._read_one_data(self.url,
- self._get_params(self.symbols))
+ return self._read_one_data(self.url, self._get_params(self.symbols))
finally:
self.close()
@@ -138,12 +152,12 @@ def _read_lines(self, out):
df.set_index("date", inplace=True)
values = ["open", "high", "low", "close", "volume"]
df = df[values]
- sstart = self.start.strftime('%Y-%m-%d')
- send = self.end.strftime('%Y-%m-%d')
+ sstart = self.start.strftime("%Y-%m-%d")
+ send = self.end.strftime("%Y-%m-%d")
df = df.loc[sstart:send]
result.update({symbol: df})
if len(result) > 1:
result = pd.concat(result).unstack(level=0)
- result.columns.names = ['Attributes', 'Symbols']
+ result.columns.names = ["Attributes", "Symbols"]
return result
return result[self.symbols]
diff --git a/pandas_datareader/iex/deep.py b/pandas_datareader/iex/deep.py
index 9dcc27af..02304ae2 100644
--- a/pandas_datareader/iex/deep.py
+++ b/pandas_datareader/iex/deep.py
@@ -1,6 +1,7 @@
-from pandas_datareader.iex import IEX
from datetime import datetime
+from pandas_datareader.iex import IEX
+
# Data provided for free by IEX
# Data is furnished in compliance with the guidelines promulgated in the IEX
# API terms of service and manual
@@ -22,17 +23,30 @@ class Deep(IEX):
Also provides last trade price and size information. Routed executions
are not reported.
"""
- def __init__(self, symbols=None, service=None, start=None, end=None,
- retry_count=3, pause=0.1, session=None):
+
+ def __init__(
+ self,
+ symbols=None,
+ service=None,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ ):
if isinstance(symbols, str):
symbols = symbols.lower()
else:
symbols = [s.lower() for s in symbols]
- super(Deep, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+ super(Deep, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
self.sub = service
@property
@@ -51,15 +65,15 @@ def _read_lines(self, out):
# Runs appropriate output functions per the service being accessed.
fmap = {
- 'book': '_pass',
- 'op-halt-status': '_convert_tstamp',
- 'security-event': '_convert_tstamp',
- 'ssr-status': '_convert_tstamp',
- 'system-event': '_read_system_event',
- 'trades': '_pass',
- 'trade-breaks': '_convert_tstamp',
- 'trading-status': '_read_trading_status',
- None: '_pass',
+ "book": "_pass",
+ "op-halt-status": "_convert_tstamp",
+ "security-event": "_convert_tstamp",
+ "ssr-status": "_convert_tstamp",
+ "system-event": "_read_system_event",
+ "trades": "_pass",
+ "trade-breaks": "_convert_tstamp",
+ "trading-status": "_read_trading_status",
+ None: "_pass",
}
if self.sub in fmap:
@@ -71,12 +85,12 @@ def _read_system_event(self, out):
# Map the response code to a string output per the API docs.
# Per: https://www.iextrading.com/developer/docs/#system-event-message
smap = {
- 'O': 'Start of messages',
- 'S': 'Start of system hours',
- 'R': 'Start of regular market hours',
- 'M': 'End of regular market hours',
- 'E': 'End of system hours',
- 'C': 'End of messages'
+ "O": "Start of messages",
+ "S": "Start of system hours",
+ "R": "Start of regular market hours",
+ "M": "End of regular market hours",
+ "E": "End of system hours",
+ "C": "End of messages",
}
tid = out["systemEvent"]
out["eventResponse"] = smap[tid]
@@ -90,34 +104,33 @@ def _pass(out):
def _read_trading_status(self, out):
# Reference: https://www.iextrading.com/developer/docs/#trading-status
smap = {
- 'H': 'Trading halted across all US equity markets',
- 'O': 'Trading halt released into an Order Acceptance Period '
- '(IEX-listed securities only)',
- 'P': 'Trading paused and Order Acceptance Period on IEX '
- '(IEX-listed securities only)',
- 'T': 'Trading on IEX'
+ "H": "Trading halted across all US equity markets",
+ "O": "Trading halt released into an Order Acceptance Period "
+ "(IEX-listed securities only)",
+ "P": "Trading paused and Order Acceptance Period on IEX "
+ "(IEX-listed securities only)",
+ "T": "Trading on IEX",
}
rmap = {
# Trading Halt Reasons
- 'T1': 'Halt News Pending',
- 'IPO1': 'IPO/New Issue Not Yet Trading',
- 'IPOD': 'IPO/New Issue Deferred',
- 'MCB3': 'Market-Wide Circuit Breaker Level 3 - Breached',
- 'NA': 'Reason Not Available',
-
+ "T1": "Halt News Pending",
+ "IPO1": "IPO/New Issue Not Yet Trading",
+ "IPOD": "IPO/New Issue Deferred",
+ "MCB3": "Market-Wide Circuit Breaker Level 3 - Breached",
+ "NA": "Reason Not Available",
# Order Acceptance Period Reasons
- 'T2': 'Halt News Dissemination',
- 'IPO2': 'IPO/New Issue Order Acceptance Period',
- 'IPO3': 'IPO Pre-Launch Period',
- 'MCB1': 'Market-Wide Circuit Breaker Level 1 - Breached',
- 'MCB2': 'Market-Wide Circuit Breaker Level 2 - Breached'
+ "T2": "Halt News Dissemination",
+ "IPO2": "IPO/New Issue Order Acceptance Period",
+ "IPO3": "IPO Pre-Launch Period",
+ "MCB1": "Market-Wide Circuit Breaker Level 1 - Breached",
+ "MCB2": "Market-Wide Circuit Breaker Level 2 - Breached",
}
for ticker, data in out.items():
- if data['status'] in smap:
- data['statusText'] = smap[data['status']]
+ if data["status"] in smap:
+ data["statusText"] = smap[data["status"]]
- if data['reason'] in rmap:
- data['reasonText'] = rmap[data['reason']]
+ if data["reason"] in rmap:
+ data["reasonText"] = rmap[data["reason"]]
out[ticker] = data
@@ -126,15 +139,15 @@ def _read_trading_status(self, out):
@staticmethod
def _convert_tstamp(out):
# Searches for top-level timestamp attributes or within dictionaries
- if 'timestamp' in out:
+ if "timestamp" in out:
# Convert UNIX to datetime object
f = float(out["timestamp"])
- out["timestamp"] = datetime.fromtimestamp(f/1000)
+ out["timestamp"] = datetime.fromtimestamp(f / 1000)
else:
for ticker, data in out.items():
- if 'timestamp' in data:
+ if "timestamp" in data:
f = float(data["timestamp"])
- data["timestamp"] = datetime.fromtimestamp(f/1000)
+ data["timestamp"] = datetime.fromtimestamp(f / 1000)
out[ticker] = data
return out
diff --git a/pandas_datareader/iex/market.py b/pandas_datareader/iex/market.py
index 16fa4452..6cd1a879 100644
--- a/pandas_datareader/iex/market.py
+++ b/pandas_datareader/iex/market.py
@@ -16,12 +16,18 @@ class MarketReader(IEX):
Market data is captured by the IEX system between approximately 7:45 a.m.
and 5:15 p.m. ET.
"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
- super(MarketReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
+ super(MarketReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
diff --git a/pandas_datareader/iex/ref.py b/pandas_datareader/iex/ref.py
index 995905ed..93908fc4 100644
--- a/pandas_datareader/iex/ref.py
+++ b/pandas_datareader/iex/ref.py
@@ -16,12 +16,18 @@ class SymbolsReader(IEX):
Returns symbols IEX supports for trading. Updated daily as of 7:45 a.m.
ET.
"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
- super(SymbolsReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
+ super(SymbolsReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
diff --git a/pandas_datareader/iex/stats.py b/pandas_datareader/iex/stats.py
index 873d6bec..a0c5cf34 100644
--- a/pandas_datareader/iex/stats.py
+++ b/pandas_datareader/iex/stats.py
@@ -5,7 +5,6 @@
from pandas_datareader.exceptions import UnstableAPIWarning
from pandas_datareader.iex import IEX
-
# Data provided for free by IEX
# Data is furnished in compliance with the guidelines promulgated in the IEX
# API terms of service and manual
@@ -17,16 +16,25 @@ class DailySummaryReader(IEX):
"""
Daily statistics from IEX for a day or month
"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
+
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
import warnings
- warnings.warn('Daily statistics is not working due to issues with the '
- 'IEX API', UnstableAPIWarning)
+
+ warnings.warn(
+ "Daily statistics is not working due to issues with the " "IEX API",
+ UnstableAPIWarning,
+ )
self.curr_date = start
- super(DailySummaryReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+ super(DailySummaryReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
@@ -37,7 +45,7 @@ def _get_params(self, symbols):
p = {}
if self.curr_date is not None:
- p['date'] = self.curr_date.strftime('%Y%m%d')
+ p["date"] = self.curr_date.strftime("%Y%m%d")
return p
@@ -59,16 +67,21 @@ def read(self):
class MonthlySummaryReader(IEX):
"""Monthly statistics from IEX"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
+
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
self.curr_date = start
- self.date_format = '%Y%m'
+ self.date_format = "%Y%m"
- super(MonthlySummaryReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause,
- session=session)
+ super(MonthlySummaryReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
@@ -79,7 +92,7 @@ def _get_params(self, symbols):
p = {}
if self.curr_date is not None:
- p['date'] = self.curr_date.strftime(self.date_format)
+ p["date"] = self.curr_date.strftime(self.date_format)
return p
@@ -94,8 +107,7 @@ def read(self):
dfs = []
# Build list of all dates within the given range
- lrange = [x for x in (self.start + timedelta(n)
- for n in range(tlen.days))]
+ lrange = [x for x in (self.start + timedelta(n) for n in range(tlen.days))]
mrange = []
for dt in lrange:
@@ -109,7 +121,7 @@ def read(self):
# We may not return data if this was a weekend/holiday:
if not tdf.empty:
- tdf['date'] = date.strftime(self.date_format)
+ tdf["date"] = date.strftime(self.date_format)
dfs.append(tdf)
# We may not return any data if we failed to specify useful parameters:
@@ -120,12 +132,18 @@ class RecordsReader(IEX):
"""
Total matched volume information from IEX
"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
- super(RecordsReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
+ super(RecordsReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
@@ -155,12 +173,18 @@ class RecentReader(IEX):
* litVolume: refers to the number of lit shares traded on IEX
(single-counted).
"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
- super(RecentReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
+ super(RecentReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
diff --git a/pandas_datareader/iex/tops.py b/pandas_datareader/iex/tops.py
index 410fd2be..10a6c1ec 100644
--- a/pandas_datareader/iex/tops.py
+++ b/pandas_datareader/iex/tops.py
@@ -17,12 +17,17 @@ class TopsReader(IEX):
on IEX's displayed limit order book.
"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
- super(TopsReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
+ super(TopsReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
@@ -39,13 +44,19 @@ class LastReader(IEX):
Last provides trade data for executions on IEX. Provides last sale price,
size and time.
"""
+
# todo: Eventually we'll want to implement WebSockets as an option.
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None):
- super(LastReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+ def __init__(
+ self, symbols=None, start=None, end=None, retry_count=3, pause=0.1, session=None
+ ):
+ super(LastReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
@property
def service(self):
diff --git a/pandas_datareader/io/jsdmx.py b/pandas_datareader/io/jsdmx.py
index e993effb..4fc0dd69 100644
--- a/pandas_datareader/io/jsdmx.py
+++ b/pandas_datareader/io/jsdmx.py
@@ -1,8 +1,8 @@
# pylint: disable-msg=E1101,W0613,W0603
from __future__ import unicode_literals
-from collections import OrderedDict
+from collections import OrderedDict
import itertools
import sys
@@ -32,7 +32,7 @@ def read_jsdmx(path_or_buf):
import simplejson as json
except ImportError:
if sys.version_info[:2] < (2, 7):
- raise ImportError('simplejson is required in python 2.6')
+ raise ImportError("simplejson is required in python 2.6")
import json
if isinstance(jdata, dict):
@@ -40,11 +40,11 @@ def read_jsdmx(path_or_buf):
else:
data = json.loads(jdata, object_pairs_hook=OrderedDict)
- structure = data['structure']
- index = _parse_dimensions(structure['dimensions']['observation'])
- columns = _parse_dimensions(structure['dimensions']['series'])
+ structure = data["structure"]
+ index = _parse_dimensions(structure["dimensions"]["observation"])
+ columns = _parse_dimensions(structure["dimensions"]["series"])
- dataset = data['dataSets']
+ dataset = data["dataSets"]
if len(dataset) != 1:
raise ValueError("length of 'dataSets' must be 1")
dataset = dataset[0]
@@ -58,20 +58,19 @@ def _get_indexer(index):
if index.nlevels == 1:
return [str(i) for i in range(len(index))]
else:
- it = itertools.product(*[range(
- len(level)) for level in index.levels])
- return [':'.join(map(str, i)) for i in it]
+ it = itertools.product(*[range(len(level)) for level in index.levels])
+ return [":".join(map(str, i)) for i in it]
def _parse_values(dataset, index, columns):
size = len(index)
- series = dataset['series']
+ series = dataset["series"]
values = []
# for s_key, s_value in iteritems(series):
for s_key in _get_indexer(columns):
try:
- observations = series[s_key]['observations']
+ observations = series[s_key]["observations"]
observed = []
for o_key in _get_indexer(index):
try:
@@ -90,14 +89,14 @@ def _parse_dimensions(dimensions):
arrays = []
names = []
for key in dimensions:
- values = [v['name'] for v in key['values']]
+ values = [v["name"] for v in key["values"]]
- role = key.get('role', None)
- if role in ('time', 'TIME_PERIOD'):
+ role = key.get("role", None)
+ if role in ("time", "TIME_PERIOD"):
values = pd.DatetimeIndex(values)
arrays.append(values)
- names.append(key['name'])
+ names.append(key["name"])
midx = pd.MultiIndex.from_product(arrays, names=names)
if len(arrays) == 1 and isinstance(midx, pd.MultiIndex):
# Fix for pandas >= 0.21
diff --git a/pandas_datareader/io/sdmx.py b/pandas_datareader/io/sdmx.py
index 4295c319..55f185f0 100644
--- a/pandas_datareader/io/sdmx.py
+++ b/pandas_datareader/io/sdmx.py
@@ -1,34 +1,33 @@
from __future__ import unicode_literals
import collections
+from io import BytesIO
import time
import zipfile
-from io import BytesIO
import pandas as pd
-from pandas_datareader.io.util import _read_content
from pandas_datareader.compat import HTTPError, str_to_bytes
+from pandas_datareader.io.util import _read_content
+_STRUCTURE = "{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/structure}"
+_MESSAGE = "{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/message}"
+_GENERIC = "{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/data/generic}"
+_COMMON = "{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/common}"
+_XML = "{http://www.w3.org/XML/1998/namespace}"
-_STRUCTURE = '{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/structure}'
-_MESSAGE = '{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/message}'
-_GENERIC = '{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/data/generic}'
-_COMMON = '{http://www.sdmx.org/resources/sdmxml/schemas/v2_1/common}'
-_XML = '{http://www.w3.org/XML/1998/namespace}'
-
-_DATASET = _MESSAGE + 'DataSet'
-_SERIES = _GENERIC + 'Series'
-_SERIES_KEY = _GENERIC + 'SeriesKey'
-_OBSERVATION = _GENERIC + 'Obs'
-_VALUE = _GENERIC + 'Value'
-_OBSDIMENSION = _GENERIC + 'ObsDimension'
-_OBSVALUE = _GENERIC + 'ObsValue'
-_CODE = _STRUCTURE + 'Code'
-_TIMEDIMENSION = _STRUCTURE + 'TimeDimension'
+_DATASET = _MESSAGE + "DataSet"
+_SERIES = _GENERIC + "Series"
+_SERIES_KEY = _GENERIC + "SeriesKey"
+_OBSERVATION = _GENERIC + "Obs"
+_VALUE = _GENERIC + "Value"
+_OBSDIMENSION = _GENERIC + "ObsDimension"
+_OBSVALUE = _GENERIC + "ObsValue"
+_CODE = _STRUCTURE + "Code"
+_TIMEDIMENSION = _STRUCTURE + "TimeDimension"
-def read_sdmx(path_or_buf, dtype='float64', dsd=None):
+def read_sdmx(path_or_buf, dtype="float64", dsd=None):
"""
Convert a SDMX-XML string to pandas object
@@ -49,14 +48,15 @@ def read_sdmx(path_or_buf, dtype='float64', dsd=None):
xdata = _read_content(path_or_buf)
import xml.etree.ElementTree as ET
+
root = ET.fromstring(xdata)
try:
- structure = _get_child(root, _MESSAGE + 'Structure')
+ structure = _get_child(root, _MESSAGE + "Structure")
except ValueError:
# get zipped path
- result = list(root.iter(_COMMON + 'Text'))[1].text
- if not result.startswith('http'):
+ result = list(root.iter(_COMMON + "Text"))[1].text
+ if not result.startswith("http"):
raise ValueError(result)
for _ in range(60):
@@ -68,11 +68,13 @@ def read_sdmx(path_or_buf, dtype='float64', dsd=None):
time.sleep(1)
continue
- msg = ('Unable to download zipped data within 60 secs, '
- 'please download it manually from: {0}')
+ msg = (
+ "Unable to download zipped data within 60 secs, "
+ "please download it manually from: {0}"
+ )
raise ValueError(msg.format(result))
- idx_name = structure.get('dimensionAtObservation')
+ idx_name = structure.get("dimensionAtObservation")
dataset = _get_child(root, _DATASET)
keys = []
@@ -141,8 +143,7 @@ def _construct_index(keys, dsd=None):
except KeyError:
values[name] = [value]
- midx = pd.MultiIndex.from_arrays([values[name] for name in names],
- names=names)
+ midx = pd.MultiIndex.from_arrays([values[name] for name in names], names=names)
return midx
@@ -151,7 +152,7 @@ def _parse_observations(observations):
for observation in observations:
obsdimension = _get_child(observation, _OBSDIMENSION)
obsvalue = _get_child(observation, _OBSVALUE)
- results.append((obsdimension.get('value'), obsvalue.get('value')))
+ results.append((obsdimension.get("value"), obsvalue.get("value")))
# return list of key/value tuple, eg: [(key, value), ...]
return results
@@ -159,7 +160,7 @@ def _parse_observations(observations):
def _parse_series_key(series):
serieskey = _get_child(series, _SERIES_KEY)
key_values = serieskey.iter(_VALUE)
- keys = [(key.get('id'), key.get('value')) for key in key_values]
+ keys = [(key.get("id"), key.get("value")) for key in key_values]
# return list of key/value tuple, eg: [(key, value), ...]
return keys
@@ -169,11 +170,11 @@ def _get_child(element, key):
if len(elements) == 1:
return elements[0]
elif len(elements) == 0:
- raise ValueError("Element {0} contains "
- "no {1}".format(element.tag, key))
+ raise ValueError("Element {0} contains " "no {1}".format(element.tag, key))
else:
- raise ValueError("Element {0} contains "
- "multiple {1}".format(element.tag, key))
+ raise ValueError(
+ "Element {0} contains " "multiple {1}".format(element.tag, key)
+ )
_NAME_EN = ".//{0}Name[@{1}lang='en']".format(_COMMON, _XML)
@@ -184,7 +185,7 @@ def _get_english_name(element):
return name
-SDMXCode = collections.namedtuple('SDMXCode', ['codes', 'ts'])
+SDMXCode = collections.namedtuple("SDMXCode", ["codes", "ts"])
def _read_sdmx_dsd(path_or_buf):
@@ -204,12 +205,13 @@ def _read_sdmx_dsd(path_or_buf):
xdata = _read_content(path_or_buf)
import xml.etree.cElementTree as ET
+
root = ET.fromstring(xdata)
- structure = _get_child(root, _MESSAGE + 'Structures')
- codes = _get_child(structure, _STRUCTURE + 'Codelists')
+ structure = _get_child(root, _MESSAGE + "Structures")
+ codes = _get_child(structure, _STRUCTURE + "Codelists")
# concepts = _get_child(structure, _STRUCTURE + 'Concepts')
- datastructures = _get_child(structure, _STRUCTURE + 'DataStructures')
+ datastructures = _get_child(structure, _STRUCTURE + "DataStructures")
code_results = {}
for codelist in codes:
@@ -217,7 +219,7 @@ def _read_sdmx_dsd(path_or_buf):
codelist_name = _get_english_name(codelist)
mapper = {}
for code in codelist.iter(_CODE):
- code_id = code.get('id')
+ code_id = code.get("id")
name = _get_english_name(code)
mapper[code_id] = name
# codeobj = SDMXCode(id=codelist_id, name=codelist_name, mapper=mapper)
@@ -225,7 +227,7 @@ def _read_sdmx_dsd(path_or_buf):
code_results[codelist_name] = mapper
times = list(datastructures.iter(_TIMEDIMENSION))
- times = [t.get('id') for t in times]
+ times = [t.get("id") for t in times]
result = SDMXCode(codes=code_results, ts=times)
return result
diff --git a/pandas_datareader/io/util.py b/pandas_datareader/io/util.py
index 25f6168e..92a3ae8b 100644
--- a/pandas_datareader/io/util.py
+++ b/pandas_datareader/io/util.py
@@ -19,11 +19,11 @@ def _read_content(path_or_buf):
exists = False
if exists:
- with open(filepath_or_buffer, 'r') as fh:
+ with open(filepath_or_buffer, "r") as fh:
data = fh.read()
else:
data = filepath_or_buffer
- elif hasattr(filepath_or_buffer, 'read'):
+ elif hasattr(filepath_or_buffer, "read"):
data = filepath_or_buffer.read()
else:
data = filepath_or_buffer
diff --git a/pandas_datareader/moex.py b/pandas_datareader/moex.py
index e439efdb..ff3aa68d 100644
--- a/pandas_datareader/moex.py
+++ b/pandas_datareader/moex.py
@@ -5,9 +5,7 @@
import pandas as pd
from pandas_datareader.base import _DailyBaseReader
-from pandas_datareader.compat import (binary_type, concat, is_list_like,
- StringIO)
-
+from pandas_datareader.compat import StringIO, binary_type, concat, is_list_like
class MoexReader(_DailyBaseReader):
@@ -54,39 +52,45 @@ def __init__(self, *args, **kwargs):
self.__engines, self.__markets = {}, {} # dicts for engines and markets
__url_metadata = "https://iss.moex.com/iss/securities/{symbol}.csv"
- __url_data = "https://iss.moex.com/iss/history/engines/{engine}/" \
- "markets/{market}/securities/{symbol}.csv"
+ __url_data = (
+ "https://iss.moex.com/iss/history/engines/{engine}/"
+ "markets/{market}/securities/{symbol}.csv"
+ )
@property
def url(self):
"""Return a list of API URLs per symbol"""
if not self.__engines or not self.__markets:
- raise Exception("Accessing url property before invocation "
- "of read() or _get_metadata() methods")
+ raise Exception(
+ "Accessing url property before invocation "
+ "of read() or _get_metadata() methods"
+ )
- return [self.__url_data.format(
- engine=self.__engines[s],
- market=self.__markets[s],
- symbol=s) for s in self.symbols]
+ return [
+ self.__url_data.format(
+ engine=self.__engines[s], market=self.__markets[s], symbol=s
+ )
+ for s in self.symbols
+ ]
def _get_params(self, start):
"""Return a dict for REST API GET request parameters"""
params = {
- 'iss.only': 'history',
- 'iss.dp': 'point',
- 'iss.df': '%Y-%m-%d',
- 'iss.tf': '%H:%M:%S',
- 'iss.dft': '%Y-%m-%d %H:%M:%S',
- 'iss.json': 'extended',
- 'callback': 'JSON_CALLBACK',
- 'from': start,
- 'till': self.end_dt.strftime('%Y-%m-%d'),
- 'limit': 100,
- 'start': 1,
- 'sort_order': 'TRADEDATE',
- 'sort_order_desc': 'asc'
+ "iss.only": "history",
+ "iss.dp": "point",
+ "iss.df": "%Y-%m-%d",
+ "iss.tf": "%H:%M:%S",
+ "iss.dft": "%Y-%m-%d %H:%M:%S",
+ "iss.json": "extended",
+ "callback": "JSON_CALLBACK",
+ "from": start,
+ "till": self.end_dt.strftime("%Y-%m-%d"),
+ "limit": 100,
+ "start": 1,
+ "sort_order": "TRADEDATE",
+ "sort_order_desc": "asc",
}
return params
@@ -96,33 +100,36 @@ def _get_metadata(self):
markets, engines = {}, {}
for symbol in self.symbols:
- response = self._get_response(
- self.__url_metadata.format(symbol=symbol)
- )
+ response = self._get_response(self.__url_metadata.format(symbol=symbol))
text = self._sanitize_response(response)
if len(text) == 0:
service = self.__class__.__name__
- raise IOError("{} request returned no data; check URL for invalid "
- "inputs: {}".format(service, self.__url_metadata))
+ raise IOError(
+ "{} request returned no data; check URL for invalid "
+ "inputs: {}".format(service, self.__url_metadata)
+ )
if isinstance(text, binary_type):
- text = text.decode('windows-1251')
+ text = text.decode("windows-1251")
- header_str = 'secid;boardid;'
+ header_str = "secid;boardid;"
get_data = False
for s in text.splitlines():
if s.startswith(header_str):
get_data = True
continue
- if get_data and s != '':
- fields = s.split(';')
+ if get_data and s != "":
+ fields = s.split(";")
markets[symbol], engines[symbol] = fields[5], fields[7]
break
if symbol not in markets or symbol not in engines:
- raise IOError("{} request returned no metadata: {}\n"
- "Typo in the security symbol `{}`?".format(
- self.__class__.__name__,
- self.__url_metadata.format(symbol=symbol),
- symbol))
+ raise IOError(
+ "{} request returned no metadata: {}\n"
+ "Typo in the security symbol `{}`?".format(
+ self.__class__.__name__,
+ self.__url_metadata.format(symbol=symbol),
+ symbol,
+ )
+ )
return markets, engines
def read(self):
@@ -140,21 +147,22 @@ def read(self):
while True: # read in a loop with small date intervals
if len(out_list) > 0:
if date_column is None:
- date_column = out_list[0].split(';').index('TRADEDATE')
+ date_column = out_list[0].split(";").index("TRADEDATE")
# get the last downloaded date
- start_str = out_list[-1].split(';', 4)[date_column]
- start = dt.datetime.strptime(start_str, '%Y-%m-%d').date()
+ start_str = out_list[-1].split(";", 4)[date_column]
+ start = dt.datetime.strptime(start_str, "%Y-%m-%d").date()
else:
- start_str = self.start.strftime('%Y-%m-%d')
+ start_str = self.start.strftime("%Y-%m-%d")
start = self.start
if start >= self.end or start >= dt.date.today():
break
params = self._get_params(start_str)
- strings_out = self._read_url_as_String(urls[i], params) \
- .splitlines()[2:]
+ strings_out = self._read_url_as_String(
+ urls[i], params
+ ).splitlines()[2:]
strings_out = list(filter(lambda x: x.strip(), strings_out))
if len(out_list) == 0:
@@ -165,13 +173,13 @@ def read(self):
out_list += strings_out[1:] # remove a CSV head line
if len(strings_out) < 100: # all data recevied - break
break
- str_io = StringIO('\r\n'.join(out_list))
- dfs.append(self._read_lines(str_io)) # add a new DataFrame
+ str_io = StringIO("\r\n".join(out_list))
+ dfs.append(self._read_lines(str_io)) # add a new DataFrame
finally:
self.close()
if len(dfs) > 1:
- return concat(dfs, axis=0, join='outer', sort=True)
+ return concat(dfs, axis=0, join="outer", sort=True)
else:
return dfs[0]
@@ -182,14 +190,21 @@ def _read_url_as_String(self, url, params=None):
text = self._sanitize_response(response)
if len(text) == 0:
service = self.__class__.__name__
- raise IOError("{} request returned no data; check URL for invalid "
- "inputs: {}".format(service, self.url))
+ raise IOError(
+ "{} request returned no data; check URL for invalid "
+ "inputs: {}".format(service, self.url)
+ )
if isinstance(text, binary_type):
- text = text.decode('windows-1251')
+ text = text.decode("windows-1251")
return text
def _read_lines(self, input):
"""Return a pandas DataFrame from input"""
- return pd.read_csv(input, index_col='TRADEDATE', parse_dates=True,
- sep=';', na_values=('-', 'null'))
+ return pd.read_csv(
+ input,
+ index_col="TRADEDATE",
+ parse_dates=True,
+ sep=";",
+ na_values=("-", "null"),
+ )
diff --git a/pandas_datareader/nasdaq_trader.py b/pandas_datareader/nasdaq_trader.py
index 78cffe7e..fd645504 100644
--- a/pandas_datareader/nasdaq_trader.py
+++ b/pandas_datareader/nasdaq_trader.py
@@ -1,34 +1,36 @@
from ftplib import FTP, all_errors
+import time
+import warnings
from pandas import read_csv
+
from pandas_datareader._utils import RemoteDataError
from pandas_datareader.compat import StringIO
-import time
-import warnings
-
-_NASDAQ_TICKER_LOC = '/SymbolDirectory/nasdaqtraded.txt'
-_NASDAQ_FTP_SERVER = 'ftp.nasdaqtrader.com'
-_TICKER_DTYPE = [('Nasdaq Traded', bool),
- ('Symbol', str),
- ('Security Name', str),
- ('Listing Exchange', str),
- ('Market Category', str),
- ('ETF', bool),
- ('Round Lot Size', float),
- ('Test Issue', bool),
- ('Financial Status', str),
- ('CQS Symbol', str),
- ('NASDAQ Symbol', str),
- ('NextShares', bool)]
-_CATEGORICAL = ('Listing Exchange', 'Financial Status')
-
-_DELIMITER = '|'
+_NASDAQ_TICKER_LOC = "/SymbolDirectory/nasdaqtraded.txt"
+_NASDAQ_FTP_SERVER = "ftp.nasdaqtrader.com"
+_TICKER_DTYPE = [
+ ("Nasdaq Traded", bool),
+ ("Symbol", str),
+ ("Security Name", str),
+ ("Listing Exchange", str),
+ ("Market Category", str),
+ ("ETF", bool),
+ ("Round Lot Size", float),
+ ("Test Issue", bool),
+ ("Financial Status", str),
+ ("CQS Symbol", str),
+ ("NASDAQ Symbol", str),
+ ("NextShares", bool),
+]
+_CATEGORICAL = ("Listing Exchange", "Financial Status")
+
+_DELIMITER = "|"
_ticker_cache = None
def _bool_converter(item):
- return item == 'Y'
+ return item == "Y"
def _download_nasdaq_symbols(timeout):
@@ -39,38 +41,43 @@ def _download_nasdaq_symbols(timeout):
ftp_session = FTP(_NASDAQ_FTP_SERVER, timeout=timeout)
ftp_session.login()
except all_errors as err:
- raise RemoteDataError('Error connecting to %r: %s' %
- (_NASDAQ_FTP_SERVER, err))
+ raise RemoteDataError("Error connecting to %r: %s" % (_NASDAQ_FTP_SERVER, err))
lines = []
try:
- ftp_session.retrlines('RETR ' + _NASDAQ_TICKER_LOC, lines.append)
+ ftp_session.retrlines("RETR " + _NASDAQ_TICKER_LOC, lines.append)
except all_errors as err:
- raise RemoteDataError('Error downloading from %r: %s' %
- (_NASDAQ_FTP_SERVER, err))
+ raise RemoteDataError(
+ "Error downloading from %r: %s" % (_NASDAQ_FTP_SERVER, err)
+ )
finally:
ftp_session.close()
# Sanity Checking
- if not lines[-1].startswith('File Creation Time:'):
- raise RemoteDataError('Missing expected footer. Found %r' % lines[-1])
+ if not lines[-1].startswith("File Creation Time:"):
+ raise RemoteDataError("Missing expected footer. Found %r" % lines[-1])
# Convert Y/N to True/False.
- converter_map = dict((col, _bool_converter) for col, t in _TICKER_DTYPE
- if t is bool)
+ converter_map = dict(
+ (col, _bool_converter) for col, t in _TICKER_DTYPE if t is bool
+ )
# For pandas >= 0.20.0, the Python parser issues a warning if
# both a converter and dtype are specified for the same column.
# However, this measure is probably temporary until the read_csv
# behavior is better formalized.
with warnings.catch_warnings(record=True):
- data = read_csv(StringIO('\n'.join(lines[:-1])), '|',
- dtype=_TICKER_DTYPE, converters=converter_map,
- index_col=1)
+ data = read_csv(
+ StringIO("\n".join(lines[:-1])),
+ "|",
+ dtype=_TICKER_DTYPE,
+ converters=converter_map,
+ index_col=1,
+ )
# Properly cast enumerations
for cat in _CATEGORICAL:
- data[cat] = data[cat].astype('category')
+ data[cat] = data[cat].astype("category")
return data
@@ -87,12 +94,12 @@ def get_nasdaq_symbols(retry_count=3, timeout=30, pause=None):
global _ticker_cache
if timeout < 0:
- raise ValueError('timeout must be >= 0, not %r' % (timeout,))
+ raise ValueError("timeout must be >= 0, not %r" % (timeout,))
if pause is None:
pause = timeout / 3
elif pause < 0:
- raise ValueError('pause must be >= 0, not %r' % (pause,))
+ raise ValueError("pause must be >= 0, not %r" % (pause,))
if _ticker_cache is None:
while retry_count > 0:
diff --git a/pandas_datareader/oecd.py b/pandas_datareader/oecd.py
index 9d9d85e4..94364352 100644
--- a/pandas_datareader/oecd.py
+++ b/pandas_datareader/oecd.py
@@ -1,34 +1,34 @@
import pandas as pd
-from pandas_datareader.io import read_jsdmx
from pandas_datareader.base import _BaseReader
from pandas_datareader.compat import string_types
+from pandas_datareader.io import read_jsdmx
class OECDReader(_BaseReader):
"""Get data for the given name from OECD."""
- _format = 'json'
+ _format = "json"
@property
def url(self):
"""API URL"""
- url = 'http://stats.oecd.org/SDMX-JSON/data'
+ url = "http://stats.oecd.org/SDMX-JSON/data"
if not isinstance(self.symbols, string_types):
- raise ValueError('data name must be string')
+ raise ValueError("data name must be string")
# API: https://data.oecd.org/api/sdmx-json-documentation/
- return '{0}/{1}/all/all?'.format(url, self.symbols)
+ return "{0}/{1}/all/all?".format(url, self.symbols)
def _read_lines(self, out):
""" read one data from specified URL """
df = read_jsdmx(out)
try:
idx_name = df.index.name # hack for pandas 0.16.2
- df.index = pd.to_datetime(df.index, errors='ignore')
+ df.index = pd.to_datetime(df.index, errors="ignore")
for col in df:
- df[col] = pd.to_numeric(df[col], errors='ignore')
+ df[col] = pd.to_numeric(df[col], errors="ignore")
df = df.sort_index()
df = df.truncate(self.start, self.end)
df.index.name = idx_name
diff --git a/pandas_datareader/quandl.py b/pandas_datareader/quandl.py
index 54cacb5f..990d60ed 100644
--- a/pandas_datareader/quandl.py
+++ b/pandas_datareader/quandl.py
@@ -44,30 +44,40 @@ class QuandlReader(_DailyBaseReader):
_BASE_URL = "https://www.quandl.com/api/v3/datasets/"
- def __init__(self, symbols, start=None, end=None, retry_count=3, pause=0.1,
- session=None, chunksize=25, api_key=None):
- super(QuandlReader, self).__init__(symbols, start, end, retry_count,
- pause, session, chunksize)
+ def __init__(
+ self,
+ symbols,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ chunksize=25,
+ api_key=None,
+ ):
+ super(QuandlReader, self).__init__(
+ symbols, start, end, retry_count, pause, session, chunksize
+ )
if api_key is None:
- api_key = os.getenv('QUANDL_API_KEY')
+ api_key = os.getenv("QUANDL_API_KEY")
if not api_key or not isinstance(api_key, str):
- raise ValueError('The Quandl API key must be provided either '
- 'through the api_key variable or through the '
- 'environmental variable QUANDL_API_KEY.')
+ raise ValueError(
+ "The Quandl API key must be provided either "
+ "through the api_key variable or through the "
+ "environmental variable QUANDL_API_KEY."
+ )
self.api_key = api_key
@property
def url(self):
"""API URL"""
- symbol = self.symbols if isinstance(self.symbols, str) else \
- self.symbols[0]
+ symbol = self.symbols if isinstance(self.symbols, str) else self.symbols[0]
mm = self._fullmatch(r"([A-Z0-9]+)(([/\.])([A-Z0-9_]+))?", symbol)
- assert mm, ("Symbol '%s' must conform to Quandl convention 'DB/SYM'" %
- symbol)
- datasetname = 'WIKI'
+ assert mm, "Symbol '%s' must conform to Quandl convention 'DB/SYM'" % symbol
+ datasetname = "WIKI"
if not mm.group(2):
# bare symbol:
- datasetname = 'WIKI' # default; symbol stays itself
+ datasetname = "WIKI" # default; symbol stays itself
elif mm.group(3) == "/":
# --- normal Quandl DB/SYM convention:
symbol = mm.group(4)
@@ -77,15 +87,16 @@ def url(self):
symbol = mm.group(1)
datasetname = self._db_from_countrycode(mm.group(4))
params = {
- 'start_date': self.start.strftime('%Y-%m-%d'),
- 'end_date': self.end.strftime('%Y-%m-%d'),
- 'order': "asc",
- 'api_key': self.api_key
+ "start_date": self.start.strftime("%Y-%m-%d"),
+ "end_date": self.end.strftime("%Y-%m-%d"),
+ "order": "asc",
+ "api_key": self.api_key,
}
- paramstring = '&'.join(['%s=%s' % (k, v) for k, v in params.items()])
- url = '{url}{dataset}/{symbol}.csv?{params}'
- return url.format(url=self._BASE_URL, dataset=datasetname,
- symbol=symbol, params=paramstring)
+ paramstring = "&".join(["%s=%s" % (k, v) for k, v in params.items()])
+ url = "{url}{dataset}/{symbol}.csv?{params}"
+ return url.format(
+ url=self._BASE_URL, dataset=datasetname, symbol=symbol, params=paramstring
+ )
def _fullmatch(self, regex, string, flags=0):
"""Emulate python-3.4 re.fullmatch()."""
@@ -93,27 +104,28 @@ def _fullmatch(self, regex, string, flags=0):
_COUNTRYCODE_TO_DATASET = dict(
# https://www.quandl.com/data/EURONEXT-Euronext-Stock-Exchange
- BE='EURONEXT',
+ BE="EURONEXT",
# https://www.quandl.com/data/HKEX-Hong-Kong-Exchange
- CN='HKEX',
+ CN="HKEX",
# https://www.quandl.com/data/SSE-Boerse-Stuttgart
- DE='SSE',
- FR='EURONEXT',
+ DE="SSE",
+ FR="EURONEXT",
# https://www.quandl.com/data/NSE-National-Stock-Exchange-of-India
- IN='NSE',
+ IN="NSE",
# https://www.quandl.com/data/TSE-Tokyo-Stock-Exchange
- JP='TSE',
- NL='EURONEXT',
- PT='EURONEXT',
+ JP="TSE",
+ NL="EURONEXT",
+ PT="EURONEXT",
# https://www.quandl.com/data/LSE-London-Stock-Exchange
- UK='LSE',
+ UK="LSE",
# https://www.quandl.com/data/WIKI-Wiki-EOD-Stock-Prices
- US='WIKI',
+ US="WIKI",
)
def _db_from_countrycode(self, code):
- assert code in self._COUNTRYCODE_TO_DATASET, \
+ assert code in self._COUNTRYCODE_TO_DATASET, (
"No Quandl dataset known for country code '%s'" % code
+ )
return self._COUNTRYCODE_TO_DATASET[code]
def _get_params(self, symbol):
@@ -122,13 +134,15 @@ def _get_params(self, symbol):
def read(self):
"""Read data"""
df = super(QuandlReader, self).read()
- df.rename(columns=lambda n: n.replace(' ', '')
- .replace('.', '')
- .replace('/', '')
- .replace('%', '')
- .replace('(', '')
- .replace(')', '')
- .replace("'", '')
- .replace('-', ''),
- inplace=True)
+ df.rename(
+ columns=lambda n: n.replace(" ", "")
+ .replace(".", "")
+ .replace("/", "")
+ .replace("%", "")
+ .replace("(", "")
+ .replace(")", "")
+ .replace("'", "")
+ .replace("-", ""),
+ inplace=True,
+ )
return df
diff --git a/pandas_datareader/robinhood.py b/pandas_datareader/robinhood.py
index adca1c40..639a3e56 100644
--- a/pandas_datareader/robinhood.py
+++ b/pandas_datareader/robinhood.py
@@ -1,8 +1,7 @@
import pandas as pd
from pandas_datareader.base import _BaseReader
-from pandas_datareader.exceptions import (ImmediateDeprecationError,
- DEP_ERROR_MSG)
+from pandas_datareader.exceptions import DEP_ERROR_MSG, ImmediateDeprecationError
class RobinhoodQuoteReader(_BaseReader):
@@ -29,14 +28,24 @@ class RobinhoodQuoteReader(_BaseReader):
freq : None
Quotes are near real-time and so this value is ignored
"""
- _format = 'json'
- def __init__(self, symbols, start=None, end=None, retry_count=3, pause=.1,
- timeout=30, session=None, freq=None):
+ _format = "json"
+
+ def __init__(
+ self,
+ symbols,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ timeout=30,
+ session=None,
+ freq=None,
+ ):
raise ImmediateDeprecationError(DEP_ERROR_MSG.format("Robinhood"))
- super(RobinhoodQuoteReader, self).__init__(symbols, start, end,
- retry_count, pause,
- timeout, session, freq)
+ super(RobinhoodQuoteReader, self).__init__(
+ symbols, start, end, retry_count, pause, timeout, session, freq
+ )
if isinstance(self.symbols, str):
self.symbols = [self.symbols]
self._max_symbols = 1630
@@ -45,8 +54,10 @@ def __init__(self, symbols, start=None, end=None, retry_count=3, pause=.1,
def _validate_symbols(self):
if len(self.symbols) > self._max_symbols:
- raise ValueError('A maximum of {0} symbols are supported '
- 'in a single call.'.format(self._max_symbols))
+ raise ValueError(
+ "A maximum of {0} symbols are supported "
+ "in a single call.".format(self._max_symbols)
+ )
def _get_crumb(self, *args):
pass
@@ -54,23 +65,23 @@ def _get_crumb(self, *args):
@property
def url(self):
"""API URL"""
- return 'https://api.robinhood.com/quotes/'
+ return "https://api.robinhood.com/quotes/"
@property
def params(self):
"""Parameters to use in API calls"""
- symbols = ','.join(self.symbols)
- return {'symbols': symbols}
+ symbols = ",".join(self.symbols)
+ return {"symbols": symbols}
def _process_json(self):
res = pd.DataFrame(self._json_results)
- return res.set_index('symbol').T
+ return res.set_index("symbol").T
def _read_lines(self, out):
- if 'next' in out:
- self._json_results.extend(out['results'])
- return self._read_one_data(out['next'])
- self._json_results.extend(out['results'])
+ if "next" in out:
+ self._json_results.extend(out["results"])
+ return self._read_one_data(out["next"])
+ self._json_results.extend(out["results"])
return self._process_json()
@@ -114,26 +125,42 @@ class RobinhoodHistoricalReader(RobinhoodQuoteReader):
* 5minute: day, week
* 10minute: day, week
"""
- _format = 'json'
- def __init__(self, symbols, start=None, end=None, retry_count=3, pause=.1,
- timeout=30, session=None, freq=None, interval='day',
- span='year'):
+ _format = "json"
+
+ def __init__(
+ self,
+ symbols,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ timeout=30,
+ session=None,
+ freq=None,
+ interval="day",
+ span="year",
+ ):
raise ImmediateDeprecationError(DEP_ERROR_MSG.format("Robinhood"))
- super(RobinhoodHistoricalReader, self).__init__(symbols, start, end,
- retry_count, pause,
- timeout, session, freq)
- interval_span = {'day': ['year'],
- 'week': ['5year'],
- '10minute': ['day', 'week'],
- '5minute': ['day', 'week']}
+ super(RobinhoodHistoricalReader, self).__init__(
+ symbols, start, end, retry_count, pause, timeout, session, freq
+ )
+ interval_span = {
+ "day": ["year"],
+ "week": ["5year"],
+ "10minute": ["day", "week"],
+ "5minute": ["day", "week"],
+ }
if interval not in interval_span:
- raise ValueError('Interval must be one of '
- '{0}'.format(', '.join(interval_span.keys())))
+ raise ValueError(
+ "Interval must be one of " "{0}".format(", ".join(interval_span.keys()))
+ )
valid_spans = interval_span[interval]
if span not in valid_spans:
- raise ValueError('For interval {0}, span must '
- 'be in: {1}'.format(interval, valid_spans))
+ raise ValueError(
+ "For interval {0}, span must "
+ "be in: {1}".format(interval, valid_spans)
+ )
self.interval = interval
self.span = span
self._max_symbols = 75
@@ -143,23 +170,21 @@ def __init__(self, symbols, start=None, end=None, retry_count=3, pause=.1,
@property
def url(self):
"""API URL"""
- return 'https://api.robinhood.com/quotes/historicals/'
+ return "https://api.robinhood.com/quotes/historicals/"
@property
def params(self):
"""Parameters to use in API calls"""
- symbols = ','.join(self.symbols)
- pars = {'symbols': symbols,
- 'interval': self.interval,
- 'span': self.span}
+ symbols = ",".join(self.symbols)
+ pars = {"symbols": symbols, "interval": self.interval, "span": self.span}
return pars
def _process_json(self):
df = []
for sym in self._json_results:
- vals = pd.DataFrame(sym['historicals'])
- vals['begins_at'] = pd.to_datetime(vals['begins_at'])
- vals['symbol'] = sym['symbol']
- df.append(vals.set_index(['symbol', 'begins_at']))
+ vals = pd.DataFrame(sym["historicals"])
+ vals["begins_at"] = pd.to_datetime(vals["begins_at"])
+ vals["symbol"] = sym["symbol"]
+ df.append(vals.set_index(["symbol", "begins_at"]))
return pd.concat(df, 0)
diff --git a/pandas_datareader/stooq.py b/pandas_datareader/stooq.py
index 35fe8317..0895a467 100644
--- a/pandas_datareader/stooq.py
+++ b/pandas_datareader/stooq.py
@@ -33,25 +33,24 @@ class StooqDailyReader(_DailyBaseReader):
@property
def url(self):
"""API URL"""
- return 'https://stooq.com/q/d/l/'
+ return "https://stooq.com/q/d/l/"
- def _get_params(self, symbol, country='US'):
+ def _get_params(self, symbol, country="US"):
symbol_parts = symbol.split(".")
- if not symbol.startswith('^'):
+ if not symbol.startswith("^"):
if len(symbol_parts) == 1:
symbol = ".".join([symbol, country])
- elif symbol_parts[1].lower() == 'pl':
+ elif symbol_parts[1].lower() == "pl":
symbol = symbol_parts[0]
else:
- if symbol_parts[1].lower() not in ['de', 'hk', 'hu', 'jp',
- 'uk', 'us']:
- symbol = ".".join([symbol, 'US'])
+ if symbol_parts[1].lower() not in ["de", "hk", "hu", "jp", "uk", "us"]:
+ symbol = ".".join([symbol, "US"])
params = {
- 's': symbol,
- 'i': self.freq or 'd',
- 'd1': self.start.strftime('%Y%m%d'),
- 'd2': self.end.strftime('%Y%m%d')
+ "s": symbol,
+ "i": self.freq or "d",
+ "d1": self.start.strftime("%Y%m%d"),
+ "d2": self.end.strftime("%Y%m%d"),
}
return params
diff --git a/pandas_datareader/tests/io/test_jsdmx.py b/pandas_datareader/tests/io/test_jsdmx.py
index 1a438200..e34c128a 100644
--- a/pandas_datareader/tests/io/test_jsdmx.py
+++ b/pandas_datareader/tests/io/test_jsdmx.py
@@ -18,25 +18,46 @@ def dirpath(datapath):
return datapath("io", "data")
-@pytest.mark.skipif(not PANDAS_0210, reason='Broken on old pandas')
+@pytest.mark.skipif(not PANDAS_0210, reason="Broken on old pandas")
def test_tourism(dirpath):
# OECD -> Industry and Services -> Inbound Tourism
- result = read_jsdmx(os.path.join(dirpath, 'jsdmx',
- 'tourism.json'))
+ result = read_jsdmx(os.path.join(dirpath, "jsdmx", "tourism.json"))
assert isinstance(result, pd.DataFrame)
- jp = result['Japan']
- visitors = ['China', 'Hong Kong, China',
- 'Total international arrivals',
- 'Korea', 'Chinese Taipei', 'United States']
+ jp = result["Japan"]
+ visitors = [
+ "China",
+ "Hong Kong, China",
+ "Total international arrivals",
+ "Korea",
+ "Chinese Taipei",
+ "United States",
+ ]
exp_col = pd.Index(
- ['China', 'Hong Kong, China', 'Total international arrivals',
- 'Korea', 'Chinese Taipei', 'United States'],
- name='Variable')
- exp_idx = pd.DatetimeIndex(['2008-01-01', '2009-01-01', '2010-01-01',
- '2011-01-01', '2012-01-01', '2013-01-01',
- '2014-01-01', '2015-01-01', '2016-01-01'],
- name='Year')
+ [
+ "China",
+ "Hong Kong, China",
+ "Total international arrivals",
+ "Korea",
+ "Chinese Taipei",
+ "United States",
+ ],
+ name="Variable",
+ )
+ exp_idx = pd.DatetimeIndex(
+ [
+ "2008-01-01",
+ "2009-01-01",
+ "2010-01-01",
+ "2011-01-01",
+ "2012-01-01",
+ "2013-01-01",
+ "2014-01-01",
+ "2015-01-01",
+ "2016-01-01",
+ ],
+ name="Year",
+ )
values = [
[1000000.0, 550000.0, 8351000.0, 2382000.0, 1390000.0, 768000.0],
[1006000.0, 450000.0, 6790000.0, 1587000.0, 1024000.0, 700000.0],
@@ -45,44 +66,83 @@ def test_tourism(dirpath):
[1430000.0, 482000.0, 8368000.0, 2044000.0, 1467000.0, 717000.0],
[1314000.0, 746000.0, 10364000.0, 2456000.0, 2211000.0, 799000.0],
[2409000.0, 926000.0, 13413000.0, 2755000.0, 2830000.0, 892000.0],
- [4993689.0, 1524292.0, 19737409.0, 4002095.0, 3677075.0,
- 1033258.0],
- [6373564.0, 1839193.0, 24039700.0, 5090302.0, 4167512.0, 1242719.0]
+ [4993689.0, 1524292.0, 19737409.0, 4002095.0, 3677075.0, 1033258.0],
+ [6373564.0, 1839193.0, 24039700.0, 5090302.0, 4167512.0, 1242719.0],
]
- values = np.array(values, dtype='object')
+ values = np.array(values, dtype="object")
expected = pd.DataFrame(values, index=exp_idx, columns=exp_col)
tm.assert_frame_equal(jp[visitors], expected)
-@pytest.mark.skipif(not PANDAS_0210, reason='Broken on old pandas')
+@pytest.mark.skipif(not PANDAS_0210, reason="Broken on old pandas")
def test_land_use(dirpath):
# OECD -> Environment -> Resources Land Use
- result = read_jsdmx(os.path.join(dirpath, 'jsdmx',
- 'land_use.json'))
+ result = read_jsdmx(os.path.join(dirpath, "jsdmx", "land_use.json"))
assert isinstance(result, pd.DataFrame)
- result = result.loc['2010':'2011']
+ result = result.loc["2010":"2011"]
- cols = ['Arable land and permanent crops',
- 'Arable and cropland % land area',
- 'Total area', 'Forest', 'Forest % land area',
- 'Land area', 'Permanent meadows and pastures',
- 'Meadows and pastures % land area', 'Other areas',
- 'Other % land area']
- exp_col = pd.MultiIndex.from_product([
- ['Japan', 'United States'],
- cols], names=['Country', 'Variable'])
- exp_idx = pd.DatetimeIndex(['2010', '2011'], name='Year')
+ cols = [
+ "Arable land and permanent crops",
+ "Arable and cropland % land area",
+ "Total area",
+ "Forest",
+ "Forest % land area",
+ "Land area",
+ "Permanent meadows and pastures",
+ "Meadows and pastures % land area",
+ "Other areas",
+ "Other % land area",
+ ]
+ exp_col = pd.MultiIndex.from_product(
+ [["Japan", "United States"], cols], names=["Country", "Variable"]
+ )
+ exp_idx = pd.DatetimeIndex(["2010", "2011"], name="Year")
values = [
- [53790.0, 14.753154141525, 377800.0, np.nan, np.nan, 364600.0,
- 5000.0, 1.3713658804169, np.nan, np.nan,
- 1897990.0, 20.722767650476, 9629090.0, np.nan, np.nan, 9158960.0,
- 2416000.0, 26.378540795025, np.nan,
- np.nan],
- [53580.0, 14.691527282698, 377800.0, np.nan, np.nan, 364700.0,
- 5000.0, 1.3709898546751, np.nan, np.nan,
- 1897990.0, 20.722767650476, 9629090.0, np.nan, np.nan, 9158960.0,
- 2416000.0, 26.378540795025, np.nan,
- np.nan]]
+ [
+ 53790.0,
+ 14.753154141525,
+ 377800.0,
+ np.nan,
+ np.nan,
+ 364600.0,
+ 5000.0,
+ 1.3713658804169,
+ np.nan,
+ np.nan,
+ 1897990.0,
+ 20.722767650476,
+ 9629090.0,
+ np.nan,
+ np.nan,
+ 9158960.0,
+ 2416000.0,
+ 26.378540795025,
+ np.nan,
+ np.nan,
+ ],
+ [
+ 53580.0,
+ 14.691527282698,
+ 377800.0,
+ np.nan,
+ np.nan,
+ 364700.0,
+ 5000.0,
+ 1.3709898546751,
+ np.nan,
+ np.nan,
+ 1897990.0,
+ 20.722767650476,
+ 9629090.0,
+ np.nan,
+ np.nan,
+ 9158960.0,
+ 2416000.0,
+ 26.378540795025,
+ np.nan,
+ np.nan,
+ ],
+ ]
values = np.array(values)
expected = pd.DataFrame(values, index=exp_idx, columns=exp_col)
tm.assert_frame_equal(result[exp_col], expected)
diff --git a/pandas_datareader/tests/io/test_sdmx.py b/pandas_datareader/tests/io/test_sdmx.py
index d52e56e2..de31822c 100644
--- a/pandas_datareader/tests/io/test_sdmx.py
+++ b/pandas_datareader/tests/io/test_sdmx.py
@@ -7,7 +7,7 @@
import pandas.util.testing as tm
import pytest
-from pandas_datareader.io.sdmx import read_sdmx, _read_sdmx_dsd
+from pandas_datareader.io.sdmx import _read_sdmx_dsd, read_sdmx
pytestmark = pytest.mark.stable
@@ -21,23 +21,20 @@ def test_tourism(dirpath):
# Eurostat
# Employed doctorate holders in non managerial and non professional
# occupations by fields of science (%)
- dsd = _read_sdmx_dsd(os.path.join(dirpath, 'sdmx',
- 'DSD_cdh_e_fos.xml'))
- df = read_sdmx(os.path.join(dirpath, 'sdmx',
- 'cdh_e_fos.xml'), dsd=dsd)
+ dsd = _read_sdmx_dsd(os.path.join(dirpath, "sdmx", "DSD_cdh_e_fos.xml"))
+ df = read_sdmx(os.path.join(dirpath, "sdmx", "cdh_e_fos.xml"), dsd=dsd)
assert isinstance(df, pd.DataFrame)
assert df.shape == (2, 336)
- df = df['Percentage']['Total']['Natural sciences']
- df = df[['Norway', 'Poland', 'Portugal', 'Russia']]
+ df = df["Percentage"]["Total"]["Natural sciences"]
+ df = df[["Norway", "Poland", "Portugal", "Russia"]]
- exp_col = pd.MultiIndex.from_product([['Norway', 'Poland', 'Portugal',
- 'Russia'], ['Annual']],
- names=['GEO', 'FREQ'])
- exp_idx = pd.DatetimeIndex(['2009', '2006'], name='TIME_PERIOD')
+ exp_col = pd.MultiIndex.from_product(
+ [["Norway", "Poland", "Portugal", "Russia"], ["Annual"]], names=["GEO", "FREQ"]
+ )
+ exp_idx = pd.DatetimeIndex(["2009", "2006"], name="TIME_PERIOD")
- values = np.array([[20.38, 25.1, 27.77, 38.1],
- [25.49, np.nan, 39.05, np.nan]])
+ values = np.array([[20.38, 25.1, 27.77, 38.1], [25.49, np.nan, 39.05, np.nan]])
expected = pd.DataFrame(values, index=exp_idx, columns=exp_col)
tm.assert_frame_equal(df, expected)
diff --git a/pandas_datareader/tests/test_bankofcanada.py b/pandas_datareader/tests/test_bankofcanada.py
index daff9f58..fe9cb12b 100644
--- a/pandas_datareader/tests/test_bankofcanada.py
+++ b/pandas_datareader/tests/test_bankofcanada.py
@@ -2,68 +2,81 @@
import pytest
-import pandas_datareader.data as web
from pandas_datareader._utils import RemoteDataError
+import pandas_datareader.data as web
pytestmark = pytest.mark.stable
class TestBankOfCanada(object):
-
@staticmethod
def get_symbol(currency_code, inverted=False):
if inverted:
- return 'FXCAD{}'.format(currency_code)
+ return "FXCAD{}".format(currency_code)
else:
- return 'FX{}CAD'.format(currency_code)
+ return "FX{}CAD".format(currency_code)
def check_bankofcanada_count(self, code):
start, end = date.today() - timedelta(days=30), date.today()
- df = web.DataReader(self.get_symbol(code), 'bankofcanada', start, end)
+ df = web.DataReader(self.get_symbol(code), "bankofcanada", start, end)
assert df.size > 15
def check_bankofcanada_valid(self, code):
symbol = self.get_symbol(code)
- df = web.DataReader(symbol, 'bankofcanada',
- date.today() - timedelta(days=30), date.today())
+ df = web.DataReader(
+ symbol, "bankofcanada", date.today() - timedelta(days=30), date.today()
+ )
assert symbol in df.columns
def check_bankofcanada_inverted(self, code):
symbol = self.get_symbol(code)
symbol_inverted = self.get_symbol(code, inverted=True)
- df = web.DataReader(symbol, 'bankofcanada',
- date.today() - timedelta(days=30), date.today())
- df_i = web.DataReader(symbol_inverted, 'bankofcanada',
- date.today() - timedelta(days=30), date.today())
+ df = web.DataReader(
+ symbol, "bankofcanada", date.today() - timedelta(days=30), date.today()
+ )
+ df_i = web.DataReader(
+ symbol_inverted,
+ "bankofcanada",
+ date.today() - timedelta(days=30),
+ date.today(),
+ )
pairs = zip((1 / df)[symbol].tolist(), df_i[symbol_inverted].tolist())
assert all(a - b < 0.01 for a, b in pairs)
def test_bankofcanada_usd_count(self):
- self.check_bankofcanada_count('USD')
+ self.check_bankofcanada_count("USD")
def test_bankofcanada_eur_count(self):
- self.check_bankofcanada_count('EUR')
+ self.check_bankofcanada_count("EUR")
def test_bankofcanada_usd_valid(self):
- self.check_bankofcanada_valid('USD')
+ self.check_bankofcanada_valid("USD")
def test_bankofcanada_eur_valid(self):
- self.check_bankofcanada_valid('EUR')
+ self.check_bankofcanada_valid("EUR")
def test_bankofcanada_usd_inverted(self):
- self.check_bankofcanada_inverted('USD')
+ self.check_bankofcanada_inverted("USD")
def test_bankofcanada_eur_inverted(self):
- self.check_bankofcanada_inverted('EUR')
+ self.check_bankofcanada_inverted("EUR")
def test_bankofcanada_bad_range(self):
with pytest.raises(ValueError):
- web.DataReader('FXCADUSD', 'bankofcanada',
- date.today(), date.today() - timedelta(days=30))
+ web.DataReader(
+ "FXCADUSD",
+ "bankofcanada",
+ date.today(),
+ date.today() - timedelta(days=30),
+ )
def test_bankofcanada_bad_url(self):
with pytest.raises(RemoteDataError):
- web.DataReader('abcdefgh', 'bankofcanada',
- date.today() - timedelta(days=30), date.today())
+ web.DataReader(
+ "abcdefgh",
+ "bankofcanada",
+ date.today() - timedelta(days=30),
+ date.today(),
+ )
diff --git a/pandas_datareader/tests/test_base.py b/pandas_datareader/tests/test_base.py
index 11103592..2154b81b 100644
--- a/pandas_datareader/tests/test_base.py
+++ b/pandas_datareader/tests/test_base.py
@@ -8,11 +8,11 @@
class TestBaseReader(object):
def test_requests_not_monkey_patched(self):
- assert not hasattr(requests.Session(), 'stor')
+ assert not hasattr(requests.Session(), "stor")
def test_valid_retry_count(self):
with pytest.raises(ValueError):
- base._BaseReader([], retry_count='stuff')
+ base._BaseReader([], retry_count="stuff")
with pytest.raises(ValueError):
base._BaseReader([], retry_count=-1)
@@ -23,8 +23,8 @@ def test_invalid_url(self):
def test_invalid_format(self):
with pytest.raises(NotImplementedError):
b = base._BaseReader([])
- b._format = 'IM_NOT_AN_IMPLEMENTED_TYPE'
- b._read_one_data('a', None)
+ b._format = "IM_NOT_AN_IMPLEMENTED_TYPE"
+ b._read_one_data("a", None)
class TestDailyBaseReader(object):
diff --git a/pandas_datareader/tests/test_data.py b/pandas_datareader/tests/test_data.py
index 9ecd970c..4cdea1d0 100644
--- a/pandas_datareader/tests/test_data.py
+++ b/pandas_datareader/tests/test_data.py
@@ -1,13 +1,12 @@
+from pandas import DataFrame
import pytest
-from pandas import DataFrame
from pandas_datareader.data import DataReader
pytestmark = pytest.mark.stable
class TestDataReader(object):
-
def test_read_iex(self):
gs = DataReader("GS", "iex-last")
assert isinstance(gs, DataFrame)
diff --git a/pandas_datareader/tests/test_econdb.py b/pandas_datareader/tests/test_econdb.py
index 1b960dfb..58bb48cc 100644
--- a/pandas_datareader/tests/test_econdb.py
+++ b/pandas_datareader/tests/test_econdb.py
@@ -8,16 +8,16 @@
class TestEcondb(object):
-
def test_get_cdh_e_fos(self):
# EUROSTAT
# Employed doctorate holders in non managerial and non professional
# occupations by fields of science (%)
df = web.DataReader(
- 'dataset=CDH_E_FOS&GEO=NO,PL,PT,RU&FOS07=FOS1&Y_GRAD=TOTAL',
- 'econdb',
- start=pd.Timestamp('2005-01-01'),
- end=pd.Timestamp('2010-01-01'))
+ "dataset=CDH_E_FOS&GEO=NO,PL,PT,RU&FOS07=FOS1&Y_GRAD=TOTAL",
+ "econdb",
+ start=pd.Timestamp("2005-01-01"),
+ end=pd.Timestamp("2010-01-01"),
+ )
assert isinstance(df, pd.DataFrame)
assert df.shape == (2, 4)
@@ -26,11 +26,9 @@ def test_get_cdh_e_fos(self):
levels = [lvl.values.tolist() for lvl in list(df.columns.levels)]
exp_col = pd.MultiIndex.from_product(levels, names=names)
- exp_idx = pd.DatetimeIndex(['2006-01-01', '2009-01-01'],
- name='TIME_PERIOD')
+ exp_idx = pd.DatetimeIndex(["2006-01-01", "2009-01-01"], name="TIME_PERIOD")
- values = np.array([[25.49, np.nan, 39.05, np.nan],
- [20.38, 25.1, 27.77, 38.1]])
+ values = np.array([[25.49, np.nan, 39.05, np.nan], [20.38, 25.1, 27.77, 38.1]])
expected = pd.DataFrame(values, index=exp_idx, columns=exp_col)
tm.assert_frame_equal(df, expected)
@@ -39,32 +37,36 @@ def test_get_tourism(self):
# TOURISM_INBOUND
df = web.DataReader(
- 'dataset=OE_TOURISM_INBOUND&COUNTRY=JPN,USA&'
- 'VARIABLE=INB_ARRIVALS_TOTAL', 'econdb',
- start=pd.Timestamp('2008-01-01'), end=pd.Timestamp('2012-01-01'))
+ "dataset=OE_TOURISM_INBOUND&COUNTRY=JPN,USA&VARIABLE=INB_ARRIVALS_TOTAL",
+ "econdb",
+ start=pd.Timestamp("2008-01-01"),
+ end=pd.Timestamp("2012-01-01"),
+ )
df = df.astype(np.float)
- jp = np.array([8351000, 6790000, 8611000, 6219000,
- 8368000], dtype=float)
- us = np.array([175702304, 160507424, 164079728, 167600272,
- 171320416], dtype=float)
- index = pd.date_range('2008-01-01', '2012-01-01', freq='AS',
- name='TIME_PERIOD')
+ jp = np.array([8351000, 6790000, 8611000, 6219000, 8368000], dtype=float)
+ us = np.array(
+ [175702304, 160507424, 164079728, 167600272, 171320416], dtype=float
+ )
+ index = pd.date_range("2008-01-01", "2012-01-01", freq="AS", name="TIME_PERIOD")
# sometimes the country and variable columns are swapped
lvl1 = df.columns.levels[0][0]
if lvl1 == "Total international arrivals":
df = df.swaplevel(0, 1, axis=1)
- for label, values in [('Japan', jp), ('United States', us)]:
- expected = pd.Series(values, index=index,
- name='Total international arrivals')
- tm.assert_series_equal(df[label]['Total international arrivals'],
- expected)
+ for label, values in [("Japan", jp), ("United States", us)]:
+ expected = pd.Series(
+ values, index=index, name="Total international arrivals"
+ )
+ tm.assert_series_equal(df[label]["Total international arrivals"], expected)
def test_bls(self):
# BLS
# CPI
df = web.DataReader(
- 'ticker=BLS_CU.CUSR0000SA0.M.US', 'econdb',
- start=pd.Timestamp('2010-01-01'), end=pd.Timestamp('2013-01-27'))
+ "ticker=BLS_CU.CUSR0000SA0.M.US",
+ "econdb",
+ start=pd.Timestamp("2010-01-01"),
+ end=pd.Timestamp("2013-01-27"),
+ )
- assert df.loc['2010-05-01'][0] == 217.3
+ assert df.loc["2010-05-01"][0] == 217.3
diff --git a/pandas_datareader/tests/test_enigma.py b/pandas_datareader/tests/test_enigma.py
index 46340d5d..7c34e964 100644
--- a/pandas_datareader/tests/test_enigma.py
+++ b/pandas_datareader/tests/test_enigma.py
@@ -1,7 +1,8 @@
import os
-import pytest
+import pytest
from requests.exceptions import HTTPError
+
import pandas_datareader as pdr
import pandas_datareader.data as web
@@ -12,7 +13,6 @@
@pytest.mark.skipif(TEST_API_KEY is None, reason="no enigma_api_key")
class TestEnigma(object):
-
@property
def dataset_id(self):
"""
@@ -28,25 +28,24 @@ def setup_class(cls):
def test_enigma_datareader(self):
try:
- df = web.DataReader(self.dataset_id,
- 'enigma', access_key=TEST_API_KEY)
- assert 'case_number' in df.columns
+ df = web.DataReader(self.dataset_id, "enigma", access_key=TEST_API_KEY)
+ assert "case_number" in df.columns
except HTTPError as e:
pytest.skip(e)
def test_enigma_get_data_enigma(self):
try:
df = pdr.get_data_enigma(self.dataset_id, TEST_API_KEY)
- assert 'case_number' in df.columns
+ assert "case_number" in df.columns
except HTTPError as e:
pytest.skip(e)
def test_bad_key(self):
with pytest.raises(HTTPError):
- web.DataReader(self.dataset_id,
- 'enigma', access_key=TEST_API_KEY + 'xxx')
+ web.DataReader(self.dataset_id, "enigma", access_key=TEST_API_KEY + "xxx")
def test_bad_dataset_id(self):
with pytest.raises(HTTPError):
- web.DataReader('zzzzzzzz-zzzz-zzzz-zzzz-zzzzzzzzzzz',
- 'enigma', access_key=TEST_API_KEY)
+ web.DataReader(
+ "zzzzzzzz-zzzz-zzzz-zzzz-zzzzzzzzzzz", "enigma", access_key=TEST_API_KEY
+ )
diff --git a/pandas_datareader/tests/test_famafrench.py b/pandas_datareader/tests/test_famafrench.py
index 08335968..c73d476b 100644
--- a/pandas_datareader/tests/test_famafrench.py
+++ b/pandas_datareader/tests/test_famafrench.py
@@ -1,7 +1,6 @@
-import pytest
-
import pandas as pd
import pandas.util.testing as tm
+import pytest
import pandas_datareader.data as web
from pandas_datareader.famafrench import get_available_datasets
@@ -10,17 +9,19 @@
class TestFamaFrench(object):
-
def test_get_data(self):
keys = [
- 'F-F_Research_Data_Factors', 'F-F_ST_Reversal_Factor',
- '6_Portfolios_2x3', 'Portfolios_Formed_on_ME',
- 'Prior_2-12_Breakpoints', 'ME_Breakpoints',
+ "F-F_Research_Data_Factors",
+ "F-F_ST_Reversal_Factor",
+ "6_Portfolios_2x3",
+ "Portfolios_Formed_on_ME",
+ "Prior_2-12_Breakpoints",
+ "ME_Breakpoints",
]
for name in keys:
- ff = web.DataReader(name, 'famafrench')
- assert 'DESCR' in ff
+ ff = web.DataReader(name, "famafrench")
+ assert "DESCR" in ff
assert len(ff) > 1
def test_get_available_datasets(self):
@@ -29,63 +30,159 @@ def test_get_available_datasets(self):
assert len(avail) > 100
def test_index(self):
- ff = web.DataReader('F-F_Research_Data_Factors', 'famafrench')
- assert ff[0].index.freq == 'M'
- assert ff[1].index.freq == 'A-DEC'
+ ff = web.DataReader("F-F_Research_Data_Factors", "famafrench")
+ assert ff[0].index.freq == "M"
+ assert ff[1].index.freq == "A-DEC"
def test_f_f_research(self):
- results = web.DataReader("F-F_Research_Data_Factors", "famafrench",
- start='2010-01-01', end='2010-12-01')
+ results = web.DataReader(
+ "F-F_Research_Data_Factors",
+ "famafrench",
+ start="2010-01-01",
+ end="2010-12-01",
+ )
assert isinstance(results, dict)
assert len(results) == 3
- exp = pd.DataFrame({'Mkt-RF': [-3.36, 3.4, 6.31, 2., -7.89, -5.56,
- 6.93, -4.77, 9.54, 3.88, 0.6, 6.82],
- 'SMB': [0.38, 1.2, 1.42, 4.98, 0.05, -1.97, 0.16,
- -3.00, 3.92, 1.15, 3.70, 0.7],
- 'HML': [0.31, 3.16, 2.1, 2.81, -2.38, -4.5, -0.27,
- -1.95, -3.12, -2.59, -0.9, 3.81],
- 'RF': [0., 0., 0.01, 0.01, 0.01, 0.01, 0.01,
- 0.01, 0.01, 0.01, 0.01, 0.01]},
- index=pd.period_range('2010-01-01', '2010-12-01',
- freq='M', name='Date'),
- columns=['Mkt-RF', 'SMB', 'HML', 'RF'])
+ exp = pd.DataFrame(
+ {
+ "Mkt-RF": [
+ -3.36,
+ 3.4,
+ 6.31,
+ 2.0,
+ -7.89,
+ -5.56,
+ 6.93,
+ -4.77,
+ 9.54,
+ 3.88,
+ 0.6,
+ 6.82,
+ ],
+ "SMB": [
+ 0.38,
+ 1.2,
+ 1.42,
+ 4.98,
+ 0.05,
+ -1.97,
+ 0.16,
+ -3.00,
+ 3.92,
+ 1.15,
+ 3.70,
+ 0.7,
+ ],
+ "HML": [
+ 0.31,
+ 3.16,
+ 2.1,
+ 2.81,
+ -2.38,
+ -4.5,
+ -0.27,
+ -1.95,
+ -3.12,
+ -2.59,
+ -0.9,
+ 3.81,
+ ],
+ "RF": [
+ 0.0,
+ 0.0,
+ 0.01,
+ 0.01,
+ 0.01,
+ 0.01,
+ 0.01,
+ 0.01,
+ 0.01,
+ 0.01,
+ 0.01,
+ 0.01,
+ ],
+ },
+ index=pd.period_range("2010-01-01", "2010-12-01", freq="M", name="Date"),
+ columns=["Mkt-RF", "SMB", "HML", "RF"],
+ )
tm.assert_frame_equal(results[0], exp, check_less_precise=0)
def test_me_breakpoints(self):
- results = web.DataReader("ME_Breakpoints", "famafrench",
- start='2010-01-01', end='2010-12-01')
+ results = web.DataReader(
+ "ME_Breakpoints", "famafrench", start="2010-01-01", end="2010-12-01"
+ )
assert isinstance(results, dict)
assert len(results) == 2
assert results[0].shape == (12, 21)
- exp_columns = pd.Index(['Count', (0, 5), (5, 10), (10, 15), (15, 20),
- (20, 25), (25, 30), (30, 35), (35, 40),
- (40, 45), (45, 50), (50, 55), (55, 60),
- (60, 65), (65, 70), (70, 75), (75, 80),
- (80, 85), (85, 90), (90, 95), (95, 100)],
- dtype='object')
+ exp_columns = pd.Index(
+ [
+ "Count",
+ (0, 5),
+ (5, 10),
+ (10, 15),
+ (15, 20),
+ (20, 25),
+ (25, 30),
+ (30, 35),
+ (35, 40),
+ (40, 45),
+ (45, 50),
+ (50, 55),
+ (55, 60),
+ (60, 65),
+ (65, 70),
+ (70, 75),
+ (75, 80),
+ (80, 85),
+ (85, 90),
+ (90, 95),
+ (95, 100),
+ ],
+ dtype="object",
+ )
tm.assert_index_equal(results[0].columns, exp_columns)
- exp_index = pd.period_range('2010-01-01', '2010-12-01',
- freq='M', name='Date')
+ exp_index = pd.period_range("2010-01-01", "2010-12-01", freq="M", name="Date")
tm.assert_index_equal(results[0].index, exp_index)
def test_prior_2_12_breakpoints(self):
- results = web.DataReader("Prior_2-12_Breakpoints", "famafrench",
- start='2010-01-01', end='2010-12-01')
+ results = web.DataReader(
+ "Prior_2-12_Breakpoints", "famafrench", start="2010-01-01", end="2010-12-01"
+ )
assert isinstance(results, dict)
assert len(results) == 2
assert results[0].shape == (12, 22)
- exp_columns = pd.Index(['<=0', '>0', (0, 5), (5, 10), (10, 15),
- (15, 20), (20, 25), (25, 30), (30, 35),
- (35, 40), (40, 45), (45, 50), (50, 55),
- (55, 60), (60, 65), (65, 70), (70, 75),
- (75, 80), (80, 85), (85, 90), (90, 95),
- (95, 100)], dtype='object')
+ exp_columns = pd.Index(
+ [
+ "<=0",
+ ">0",
+ (0, 5),
+ (5, 10),
+ (10, 15),
+ (15, 20),
+ (20, 25),
+ (25, 30),
+ (30, 35),
+ (35, 40),
+ (40, 45),
+ (45, 50),
+ (50, 55),
+ (55, 60),
+ (60, 65),
+ (65, 70),
+ (70, 75),
+ (75, 80),
+ (80, 85),
+ (85, 90),
+ (90, 95),
+ (95, 100),
+ ],
+ dtype="object",
+ )
tm.assert_index_equal(results[0].columns, exp_columns)
- exp_index = pd.period_range('2010-01-01', '2010-12-01',
- freq='M', name='Date')
+ exp_index = pd.period_range("2010-01-01", "2010-12-01", freq="M", name="Date")
tm.assert_index_equal(results[0].index, exp_index)
diff --git a/pandas_datareader/tests/test_fred.py b/pandas_datareader/tests/test_fred.py
index d55782a4..158721be 100644
--- a/pandas_datareader/tests/test_fred.py
+++ b/pandas_datareader/tests/test_fred.py
@@ -1,20 +1,18 @@
from datetime import datetime
-import pytest
-
import numpy as np
import pandas as pd
+from pandas import DataFrame
import pandas.util.testing as tm
-import pandas_datareader.data as web
+import pytest
-from pandas import DataFrame
from pandas_datareader._utils import RemoteDataError
+import pandas_datareader.data as web
pytestmark = pytest.mark.stable
class TestFred(object):
-
def test_fred(self):
# Raises an exception when DataReader can't
@@ -24,7 +22,7 @@ def test_fred(self):
end = datetime(2013, 1, 1)
df = web.DataReader("GDP", "fred", start, end)
- ts = df['GDP']
+ ts = df["GDP"]
assert ts.index[0] == pd.to_datetime("2010-01-01")
assert ts.index[-1] == pd.to_datetime("2013-01-01")
@@ -33,30 +31,26 @@ def test_fred(self):
assert len(ts) == 13
with pytest.raises(RemoteDataError):
- web.DataReader("NON EXISTENT SERIES", 'fred', start, end)
+ web.DataReader("NON EXISTENT SERIES", "fred", start, end)
def test_fred_nan(self):
start = datetime(2010, 1, 1)
end = datetime(2013, 1, 27)
df = web.DataReader("DFII5", "fred", start, end)
- assert pd.isnull(df.loc['2010-01-01'][0])
+ assert pd.isnull(df.loc["2010-01-01"][0])
def test_fred_parts(self): # pragma: no cover
start = datetime(2010, 1, 1)
end = datetime(2013, 1, 27)
df = web.get_data_fred("CPIAUCSL", start, end)
- assert df.loc['2010-05-01'][0] == 217.29
+ assert df.loc["2010-05-01"][0] == 217.29
t = df.CPIAUCSL.values
assert np.issubdtype(t.dtype, np.floating)
assert t.shape == (37,)
def test_fred_part2(self):
- expected = [[576.7],
- [962.9],
- [684.7],
- [848.3],
- [933.3]]
+ expected = [[576.7], [962.9], [684.7], [848.3], [933.3]]
result = web.get_data_fred("A09024USA144NNBR", start="1915").iloc[:5]
tm.assert_numpy_array_equal(result.values, np.array(expected))
@@ -66,18 +60,21 @@ def test_invalid_series(self):
web.get_data_fred(name)
def test_fred_multi(self): # pragma: no cover
- names = ['CPIAUCSL', 'CPALTT01USQ661S', 'CPILFESL']
+ names = ["CPIAUCSL", "CPALTT01USQ661S", "CPILFESL"]
start = datetime(2010, 1, 1)
end = datetime(2013, 1, 27)
received = web.DataReader(names, "fred", start, end).head(1)
- expected = DataFrame([[217.488, 91.712409, 220.633]], columns=names,
- index=[pd.Timestamp('2010-01-01 00:00:00')])
- expected.index.rename('DATE', inplace=True)
+ expected = DataFrame(
+ [[217.488, 91.712409, 220.633]],
+ columns=names,
+ index=[pd.Timestamp("2010-01-01 00:00:00")],
+ )
+ expected.index.rename("DATE", inplace=True)
tm.assert_frame_equal(received, expected, check_less_precise=True)
def test_fred_multi_bad_series(self):
- names = ['NOTAREALSERIES', 'CPIAUCSL', "ALSO FAKE"]
+ names = ["NOTAREALSERIES", "CPIAUCSL", "ALSO FAKE"]
with pytest.raises(RemoteDataError):
web.DataReader(names, data_source="fred")
diff --git a/pandas_datareader/tests/test_iex.py b/pandas_datareader/tests/test_iex.py
index f3933aad..6cc4b2d5 100644
--- a/pandas_datareader/tests/test_iex.py
+++ b/pandas_datareader/tests/test_iex.py
@@ -1,11 +1,16 @@
from datetime import datetime
-import pytest
from pandas import DataFrame
+import pytest
-from pandas_datareader.data import (DataReader, get_summary_iex, get_last_iex,
- get_dailysummary_iex, get_iex_symbols,
- get_iex_book)
+from pandas_datareader.data import (
+ DataReader,
+ get_dailysummary_iex,
+ get_iex_book,
+ get_iex_symbols,
+ get_last_iex,
+ get_summary_iex,
+)
from pandas_datareader.exceptions import UnstableAPIWarning
@@ -19,25 +24,26 @@ def test_read_iex(self):
assert isinstance(gs, DataFrame)
def test_historical(self):
- df = get_summary_iex(start=datetime(2017, 4, 1),
- end=datetime(2017, 4, 30))
+ df = get_summary_iex(start=datetime(2017, 4, 1), end=datetime(2017, 4, 30))
assert df.T["averageDailyVolume"].iloc[0] == 137650908.9
def test_false_ticker(self):
df = get_last_iex("INVALID TICKER")
assert df.shape[0] == 0
- @pytest.mark.xfail(reason='IEX daily history API is returning 500 as of '
- 'Jan 2018')
+ @pytest.mark.xfail(
+ reason="IEX daily history API is returning 500 as of " "Jan 2018"
+ )
def test_daily(self):
with pytest.warns(UnstableAPIWarning):
- df = get_dailysummary_iex(start=datetime(2017, 5, 5),
- end=datetime(2017, 5, 6))
- assert df['routedVolume'].iloc[0] == 39974788
+ df = get_dailysummary_iex(
+ start=datetime(2017, 5, 5), end=datetime(2017, 5, 6)
+ )
+ assert df["routedVolume"].iloc[0] == 39974788
def test_symbols(self):
df = get_iex_symbols()
- assert 'GS' in df.symbol.values
+ assert "GS" in df.symbol.values
def test_live_prices(self):
dftickers = get_iex_symbols()
@@ -46,8 +52,8 @@ def test_live_prices(self):
assert df["price"].mean() > 0
def test_deep(self):
- dob = get_iex_book('GS', service='book')
+ dob = get_iex_book("GS", service="book")
if dob:
- assert 'GS' in dob
+ assert "GS" in dob
else:
- pytest.xfail(reason='Can only get Book when market open')
+ pytest.xfail(reason="Can only get Book when market open")
diff --git a/pandas_datareader/tests/test_iex_daily.py b/pandas_datareader/tests/test_iex_daily.py
index e95ca82d..6df487d5 100644
--- a/pandas_datareader/tests/test_iex_daily.py
+++ b/pandas_datareader/tests/test_iex_daily.py
@@ -1,18 +1,17 @@
from datetime import date, datetime, timedelta
-
import os
from pandas import DataFrame, MultiIndex
-
import pytest
import pandas_datareader.data as web
from pandas_datareader.iex.daily import IEXDailyReader
-@pytest.mark.skipif(os.getenv("IEX_SANDBOX") != 'enable',
- reason='All tests must be run in sandbox mode')
-class TestIEXDaily(object):
+@pytest.mark.skipif(
+ os.getenv("IEX_SANDBOX") != "enable", reason="All tests must be run in sandbox mode"
+)
+class TestIEXDaily(object):
@classmethod
def setup_class(cls):
pytest.importorskip("lxml")
@@ -31,8 +30,7 @@ def test_iex_bad_symbol(self):
def test_iex_bad_symbol_list(self):
with pytest.raises(Exception):
- web.DataReader(["AAPL", "BADTICKER"], "iex",
- self.start, self.end)
+ web.DataReader(["AAPL", "BADTICKER"], "iex", self.start, self.end)
def test_daily_invalid_date(self):
start = datetime(2000, 1, 5)
@@ -50,7 +48,7 @@ def test_multiple_symbols(self):
df = web.DataReader(syms, "iex", self.start, self.end)
assert sorted(list(df.columns.levels[1])) == syms
for sym in syms:
- assert len(df.xs(sym, level='Symbols', axis=1) == 578)
+ assert len(df.xs(sym, level="Symbols", axis=1) == 578)
def test_multiple_symbols_2(self):
syms = ["AAPL", "MSFT", "TSLA"]
@@ -62,8 +60,8 @@ def test_multiple_symbols_2(self):
assert len(df.columns.levels[1]) == 3
assert sorted(list(df.columns.levels[1])) == syms
- a = df.xs("AAPL", axis=1, level='Symbols')
- t = df.xs("TSLA", axis=1, level='Symbols')
+ a = df.xs("AAPL", axis=1, level="Symbols")
+ t = df.xs("TSLA", axis=1, level="Symbols")
assert len(a) == 73
assert len(t) == 73
@@ -71,35 +69,53 @@ def test_multiple_symbols_2(self):
def test_range_string_from_date(self):
syms = ["AAPL"]
- assert IEXDailyReader(symbols=syms,
- start=date.today() - timedelta(days=5),
- end=date.today()
- )._range_string_from_date() == '5d'
- assert IEXDailyReader(symbols=syms,
- start=date.today() - timedelta(days=27),
- end=date.today()
- )._range_string_from_date() == '1m'
- assert IEXDailyReader(symbols=syms,
- start=date.today() - timedelta(days=83),
- end=date.today()
- )._range_string_from_date() == '3m'
- assert IEXDailyReader(symbols=syms,
- start=date.today() - timedelta(days=167),
- end=date.today()
- )._range_string_from_date() == '6m'
- assert IEXDailyReader(symbols=syms,
- start=date.today() - timedelta(days=170),
- end=date.today()
- )._range_string_from_date() == '1y'
- assert IEXDailyReader(symbols=syms,
- start=date.today() - timedelta(days=365),
- end=date.today()
- )._range_string_from_date() == '2y'
- assert IEXDailyReader(symbols=syms,
- start=date.today() - timedelta(days=730),
- end=date.today()
- )._range_string_from_date() == '5y'
- assert IEXDailyReader(symbols=syms,
- start=date.today() - timedelta(days=1826),
- end=date.today()
- )._range_string_from_date() == 'max'
+ assert (
+ IEXDailyReader(
+ symbols=syms, start=date.today() - timedelta(days=5), end=date.today()
+ )._range_string_from_date()
+ == "5d"
+ )
+ assert (
+ IEXDailyReader(
+ symbols=syms, start=date.today() - timedelta(days=27), end=date.today()
+ )._range_string_from_date()
+ == "1m"
+ )
+ assert (
+ IEXDailyReader(
+ symbols=syms, start=date.today() - timedelta(days=83), end=date.today()
+ )._range_string_from_date()
+ == "3m"
+ )
+ assert (
+ IEXDailyReader(
+ symbols=syms, start=date.today() - timedelta(days=167), end=date.today()
+ )._range_string_from_date()
+ == "6m"
+ )
+ assert (
+ IEXDailyReader(
+ symbols=syms, start=date.today() - timedelta(days=170), end=date.today()
+ )._range_string_from_date()
+ == "1y"
+ )
+ assert (
+ IEXDailyReader(
+ symbols=syms, start=date.today() - timedelta(days=365), end=date.today()
+ )._range_string_from_date()
+ == "2y"
+ )
+ assert (
+ IEXDailyReader(
+ symbols=syms, start=date.today() - timedelta(days=730), end=date.today()
+ )._range_string_from_date()
+ == "5y"
+ )
+ assert (
+ IEXDailyReader(
+ symbols=syms,
+ start=date.today() - timedelta(days=1826),
+ end=date.today(),
+ )._range_string_from_date()
+ == "max"
+ )
diff --git a/pandas_datareader/tests/test_moex.py b/pandas_datareader/tests/test_moex.py
index 5a26970b..5f46ac3e 100644
--- a/pandas_datareader/tests/test_moex.py
+++ b/pandas_datareader/tests/test_moex.py
@@ -1,6 +1,6 @@
import pytest
-
from requests.exceptions import HTTPError
+
import pandas_datareader.data as web
pytestmark = pytest.mark.stable
@@ -9,10 +9,9 @@
class TestMoex(object):
def test_moex_datareader(self):
try:
- df = web.DataReader("USD000UTSTOM",
- "moex",
- start="2017-07-01",
- end="2017-07-31")
- assert 'SECID' in df.columns
+ df = web.DataReader(
+ "USD000UTSTOM", "moex", start="2017-07-01", end="2017-07-31"
+ )
+ assert "SECID" in df.columns
except HTTPError as e:
pytest.skip(e)
diff --git a/pandas_datareader/tests/test_nasdaq.py b/pandas_datareader/tests/test_nasdaq.py
index d58e8869..659755c4 100644
--- a/pandas_datareader/tests/test_nasdaq.py
+++ b/pandas_datareader/tests/test_nasdaq.py
@@ -1,11 +1,10 @@
-import pandas_datareader.data as web
-from pandas_datareader._utils import RemoteDataError
from pandas_datareader._testing import skip_on_exception
+from pandas_datareader._utils import RemoteDataError
+import pandas_datareader.data as web
class TestNasdaqSymbols(object):
-
@skip_on_exception(RemoteDataError)
def test_get_symbols(self):
- symbols = web.DataReader('symbols', 'nasdaq')
- assert 'IBM' in symbols.index
+ symbols = web.DataReader("symbols", "nasdaq")
+ assert "IBM" in symbols.index
diff --git a/pandas_datareader/tests/test_oecd.py b/pandas_datareader/tests/test_oecd.py
index cd358a90..ca4a60ff 100644
--- a/pandas_datareader/tests/test_oecd.py
+++ b/pandas_datareader/tests/test_oecd.py
@@ -5,87 +5,210 @@
import pandas.util.testing as tm
import pytest
-import pandas_datareader.data as web
from pandas_datareader._utils import RemoteDataError
+import pandas_datareader.data as web
class TestOECD(object):
-
- @pytest.mark.xfail(reason='Incorrect URL')
+ @pytest.mark.xfail(reason="Incorrect URL")
def test_get_un_den(self):
- df = web.DataReader('TUD', 'oecd', start=datetime(1960, 1, 1),
- end=datetime(2012, 1, 1))
+ df = web.DataReader(
+ "TUD", "oecd", start=datetime(1960, 1, 1), end=datetime(2012, 1, 1)
+ )
- au = [50.17292785, 49.47181009, 49.52106174, 49.16341327,
- 48.19296375, 47.8863461, 45.83517292, 45.02021403,
- 44.78983834, 44.37794217, 44.15358142, 45.38865546,
- 46.33092037, 47.2343406, 48.80023876, 50.0639872,
- 50.23390644, 49.8214994, 49.67636585, 49.55227375,
- 48.48657368, 47.41179739, 47.52526561, 47.93048854,
- 47.26327162, 45.4617105, 43.90202112, 42.32759607,
- 40.35838899, 39.35157364, 39.55023059, 39.93212859,
- 39.21948472, 37.24343693, 34.42549573, 32.51172639,
- 31.16809569, 29.78835077, 28.14657769, 25.41970706,
- 25.71752984, 24.53108811, 23.21936888, 22.99140633,
- 22.29380238, 22.29160819, 20.22236326, 18.51151852,
- 18.56792804, 19.31219498, 18.44405734, 18.51105048,
- 18.19718895]
- jp = [32.32911392, 33.73688458, 34.5969919, 35.01871257,
- 35.46869345, 35.28164117, 34.749499, 34.40573103,
- 34.50762389, 35.16411379, 35.10284332, 34.57209848,
- 34.31168831, 33.46611342, 34.26450371, 34.53099287,
- 33.69881466, 32.99814274, 32.59541985, 31.75696594,
- 31.14832536, 30.8917513, 30.56612982, 29.75285171,
- 29.22391559, 28.79202411, 28.18680064, 27.71454381,
- 26.94358748, 26.13165206, 25.36711479, 24.78408637,
- 24.49892557, 24.34256055, 24.25324675, 23.96731902,
- 23.3953401, 22.78797997, 22.52794337, 22.18157944,
- 21.54406273, 20.88284597, 20.26073907, 19.73945642,
- 19.25116713, 18.79844243, 18.3497807, 18.25095057,
- 18.2204924, 18.45787546, 18.40380743, 18.99504195,
- 17.97238372]
- us = [30.89748411, 29.51891217, 29.34276869, 28.51337535,
- 28.30646144, 28.16661991, 28.19557735, 27.76578899,
- 27.9004622, 27.30836054, 27.43402867, 26.94941363,
- 26.25996487, 25.83134349, 25.74427582, 25.28771204,
- 24.38412814, 23.59186681, 23.94328194, 22.36400776,
- 22.06009466, 21.01328205, 20.47463895, 19.45290876,
- 18.22953818, 17.44855678, 17.00126975, 16.5162476,
- 16.24744487, 15.86401127, 15.45147174, 15.46986912,
- 15.1499578, 15.13654544, 14.91544059, 14.31762091,
- 14.02052225, 13.55213736, 13.39571457, 13.36670812,
- 12.84874656, 12.85719022, 12.63753733, 12.39142968,
- 12.02130767, 11.96023574, 11.48458378, 11.56435375,
- 11.91022276, 11.79401904, 11.38345975, 11.32948829,
- 10.81535229]
+ au = [
+ 50.17292785,
+ 49.47181009,
+ 49.52106174,
+ 49.16341327,
+ 48.19296375,
+ 47.8863461,
+ 45.83517292,
+ 45.02021403,
+ 44.78983834,
+ 44.37794217,
+ 44.15358142,
+ 45.38865546,
+ 46.33092037,
+ 47.2343406,
+ 48.80023876,
+ 50.0639872,
+ 50.23390644,
+ 49.8214994,
+ 49.67636585,
+ 49.55227375,
+ 48.48657368,
+ 47.41179739,
+ 47.52526561,
+ 47.93048854,
+ 47.26327162,
+ 45.4617105,
+ 43.90202112,
+ 42.32759607,
+ 40.35838899,
+ 39.35157364,
+ 39.55023059,
+ 39.93212859,
+ 39.21948472,
+ 37.24343693,
+ 34.42549573,
+ 32.51172639,
+ 31.16809569,
+ 29.78835077,
+ 28.14657769,
+ 25.41970706,
+ 25.71752984,
+ 24.53108811,
+ 23.21936888,
+ 22.99140633,
+ 22.29380238,
+ 22.29160819,
+ 20.22236326,
+ 18.51151852,
+ 18.56792804,
+ 19.31219498,
+ 18.44405734,
+ 18.51105048,
+ 18.19718895,
+ ]
+ jp = [
+ 32.32911392,
+ 33.73688458,
+ 34.5969919,
+ 35.01871257,
+ 35.46869345,
+ 35.28164117,
+ 34.749499,
+ 34.40573103,
+ 34.50762389,
+ 35.16411379,
+ 35.10284332,
+ 34.57209848,
+ 34.31168831,
+ 33.46611342,
+ 34.26450371,
+ 34.53099287,
+ 33.69881466,
+ 32.99814274,
+ 32.59541985,
+ 31.75696594,
+ 31.14832536,
+ 30.8917513,
+ 30.56612982,
+ 29.75285171,
+ 29.22391559,
+ 28.79202411,
+ 28.18680064,
+ 27.71454381,
+ 26.94358748,
+ 26.13165206,
+ 25.36711479,
+ 24.78408637,
+ 24.49892557,
+ 24.34256055,
+ 24.25324675,
+ 23.96731902,
+ 23.3953401,
+ 22.78797997,
+ 22.52794337,
+ 22.18157944,
+ 21.54406273,
+ 20.88284597,
+ 20.26073907,
+ 19.73945642,
+ 19.25116713,
+ 18.79844243,
+ 18.3497807,
+ 18.25095057,
+ 18.2204924,
+ 18.45787546,
+ 18.40380743,
+ 18.99504195,
+ 17.97238372,
+ ]
+ us = [
+ 30.89748411,
+ 29.51891217,
+ 29.34276869,
+ 28.51337535,
+ 28.30646144,
+ 28.16661991,
+ 28.19557735,
+ 27.76578899,
+ 27.9004622,
+ 27.30836054,
+ 27.43402867,
+ 26.94941363,
+ 26.25996487,
+ 25.83134349,
+ 25.74427582,
+ 25.28771204,
+ 24.38412814,
+ 23.59186681,
+ 23.94328194,
+ 22.36400776,
+ 22.06009466,
+ 21.01328205,
+ 20.47463895,
+ 19.45290876,
+ 18.22953818,
+ 17.44855678,
+ 17.00126975,
+ 16.5162476,
+ 16.24744487,
+ 15.86401127,
+ 15.45147174,
+ 15.46986912,
+ 15.1499578,
+ 15.13654544,
+ 14.91544059,
+ 14.31762091,
+ 14.02052225,
+ 13.55213736,
+ 13.39571457,
+ 13.36670812,
+ 12.84874656,
+ 12.85719022,
+ 12.63753733,
+ 12.39142968,
+ 12.02130767,
+ 11.96023574,
+ 11.48458378,
+ 11.56435375,
+ 11.91022276,
+ 11.79401904,
+ 11.38345975,
+ 11.32948829,
+ 10.81535229,
+ ]
- index = pd.date_range('1960-01-01', '2012-01-01', freq='AS',
- name='Time')
- for label, values in [('Australia', au), ('Japan', jp),
- ('United States', us)]:
+ index = pd.date_range("1960-01-01", "2012-01-01", freq="AS", name="Time")
+ for label, values in [("Australia", au), ("Japan", jp), ("United States", us)]:
expected = pd.Series(values, index=index, name=label)
tm.assert_series_equal(df[label], expected)
def test_get_tourism(self):
- df = web.DataReader('TOURISM_INBOUND', 'oecd',
- start=datetime(2008, 1, 1),
- end=datetime(2012, 1, 1))
+ df = web.DataReader(
+ "TOURISM_INBOUND",
+ "oecd",
+ start=datetime(2008, 1, 1),
+ end=datetime(2012, 1, 1),
+ )
- jp = np.array([8351000, 6790000, 8611000, 6219000,
- 8368000], dtype=float)
- us = np.array([175702309, 160507417, 164079732, 167600277,
- 171320408], dtype=float)
- index = pd.date_range('2008-01-01', '2012-01-01', freq='AS',
- name='Year')
- for label, values in [('Japan', jp), ('United States', us)]:
- expected = pd.Series(values, index=index,
- name='Total international arrivals')
- tm.assert_series_equal(df[label]['Total international arrivals'],
- expected)
+ jp = np.array([8351000, 6790000, 8611000, 6219000, 8368000], dtype=float)
+ us = np.array(
+ [175702309, 160507417, 164079732, 167600277, 171320408], dtype=float
+ )
+ index = pd.date_range("2008-01-01", "2012-01-01", freq="AS", name="Year")
+ for label, values in [("Japan", jp), ("United States", us)]:
+ expected = pd.Series(
+ values, index=index, name="Total international arrivals"
+ )
+ tm.assert_series_equal(df[label]["Total international arrivals"], expected)
def test_oecd_invalid_symbol(self):
with pytest.raises(RemoteDataError):
- web.DataReader('INVALID_KEY', 'oecd')
+ web.DataReader("INVALID_KEY", "oecd")
with pytest.raises(ValueError):
- web.DataReader(1234, 'oecd')
+ web.DataReader(1234, "oecd")
diff --git a/pandas_datareader/tests/test_robinhood.py b/pandas_datareader/tests/test_robinhood.py
index faab2000..8289e78e 100644
--- a/pandas_datareader/tests/test_robinhood.py
+++ b/pandas_datareader/tests/test_robinhood.py
@@ -2,14 +2,13 @@
import pandas as pd
import pytest
-from pandas_datareader.robinhood import RobinhoodQuoteReader, \
- RobinhoodHistoricalReader
+from pandas_datareader.robinhood import RobinhoodHistoricalReader, RobinhoodQuoteReader
-syms = ['GOOG', ['GOOG', 'AAPL']]
+syms = ["GOOG", ["GOOG", "AAPL"]]
ids = list(map(str, syms))
-@pytest.fixture(params=['GOOG', ['GOOG', 'AAPL']], ids=ids)
+@pytest.fixture(params=["GOOG", ["GOOG", "AAPL"]], ids=ids)
def symbols(request):
return request.param
@@ -26,7 +25,7 @@ def test_robinhood_quote(symbols):
@pytest.mark.xfail(reason="Deprecated")
def test_robinhood_quote_too_many():
syms = np.random.randint(65, 90, size=(10000, 4)).tolist()
- syms = list(map(lambda r: ''.join(map(chr, r)), syms))
+ syms = list(map(lambda r: "".join(map(chr, r)), syms))
syms = list(set(syms))
with pytest.raises(ValueError):
RobinhoodQuoteReader(symbols=syms)
@@ -35,7 +34,7 @@ def test_robinhood_quote_too_many():
@pytest.mark.xfail(reason="Deprecated")
def test_robinhood_historical_too_many():
syms = np.random.randint(65, 90, size=(10000, 4)).tolist()
- syms = list(map(lambda r: ''.join(map(chr, r)), syms))
+ syms = list(map(lambda r: "".join(map(chr, r)), syms))
syms = list(set(syms))
with pytest.raises(ValueError):
RobinhoodHistoricalReader(symbols=syms)
diff --git a/pandas_datareader/tests/test_stooq.py b/pandas_datareader/tests/test_stooq.py
index 1158b191..436d1dc3 100644
--- a/pandas_datareader/tests/test_stooq.py
+++ b/pandas_datareader/tests/test_stooq.py
@@ -7,35 +7,35 @@
def test_stooq_dji():
- f = web.DataReader('GS', 'stooq')
+ f = web.DataReader("GS", "stooq")
assert f.shape[0] > 0
def test_get_data_stooq_dji():
- f = get_data_stooq('AMZN')
+ f = get_data_stooq("AMZN")
assert f.shape[0] > 0
def test_get_data_stooq_dates():
- f = get_data_stooq('SPY', start='20180101', end='20180115')
+ f = get_data_stooq("SPY", start="20180101", end="20180115")
assert f.shape[0] == 9
def test_stooq_sp500():
- f = get_data_stooq('^SPX')
+ f = get_data_stooq("^SPX")
assert f.shape[0] > 0
def test_get_data_stooq_dax():
- f = get_data_stooq('^DAX')
+ f = get_data_stooq("^DAX")
assert f.shape[0] > 0
def test_stooq_googl():
- f = get_data_stooq('GOOGL.US')
+ f = get_data_stooq("GOOGL.US")
assert f.shape[0] > 0
def test_get_data_ibm():
- f = get_data_stooq('IBM.UK')
+ f = get_data_stooq("IBM.UK")
assert f.shape[0] > 0
diff --git a/pandas_datareader/tests/test_tsp.py b/pandas_datareader/tests/test_tsp.py
index d41f1afb..8ebc2998 100644
--- a/pandas_datareader/tests/test_tsp.py
+++ b/pandas_datareader/tests/test_tsp.py
@@ -9,19 +9,20 @@
class TestTSPFunds(object):
def test_get_allfunds(self):
- tspdata = tsp.TSPReader(start='2015-11-2', end='2015-11-2').read()
+ tspdata = tsp.TSPReader(start="2015-11-2", end="2015-11-2").read()
assert len(tspdata == 1)
- assert round(tspdata['I Fund'][dt.date(2015, 11, 2)], 5) == 25.0058
+ assert round(tspdata["I Fund"][dt.date(2015, 11, 2)], 5) == 25.0058
def test_sanitize_response(self):
class response(object):
pass
+
r = response()
- r.text = ' , '
+ r.text = " , "
ret = tsp.TSPReader._sanitize_response(r)
- assert ret == ''
- r.text = ' a,b '
+ assert ret == ""
+ r.text = " a,b "
ret = tsp.TSPReader._sanitize_response(r)
- assert ret == 'a,b'
+ assert ret == "a,b"
diff --git a/pandas_datareader/tests/test_wb.py b/pandas_datareader/tests/test_wb.py
index 05cf4638..f6e82b27 100644
--- a/pandas_datareader/tests/test_wb.py
+++ b/pandas_datareader/tests/test_wb.py
@@ -6,39 +6,42 @@
import pytest
import requests
-from pandas_datareader.compat import assert_raises_regex
-from pandas_datareader.wb import (search, download, get_countries,
- get_indicators, WorldBankReader)
-from pandas_datareader._utils import RemoteDataError
from pandas_datareader._testing import skip_on_exception
+from pandas_datareader._utils import RemoteDataError
+from pandas_datareader.compat import assert_raises_regex
+from pandas_datareader.wb import (
+ WorldBankReader,
+ download,
+ get_countries,
+ get_indicators,
+ search,
+)
class TestWB(object):
-
def test_wdi_search(self):
# Test that a name column exists, and that some results were returned
# ...without being too strict about what the actual contents of the
# results actually are. The fact that there are some, is good enough.
- result = search('gdp.*capita.*constant')
- assert result.name.str.contains('GDP').any()
+ result = search("gdp.*capita.*constant")
+ assert result.name.str.contains("GDP").any()
# check cache returns the results within 0.5 sec
current_time = time.time()
- result = search('gdp.*capita.*constant')
- assert result.name.str.contains('GDP').any()
+ result = search("gdp.*capita.*constant")
+ assert result.name.str.contains("GDP").any()
assert time.time() - current_time < 0.5
- result2 = WorldBankReader().search('gdp.*capita.*constant')
+ result2 = WorldBankReader().search("gdp.*capita.*constant")
session = requests.Session()
- result3 = search('gdp.*capita.*constant', session=session)
- result4 = WorldBankReader(session=session).search(
- 'gdp.*capita.*constant')
+ result3 = search("gdp.*capita.*constant", session=session)
+ result4 = WorldBankReader(session=session).search("gdp.*capita.*constant")
for result in [result2, result3, result4]:
- assert result.name.str.contains('GDP').any()
+ assert result.name.str.contains("GDP").any()
def test_wdi_download(self):
@@ -51,35 +54,45 @@ def test_wdi_download(self):
# own exceptions, and don't clean up legacy country codes.
# ...but NOT a retired indicator (user should want it to error).
- cntry_codes = ['CA', 'MX', 'USA', 'US', 'US', 'KSV', 'BLA']
- inds = ['NY.GDP.PCAP.CD', 'BAD.INDICATOR']
+ cntry_codes = ["CA", "MX", "USA", "US", "US", "KSV", "BLA"]
+ inds = ["NY.GDP.PCAP.CD", "BAD.INDICATOR"]
# These are the expected results, rounded (robust against
# data revisions in the future).
- expected = {'NY.GDP.PCAP.CD': {('Canada', '2004'): 32000.0,
- ('Canada', '2003'): 28000.0,
- ('Kosovo', '2004'): 2000.0,
- ('Kosovo', '2003'): 2000.0,
- ('Mexico', '2004'): 7000.0,
- ('Mexico', '2003'): 7000.0,
- ('United States', '2004'): 42000.0,
- ('United States', '2003'): 39000.0}}
+ expected = {
+ "NY.GDP.PCAP.CD": {
+ ("Canada", "2004"): 32000.0,
+ ("Canada", "2003"): 28000.0,
+ ("Kosovo", "2004"): 2000.0,
+ ("Kosovo", "2003"): 2000.0,
+ ("Mexico", "2004"): 7000.0,
+ ("Mexico", "2003"): 7000.0,
+ ("United States", "2004"): 42000.0,
+ ("United States", "2003"): 39000.0,
+ }
+ }
expected = pd.DataFrame(expected)
expected = expected.sort_index()
- result = download(country=cntry_codes, indicator=inds,
- start=2003, end=2004, errors='ignore')
+ result = download(
+ country=cntry_codes, indicator=inds, start=2003, end=2004, errors="ignore"
+ )
result = result.sort_index()
# Round, to ignore revisions to data.
result = np.round(result, decimals=-3)
- expected.index.names = ['country', 'year']
+ expected.index.names = ["country", "year"]
tm.assert_frame_equal(result, expected)
# pass start and end as string
- result = download(country=cntry_codes, indicator=inds,
- start='2003', end='2004', errors='ignore')
+ result = download(
+ country=cntry_codes,
+ indicator=inds,
+ start="2003",
+ end="2004",
+ errors="ignore",
+ )
result = result.sort_index()
# Round, to ignore revisions to data.
@@ -90,63 +103,80 @@ def test_wdi_download_str(self):
# These are the expected results, rounded (robust against
# data revisions in the future).
- expected = {'NY.GDP.PCAP.CD': {('Japan', '2004'): 38000.0,
- ('Japan', '2003'): 35000.0,
- ('Japan', '2002'): 32000.0,
- ('Japan', '2001'): 34000.0,
- ('Japan', '2000'): 39000.0}}
+ expected = {
+ "NY.GDP.PCAP.CD": {
+ ("Japan", "2004"): 38000.0,
+ ("Japan", "2003"): 35000.0,
+ ("Japan", "2002"): 32000.0,
+ ("Japan", "2001"): 34000.0,
+ ("Japan", "2000"): 39000.0,
+ }
+ }
expected = pd.DataFrame(expected)
expected = expected.sort_index()
- cntry_codes = 'JP'
- inds = 'NY.GDP.PCAP.CD'
- result = download(country=cntry_codes, indicator=inds,
- start=2000, end=2004, errors='ignore')
+ cntry_codes = "JP"
+ inds = "NY.GDP.PCAP.CD"
+ result = download(
+ country=cntry_codes, indicator=inds, start=2000, end=2004, errors="ignore"
+ )
result = result.sort_index()
result = np.round(result, decimals=-3)
- expected.index.names = ['country', 'year']
+ expected.index.names = ["country", "year"]
tm.assert_frame_equal(result, expected)
- result = WorldBankReader(inds, countries=cntry_codes,
- start=2000, end=2004, errors='ignore').read()
+ result = WorldBankReader(
+ inds, countries=cntry_codes, start=2000, end=2004, errors="ignore"
+ ).read()
result = result.sort_index()
result = np.round(result, decimals=-3)
tm.assert_frame_equal(result, expected)
def test_wdi_download_error_handling(self):
- cntry_codes = ['USA', 'XX']
- inds = 'NY.GDP.PCAP.CD'
+ cntry_codes = ["USA", "XX"]
+ inds = "NY.GDP.PCAP.CD"
msg = "Invalid Country Code\\(s\\): XX"
with assert_raises_regex(ValueError, msg):
- download(country=cntry_codes, indicator=inds,
- start=2003, end=2004, errors='raise')
+ download(
+ country=cntry_codes,
+ indicator=inds,
+ start=2003,
+ end=2004,
+ errors="raise",
+ )
with tm.assert_produces_warning():
- result = download(country=cntry_codes, indicator=inds,
- start=2003, end=2004, errors='warn')
+ result = download(
+ country=cntry_codes, indicator=inds, start=2003, end=2004, errors="warn"
+ )
assert isinstance(result, pd.DataFrame)
assert len(result), 2
- cntry_codes = ['USA']
- inds = ['NY.GDP.PCAP.CD', 'BAD_INDICATOR']
+ cntry_codes = ["USA"]
+ inds = ["NY.GDP.PCAP.CD", "BAD_INDICATOR"]
- msg = ("The provided parameter value is not valid\\. "
- "Indicator: BAD_INDICATOR")
+ msg = "The provided parameter value is not valid\\. " "Indicator: BAD_INDICATOR"
with assert_raises_regex(ValueError, msg):
- download(country=cntry_codes, indicator=inds,
- start=2003, end=2004, errors='raise')
+ download(
+ country=cntry_codes,
+ indicator=inds,
+ start=2003,
+ end=2004,
+ errors="raise",
+ )
with tm.assert_produces_warning():
- result = download(country=cntry_codes, indicator=inds,
- start=2003, end=2004, errors='warn')
+ result = download(
+ country=cntry_codes, indicator=inds, start=2003, end=2004, errors="warn"
+ )
assert isinstance(result, pd.DataFrame)
assert len(result) == 2
def test_wdi_download_w_retired_indicator(self):
- cntry_codes = ['CA', 'MX', 'US']
+ cntry_codes = ["CA", "MX", "US"]
# Despite showing up in the search feature, and being listed online,
# the api calls to GDPPCKD don't work in their own query builder, nor
# pandas module. GDPPCKD used to be a common symbol.
@@ -158,11 +188,16 @@ def test_wdi_download_w_retired_indicator(self):
# World bank ever finishes the deprecation of this symbol,
# this test should still pass.
- inds = ['GDPPCKD']
+ inds = ["GDPPCKD"]
with pytest.raises(ValueError):
- result = download(country=cntry_codes, indicator=inds,
- start=2003, end=2004, errors='ignore')
+ result = download(
+ country=cntry_codes,
+ indicator=inds,
+ start=2003,
+ end=2004,
+ errors="ignore",
+ )
# If it ever gets here, it means WB unretired the indicator.
# even if they dropped it completely, it would still
@@ -174,12 +209,17 @@ def test_wdi_download_w_retired_indicator(self):
def test_wdi_download_w_crash_inducing_countrycode(self):
- cntry_codes = ['CA', 'MX', 'US', 'XXX']
- inds = ['NY.GDP.PCAP.CD']
+ cntry_codes = ["CA", "MX", "US", "XXX"]
+ inds = ["NY.GDP.PCAP.CD"]
with pytest.raises(ValueError):
- result = download(country=cntry_codes, indicator=inds,
- start=2003, end=2004, errors='ignore')
+ result = download(
+ country=cntry_codes,
+ indicator=inds,
+ start=2003,
+ end=2004,
+ errors="ignore",
+ )
# If it ever gets here, it means the country code XXX
# got used by WB
@@ -197,7 +237,7 @@ def test_wdi_get_countries(self):
result4 = WorldBankReader(session=session).get_countries()
for result in [result1, result2, result3, result4]:
- assert 'Zimbabwe' in list(result['name'])
+ assert "Zimbabwe" in list(result["name"])
assert len(result) > 100
assert pd.notnull(result.latitude.mean())
assert pd.notnull(result.longitude.mean())
@@ -211,70 +251,101 @@ def test_wdi_get_indicators(self):
result4 = WorldBankReader(session=session).get_indicators()
for result in [result1, result2, result3, result4]:
- exp_col = pd.Index(['id', 'name', 'source', 'sourceNote',
- 'sourceOrganization', 'topics', 'unit'])
+ exp_col = pd.Index(
+ [
+ "id",
+ "name",
+ "source",
+ "sourceNote",
+ "sourceOrganization",
+ "topics",
+ "unit",
+ ]
+ )
# assert_index_equal doesn't exists
assert result.columns.equals(exp_col)
assert len(result) > 10000
@skip_on_exception(RemoteDataError)
def test_wdi_download_monthly(self):
- expected = {'COPPER': {('World', '2012M01'): 8040.47,
- ('World', '2011M12'): 7565.48,
- ('World', '2011M11'): 7581.02,
- ('World', '2011M10'): 7394.19,
- ('World', '2011M09'): 8300.14,
- ('World', '2011M08'): 9000.76,
- ('World', '2011M07'): 9650.46,
- ('World', '2011M06'): 9066.85,
- ('World', '2011M05'): 8959.90,
- ('World', '2011M04'): 9492.79,
- ('World', '2011M03'): 9503.36,
- ('World', '2011M02'): 9867.60,
- ('World', '2011M01'): 9555.70}}
+ expected = {
+ "COPPER": {
+ ("World", "2012M01"): 8040.47,
+ ("World", "2011M12"): 7565.48,
+ ("World", "2011M11"): 7581.02,
+ ("World", "2011M10"): 7394.19,
+ ("World", "2011M09"): 8300.14,
+ ("World", "2011M08"): 9000.76,
+ ("World", "2011M07"): 9650.46,
+ ("World", "2011M06"): 9066.85,
+ ("World", "2011M05"): 8959.90,
+ ("World", "2011M04"): 9492.79,
+ ("World", "2011M03"): 9503.36,
+ ("World", "2011M02"): 9867.60,
+ ("World", "2011M01"): 9555.70,
+ }
+ }
expected = pd.DataFrame(expected)
# Round, to ignore revisions to data.
expected = np.round(expected, decimals=-3)
expected = expected.sort_index()
- cntry_codes = 'ALL'
- inds = 'COPPER'
- result = download(country=cntry_codes, indicator=inds,
- start=2011, end=2012, freq='M', errors='ignore')
+ cntry_codes = "ALL"
+ inds = "COPPER"
+ result = download(
+ country=cntry_codes,
+ indicator=inds,
+ start=2011,
+ end=2012,
+ freq="M",
+ errors="ignore",
+ )
result = result.sort_index()
result = np.round(result, decimals=-3)
- expected.index.names = ['country', 'year']
+ expected.index.names = ["country", "year"]
tm.assert_frame_equal(result, expected)
- result = WorldBankReader(inds, countries=cntry_codes, start=2011,
- end=2012, freq='M', errors='ignore').read()
+ result = WorldBankReader(
+ inds, countries=cntry_codes, start=2011, end=2012, freq="M", errors="ignore"
+ ).read()
result = result.sort_index()
result = np.round(result, decimals=-3)
tm.assert_frame_equal(result, expected)
def test_wdi_download_quarterly(self):
- code = 'DT.DOD.PUBS.CD.US'
- expected = {code: {('Albania', '2012Q1'): 3240539817.18,
- ('Albania', '2011Q4'): 3213979715.15,
- ('Albania', '2011Q3'): 3187681048.95,
- ('Albania', '2011Q2'): 3248041513.86,
- ('Albania', '2011Q1'): 3137210567.92}}
+ code = "DT.DOD.PUBS.CD.US"
+ expected = {
+ code: {
+ ("Albania", "2012Q1"): 3240539817.18,
+ ("Albania", "2011Q4"): 3213979715.15,
+ ("Albania", "2011Q3"): 3187681048.95,
+ ("Albania", "2011Q2"): 3248041513.86,
+ ("Albania", "2011Q1"): 3137210567.92,
+ }
+ }
expected = pd.DataFrame(expected)
# Round, to ignore revisions to data.
expected = np.round(expected, decimals=-3)
expected = expected.sort_index()
- cntry_codes = 'ALB'
- inds = 'DT.DOD.PUBS.CD.US'
- result = download(country=cntry_codes, indicator=inds,
- start=2011, end=2012, freq='Q', errors='ignore')
+ cntry_codes = "ALB"
+ inds = "DT.DOD.PUBS.CD.US"
+ result = download(
+ country=cntry_codes,
+ indicator=inds,
+ start=2011,
+ end=2012,
+ freq="Q",
+ errors="ignore",
+ )
result = result.sort_index()
result = np.round(result, decimals=-3)
- expected.index.names = ['country', 'year']
+ expected.index.names = ["country", "year"]
tm.assert_frame_equal(result, expected)
- result = WorldBankReader(inds, countries=cntry_codes, start=2011,
- end=2012, freq='Q', errors='ignore').read()
+ result = WorldBankReader(
+ inds, countries=cntry_codes, start=2011, end=2012, freq="Q", errors="ignore"
+ ).read()
result = result.sort_index()
result = np.round(result, decimals=-1)
tm.assert_frame_equal(result, expected)
diff --git a/pandas_datareader/tests/yahoo/test_options.py b/pandas_datareader/tests/yahoo/test_options.py
index eae227c9..34cda1c1 100644
--- a/pandas_datareader/tests/yahoo/test_options.py
+++ b/pandas_datareader/tests/yahoo/test_options.py
@@ -1,18 +1,17 @@
-import os
from datetime import datetime
+import os
import numpy as np
import pandas as pd
-
-import pytest
import pandas.util.testing as tm
+import pytest
import pandas_datareader.data as web
@pytest.yield_fixture
def aapl():
- aapl = web.Options('aapl', 'yahoo')
+ aapl = web.Options("aapl", "yahoo")
yield aapl
aapl.close()
@@ -51,18 +50,16 @@ def expiry(month, year):
@pytest.fixture
def json1(datapath):
- dirpath = datapath('yahoo', 'data')
- json1 = 'file://' + os.path.join(
- dirpath, 'yahoo_options1.json')
+ dirpath = datapath("yahoo", "data")
+ json1 = "file://" + os.path.join(dirpath, "yahoo_options1.json")
return json1
@pytest.fixture
def json2(datapath):
# see gh-22: empty table
- dirpath = datapath('yahoo', 'data')
- json2 = 'file://' + os.path.join(
- dirpath, 'yahoo_options2.json')
+ dirpath = datapath("yahoo", "data")
+ json2 = "file://" + os.path.join(dirpath, "yahoo_options2.json")
return json2
@@ -72,9 +69,8 @@ def data1(aapl, json1):
class TestYahooOptions(object):
-
def setup_class(cls):
- pytest.skip('Skip all Yahoo! tests.')
+ pytest.skip("Skip all Yahoo! tests.")
def assert_option_result(self, df):
"""
@@ -83,16 +79,42 @@ def assert_option_result(self, df):
assert isinstance(df, pd.DataFrame)
assert len(df) > 1
- exp_columns = pd.Index(['Last', 'Bid', 'Ask', 'Chg', 'PctChg', 'Vol',
- 'Open_Int', 'IV', 'Root', 'IsNonstandard',
- 'Underlying', 'Underlying_Price', 'Quote_Time',
- 'Last_Trade_Date', 'JSON'])
+ exp_columns = pd.Index(
+ [
+ "Last",
+ "Bid",
+ "Ask",
+ "Chg",
+ "PctChg",
+ "Vol",
+ "Open_Int",
+ "IV",
+ "Root",
+ "IsNonstandard",
+ "Underlying",
+ "Underlying_Price",
+ "Quote_Time",
+ "Last_Trade_Date",
+ "JSON",
+ ]
+ )
tm.assert_index_equal(df.columns, exp_columns)
- assert df.index.names == [u'Strike', u'Expiry', u'Type', u'Symbol']
-
- dtypes = [np.dtype(x) for x in ['float64'] * 7 +
- ['float64', 'object', 'bool', 'object', 'float64',
- 'datetime64[ns]', 'datetime64[ns]', 'object']]
+ assert df.index.names == [u"Strike", u"Expiry", u"Type", u"Symbol"]
+
+ dtypes = [
+ np.dtype(x)
+ for x in ["float64"] * 7
+ + [
+ "float64",
+ "object",
+ "bool",
+ "object",
+ "float64",
+ "datetime64[ns]",
+ "datetime64[ns]",
+ "object",
+ ]
+ ]
tm.assert_series_equal(df.dtypes, pd.Series(dtypes, index=exp_columns))
def test_get_options_data(self, aapl, expiry):
@@ -107,25 +129,24 @@ def test_get_options_data(self, aapl, expiry):
self.assert_option_result(options)
def test_get_near_stock_price(self, aapl, expiry):
- options = aapl.get_near_stock_price(call=True, put=True,
- expiry=expiry)
+ options = aapl.get_near_stock_price(call=True, put=True, expiry=expiry)
self.assert_option_result(options)
def test_options_is_not_none(self):
- option = web.Options('aapl', 'yahoo')
+ option = web.Options("aapl", "yahoo")
assert option is not None
def test_get_call_data(self, aapl, expiry):
calls = aapl.get_call_data(expiry=expiry)
self.assert_option_result(calls)
- assert calls.index.levels[2][0] == 'call'
+ assert calls.index.levels[2][0] == "call"
def test_get_put_data(self, aapl, expiry):
puts = aapl.get_put_data(expiry=expiry)
self.assert_option_result(puts)
- assert puts.index.levels[2][1] == 'put'
+ assert puts.index.levels[2][1] == "put"
def test_get_expiry_dates(self, aapl):
dates = aapl._get_expiry_dates()
@@ -151,7 +172,7 @@ def test_get_all_data_calls_only(self, aapl):
def test_get_underlying_price(self, aapl):
# see gh-7
- options_object = web.Options('^spxpm', 'yahoo')
+ options_object = web.Options("^spxpm", "yahoo")
quote_price = options_object.underlying_price
assert isinstance(quote_price, float)
@@ -162,49 +183,44 @@ def test_get_underlying_price(self, aapl):
assert isinstance(price, (int, float, complex))
assert isinstance(quote_time, (datetime, pd.Timestamp))
- @pytest.mark.xfail(reason='Invalid URL scheme')
+ @pytest.mark.xfail(reason="Invalid URL scheme")
def test_chop(self, aapl, data1):
# gh-7625: regression test
- aapl._chop_data(data1, above_below=2,
- underlying_price=np.nan)
- chopped = aapl._chop_data(data1, above_below=2,
- underlying_price=100)
+ aapl._chop_data(data1, above_below=2, underlying_price=np.nan)
+ chopped = aapl._chop_data(data1, above_below=2, underlying_price=100)
assert isinstance(chopped, pd.DataFrame)
assert len(chopped) > 1
- chopped2 = aapl._chop_data(data1, above_below=2,
- underlying_price=None)
+ chopped2 = aapl._chop_data(data1, above_below=2, underlying_price=None)
assert isinstance(chopped2, pd.DataFrame)
assert len(chopped2) > 1
- @pytest.mark.xfail(reason='Invalid URL scheme')
+ @pytest.mark.xfail(reason="Invalid URL scheme")
def test_chop_out_of_strike_range(self, aapl, data1):
# gh-7625: regression test
- aapl._chop_data(data1, above_below=2,
- underlying_price=np.nan)
- chopped = aapl._chop_data(data1, above_below=2,
- underlying_price=100000)
+ aapl._chop_data(data1, above_below=2, underlying_price=np.nan)
+ chopped = aapl._chop_data(data1, above_below=2, underlying_price=100000)
assert isinstance(chopped, pd.DataFrame)
assert len(chopped) > 1
- @pytest.mark.xfail(reason='Invalid URL scheme')
+ @pytest.mark.xfail(reason="Invalid URL scheme")
def test_sample_page_chg_float(self, data1):
# Tests that numeric columns with comma's are appropriately dealt with
- assert data1['Chg'].dtype == 'float64'
+ assert data1["Chg"].dtype == "float64"
def test_month_year(self, aapl, month, year):
# see gh-168
data = aapl.get_call_data(month=month, year=year)
assert len(data) > 1
- assert data.index.levels[0].dtype == 'float64'
+ assert data.index.levels[0].dtype == "float64"
self.assert_option_result(data)
- @pytest.mark.xfail(reason='Invalid URL scheme')
+ @pytest.mark.xfail(reason="Invalid URL scheme")
def test_empty_table(self, aapl, json2):
# see gh-22
empty = aapl._process_data(aapl._parse_url(json2))
diff --git a/pandas_datareader/tests/yahoo/test_yahoo.py b/pandas_datareader/tests/yahoo/test_yahoo.py
index 9a886772..16b8b79c 100644
--- a/pandas_datareader/tests/yahoo/test_yahoo.py
+++ b/pandas_datareader/tests/yahoo/test_yahoo.py
@@ -1,25 +1,24 @@
from datetime import datetime
-import requests
import numpy as np
import pandas as pd
from pandas import DataFrame
-from requests.exceptions import ConnectionError
-import pytest
import pandas.util.testing as tm
+import pytest
+import requests
+from requests.exceptions import ConnectionError
+from pandas_datareader._testing import skip_on_exception
+from pandas_datareader._utils import RemoteDataError
import pandas_datareader.data as web
from pandas_datareader.data import YahooDailyReader
-from pandas_datareader._utils import RemoteDataError
-from pandas_datareader._testing import skip_on_exception
-XFAIL_REASON = 'Known connection failures on Yahoo when testing!'
+XFAIL_REASON = "Known connection failures on Yahoo when testing!"
pytestmark = pytest.mark.stable
class TestYahoo(object):
-
@classmethod
def setup_class(cls):
pytest.importorskip("lxml")
@@ -30,37 +29,36 @@ def test_yahoo(self):
start = datetime(2010, 1, 1)
end = datetime(2013, 1, 25)
- assert round(web.DataReader('F', 'yahoo', start, end)['Close'][-1],
- 2) == 13.68
+ assert round(web.DataReader("F", "yahoo", start, end)["Close"][-1], 2) == 13.68
def test_yahoo_fails(self):
start = datetime(2010, 1, 1)
end = datetime(2013, 1, 27)
with pytest.raises(Exception):
- web.DataReader('NON EXISTENT TICKER', 'yahoo', start, end)
+ web.DataReader("NON EXISTENT TICKER", "yahoo", start, end)
def test_get_quote_series(self):
- stringlist = ['GOOG', 'AAPL']
- fields = ['exchange', 'sharesOutstanding', 'epsForward']
+ stringlist = ["GOOG", "AAPL"]
+ fields = ["exchange", "sharesOutstanding", "epsForward"]
try:
- AAPL = web.get_quote_yahoo('AAPL')
+ AAPL = web.get_quote_yahoo("AAPL")
df = web.get_quote_yahoo(pd.Series(stringlist))
except ConnectionError:
pytest.xfail(reason=XFAIL_REASON)
- tm.assert_series_equal(AAPL.iloc[0][fields], df.loc['AAPL'][fields])
+ tm.assert_series_equal(AAPL.iloc[0][fields], df.loc["AAPL"][fields])
assert sorted(stringlist) == sorted(list(df.index.values))
def test_get_quote_string(self):
try:
- df = web.get_quote_yahoo('GOOG')
+ df = web.get_quote_yahoo("GOOG")
except ConnectionError:
pytest.xfail(reason=XFAIL_REASON)
- assert not pd.isnull(df['marketCap'][0])
+ assert not pd.isnull(df["marketCap"][0])
def test_get_quote_stringlist(self):
- stringlist = ['GOOG', 'AAPL']
+ stringlist = ["GOOG", "AAPL"]
try:
df = web.get_quote_yahoo(stringlist)
except ConnectionError:
@@ -69,44 +67,44 @@ def test_get_quote_stringlist(self):
def test_get_quote_comma_name(self):
try:
- df = web.get_quote_yahoo(['RGLD'])
+ df = web.get_quote_yahoo(["RGLD"])
except ConnectionError:
pytest.xfail(reason=XFAIL_REASON)
- assert df['longName'][0] == 'Royal Gold, Inc.'
+ assert df["longName"][0] == "Royal Gold, Inc."
- @pytest.mark.skip('Unreliable test, receive partial '
- 'components back for dow_jones')
+ @pytest.mark.skip(
+ "Unreliable test, receive partial " "components back for dow_jones"
+ )
def test_get_components_dow_jones(self): # pragma: no cover
- df = web.get_components_yahoo('^DJI') # Dow Jones
+ df = web.get_components_yahoo("^DJI") # Dow Jones
assert isinstance(df, pd.DataFrame)
assert len(df) == 30
- @pytest.mark.skip('Unreliable test, receive partial '
- 'components back for dax')
+ @pytest.mark.skip("Unreliable test, receive partial " "components back for dax")
def test_get_components_dax(self): # pragma: no cover
- df = web.get_components_yahoo('^GDAXI') # DAX
+ df = web.get_components_yahoo("^GDAXI") # DAX
assert isinstance(df, pd.DataFrame)
assert len(df) == 30
- assert df[df.name.str.contains('adidas', case=False)].index == 'ADS.DE'
+ assert df[df.name.str.contains("adidas", case=False)].index == "ADS.DE"
- @pytest.mark.skip('Unreliable test, receive partial '
- 'components back for nasdaq_100')
+ @pytest.mark.skip(
+ "Unreliable test, receive partial " "components back for nasdaq_100"
+ )
def test_get_components_nasdaq_100(self): # pragma: no cover
# As of 7/12/13, the conditional will
# return false because the link is invalid
- df = web.get_components_yahoo('^NDX') # NASDAQ-100
+ df = web.get_components_yahoo("^NDX") # NASDAQ-100
assert isinstance(df, pd.DataFrame)
if len(df) > 1:
# Usual culprits, should be around for a while
- assert 'AAPL' in df.index
- assert 'GOOG' in df.index
- assert 'AMZN' in df.index
+ assert "AAPL" in df.index
+ assert "GOOG" in df.index
+ assert "AMZN" in df.index
else:
- expected = DataFrame({'exchange': 'N/A', 'name': '@^NDX'},
- index=['@^NDX'])
+ expected = DataFrame({"exchange": "N/A", "name": "@^NDX"}, index=["@^NDX"])
tm.assert_frame_equal(df, expected)
@skip_on_exception(RemoteDataError)
@@ -114,86 +112,84 @@ def test_get_data_single_symbol(self):
# single symbol
# http://finance.yahoo.com/q/hp?s=GOOG&a=09&b=08&c=2010&d=09&e=10&f=2010&g=d
# just test that we succeed
- web.get_data_yahoo('GOOG')
+ web.get_data_yahoo("GOOG")
@skip_on_exception(RemoteDataError)
def test_data_with_no_actions(self):
- web.get_data_yahoo('TSLA')
+ web.get_data_yahoo("TSLA")
@skip_on_exception(RemoteDataError)
def test_get_data_adjust_price(self):
- goog = web.get_data_yahoo('GOOG')
- goog_adj = web.get_data_yahoo('GOOG', adjust_price=True)
- assert 'Adj Close' not in goog_adj.columns
- assert (goog['Open'] * goog_adj['Adj_Ratio']).equals(goog_adj['Open'])
+ goog = web.get_data_yahoo("GOOG")
+ goog_adj = web.get_data_yahoo("GOOG", adjust_price=True)
+ assert "Adj Close" not in goog_adj.columns
+ assert (goog["Open"] * goog_adj["Adj_Ratio"]).equals(goog_adj["Open"])
@pytest.mark.xfail(reason="Yahoo are returning an extra day 31st Dec 2012")
def test_get_data_interval(self):
# daily interval data
- pan = web.get_data_yahoo('XOM', '2013-01-01',
- '2013-12-31', interval='d')
+ pan = web.get_data_yahoo("XOM", "2013-01-01", "2013-12-31", interval="d")
assert len(pan) == 252
# weekly interval data
- pan = web.get_data_yahoo('XOM', '2013-01-01',
- '2013-12-31', interval='w')
+ pan = web.get_data_yahoo("XOM", "2013-01-01", "2013-12-31", interval="w")
assert len(pan) == 53
# monthly interval data
- pan = web.get_data_yahoo('XOM', '2012-12-31',
- '2013-12-31', interval='m')
+ pan = web.get_data_yahoo("XOM", "2012-12-31", "2013-12-31", interval="m")
assert len(pan) == 12
# test fail on invalid interval
with pytest.raises(ValueError):
- web.get_data_yahoo('XOM', interval='NOT VALID')
+ web.get_data_yahoo("XOM", interval="NOT VALID")
@skip_on_exception(RemoteDataError)
def test_get_data_multiple_symbols(self):
# just test that we succeed
- sl = ['AAPL', 'AMZN', 'GOOG']
- web.get_data_yahoo(sl, '2012')
+ sl = ["AAPL", "AMZN", "GOOG"]
+ web.get_data_yahoo(sl, "2012")
- @pytest.mark.parametrize('adj_pr', [True, False])
+ @pytest.mark.parametrize("adj_pr", [True, False])
@skip_on_exception(RemoteDataError)
def test_get_data_null_as_missing_data(self, adj_pr):
- result = web.get_data_yahoo('SRCE', '20160626', '20160705',
- adjust_price=adj_pr)
+ result = web.get_data_yahoo("SRCE", "20160626", "20160705", adjust_price=adj_pr)
# sanity checking
- floats = ['Open', 'High', 'Low', 'Close']
+ floats = ["Open", "High", "Low", "Close"]
if adj_pr:
- floats.append('Adj_Ratio')
+ floats.append("Adj_Ratio")
else:
- floats.append('Adj Close')
+ floats.append("Adj Close")
assert result[floats].dtypes.all() == np.floating
@skip_on_exception(RemoteDataError)
def test_get_data_multiple_symbols_two_dates(self):
- data = web.get_data_yahoo(['GE', 'MSFT', 'INTC'], 'JAN-01-12',
- 'JAN-31-12')
- result = data.Close.loc['01-18-12'].T
+ data = web.get_data_yahoo(["GE", "MSFT", "INTC"], "JAN-01-12", "JAN-31-12")
+ result = data.Close.loc["01-18-12"].T
assert result.size == 3
# sanity checking
assert result.dtypes == np.floating
- expected = np.array([[18.99, 28.4, 25.18],
- [18.58, 28.31, 25.13],
- [19.03, 28.16, 25.52],
- [18.81, 28.82, 25.87]])
+ expected = np.array(
+ [
+ [18.99, 28.4, 25.18],
+ [18.58, 28.31, 25.13],
+ [19.03, 28.16, 25.52],
+ [18.81, 28.82, 25.87],
+ ]
+ )
df = data.Open
- result = df[(df.index >= 'Jan-15-12') & (df.index <= 'Jan-20-12')]
+ result = df[(df.index >= "Jan-15-12") & (df.index <= "Jan-20-12")]
assert expected.shape == result.shape
def test_get_date_ret_index(self):
- pan = web.get_data_yahoo(['GE', 'INTC', 'IBM'], '1977', '1987',
- ret_index=True)
- assert hasattr(pan, 'Ret_Index')
+ pan = web.get_data_yahoo(["GE", "INTC", "IBM"], "1977", "1987", ret_index=True)
+ assert hasattr(pan, "Ret_Index")
- if hasattr(pan, 'Ret_Index') and hasattr(pan.Ret_Index, 'INTC'):
+ if hasattr(pan, "Ret_Index") and hasattr(pan.Ret_Index, "INTC"):
tstamp = pan.Ret_Index.INTC.first_valid_index()
- result = pan.Ret_Index.loc[tstamp, 'INTC']
+ result = pan.Ret_Index.loc[tstamp, "INTC"]
assert result == 1.0
# sanity checking
@@ -203,83 +199,142 @@ def test_get_data_yahoo_actions(self):
start = datetime(1990, 1, 1)
end = datetime(2018, 4, 5)
- actions = web.get_data_yahoo_actions('AAPL', start, end,
- adjust_dividends=False)
+ actions = web.get_data_yahoo_actions("AAPL", start, end, adjust_dividends=False)
- assert sum(actions['action'] == 'DIVIDEND') == 47
- assert sum(actions['action'] == 'SPLIT') == 3
+ assert sum(actions["action"] == "DIVIDEND") == 47
+ assert sum(actions["action"] == "SPLIT") == 3
- assert actions.loc['2005-02-28', 'action'][0] == 'SPLIT'
- assert actions.loc['2005-02-28', 'value'][0] == 1/2.0
+ assert actions.loc["2005-02-28", "action"][0] == "SPLIT"
+ assert actions.loc["2005-02-28", "value"][0] == 1 / 2.0
- assert actions.loc['1995-11-21', 'action'][0] == 'DIVIDEND'
- assert round(actions.loc['1995-11-21', 'value'][0], 3) == 0.120
+ assert actions.loc["1995-11-21", "action"][0] == "DIVIDEND"
+ assert round(actions.loc["1995-11-21", "value"][0], 3) == 0.120
- actions = web.get_data_yahoo_actions('AAPL', start, end,
- adjust_dividends=True)
+ actions = web.get_data_yahoo_actions("AAPL", start, end, adjust_dividends=True)
- assert actions.loc['1995-11-21', 'action'][0] == 'DIVIDEND'
- assert round(actions.loc['1995-11-21', 'value'][0], 4) == 0.0043
+ assert actions.loc["1995-11-21", "action"][0] == "DIVIDEND"
+ assert round(actions.loc["1995-11-21", "value"][0], 4) == 0.0043
def test_get_data_yahoo_actions_invalid_symbol(self):
start = datetime(1990, 1, 1)
end = datetime(2000, 4, 5)
with pytest.raises(IOError):
- web.get_data_yahoo_actions('UNKNOWN TICKER', start, end)
+ web.get_data_yahoo_actions("UNKNOWN TICKER", start, end)
@skip_on_exception(RemoteDataError)
def test_yahoo_reader_class(self):
- r = YahooDailyReader('GOOG')
+ r = YahooDailyReader("GOOG")
df = r.read()
- assert df.Volume.loc['JAN-02-2015'] == 1447500
+ assert df.Volume.loc["JAN-02-2015"] == 1447500
session = requests.Session()
- r = YahooDailyReader('GOOG', session=session)
+ r = YahooDailyReader("GOOG", session=session)
assert r.session is session
def test_yahoo_DataReader(self):
start = datetime(2010, 1, 1)
end = datetime(2015, 5, 9)
# yahoo will adjust for dividends by default
- result = web.DataReader('AAPL', 'yahoo-actions', start, end)
-
- exp_idx = pd.DatetimeIndex(['2015-05-07', '2015-02-05',
- '2014-11-06', '2014-08-07',
- '2014-06-09', '2014-05-08',
- '2014-02-06', '2013-11-06',
- '2013-08-08', '2013-05-09',
- '2013-02-07', '2012-11-07',
- '2012-08-09'])
-
- exp = pd.DataFrame({'action': ['DIVIDEND', 'DIVIDEND', 'DIVIDEND',
- 'DIVIDEND', 'SPLIT', 'DIVIDEND',
- 'DIVIDEND', 'DIVIDEND',
- 'DIVIDEND', 'DIVIDEND', 'DIVIDEND',
- 'DIVIDEND', 'DIVIDEND'],
- 'value': [0.52, 0.47, 0.47, 0.47, 0.14285714,
- 0.47, 0.43571, 0.43571, 0.43571,
- 0.43571, 0.37857, 0.37857, 0.37857]},
- index=exp_idx)
- exp.index.name = 'Date'
+ result = web.DataReader("AAPL", "yahoo-actions", start, end)
+
+ exp_idx = pd.DatetimeIndex(
+ [
+ "2015-05-07",
+ "2015-02-05",
+ "2014-11-06",
+ "2014-08-07",
+ "2014-06-09",
+ "2014-05-08",
+ "2014-02-06",
+ "2013-11-06",
+ "2013-08-08",
+ "2013-05-09",
+ "2013-02-07",
+ "2012-11-07",
+ "2012-08-09",
+ ]
+ )
+
+ exp = pd.DataFrame(
+ {
+ "action": [
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "SPLIT",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ ],
+ "value": [
+ 0.52,
+ 0.47,
+ 0.47,
+ 0.47,
+ 0.14285714,
+ 0.47,
+ 0.43571,
+ 0.43571,
+ 0.43571,
+ 0.43571,
+ 0.37857,
+ 0.37857,
+ 0.37857,
+ ],
+ },
+ index=exp_idx,
+ )
+ exp.index.name = "Date"
tm.assert_frame_equal(result.reindex_like(exp).round(2), exp.round(2))
# where dividends are not adjusted for splits
- result = web.get_data_yahoo_actions('AAPL', start, end,
- adjust_dividends=False)
-
- exp = pd.DataFrame({'action': ['DIVIDEND', 'DIVIDEND', 'DIVIDEND',
- 'DIVIDEND', 'SPLIT', 'DIVIDEND',
- 'DIVIDEND', 'DIVIDEND',
- 'DIVIDEND', 'DIVIDEND', 'DIVIDEND',
- 'DIVIDEND', 'DIVIDEND'],
- 'value': [0.52, 0.47, 0.47, 0.47, 0.14285714,
- 3.29, 3.05, 3.05, 3.05,
- 3.05, 2.65, 2.65, 2.65]},
- index=exp_idx)
- exp.index.name = 'Date'
+ result = web.get_data_yahoo_actions("AAPL", start, end, adjust_dividends=False)
+
+ exp = pd.DataFrame(
+ {
+ "action": [
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "SPLIT",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ "DIVIDEND",
+ ],
+ "value": [
+ 0.52,
+ 0.47,
+ 0.47,
+ 0.47,
+ 0.14285714,
+ 3.29,
+ 3.05,
+ 3.05,
+ 3.05,
+ 3.05,
+ 2.65,
+ 2.65,
+ 2.65,
+ ],
+ },
+ index=exp_idx,
+ )
+ exp.index.name = "Date"
tm.assert_frame_equal(result.reindex_like(exp).round(4), exp.round(4))
# test cases with "1/0" split ratio in actions -
@@ -287,22 +342,25 @@ def test_yahoo_DataReader(self):
start = datetime(2017, 12, 30)
end = datetime(2018, 12, 30)
- result = web.DataReader('NTR', 'yahoo-actions', start, end)
+ result = web.DataReader("NTR", "yahoo-actions", start, end)
- exp_idx = pd.DatetimeIndex(['2018-12-28', '2018-09-27',
- '2018-06-28', '2018-03-28',
- '2018-01-02'])
+ exp_idx = pd.DatetimeIndex(
+ ["2018-12-28", "2018-09-27", "2018-06-28", "2018-03-28", "2018-01-02"]
+ )
- exp = pd.DataFrame({'action': ['DIVIDEND', 'DIVIDEND', 'DIVIDEND',
- 'DIVIDEND', 'SPLIT'],
- 'value': [0.43, 0.40, 0.40, 0.40, 1.00]},
- index=exp_idx)
- exp.index.name = 'Date'
+ exp = pd.DataFrame(
+ {
+ "action": ["DIVIDEND", "DIVIDEND", "DIVIDEND", "DIVIDEND", "SPLIT"],
+ "value": [0.43, 0.40, 0.40, 0.40, 1.00],
+ },
+ index=exp_idx,
+ )
+ exp.index.name = "Date"
tm.assert_frame_equal(result.reindex_like(exp).round(2), exp.round(2))
@skip_on_exception(RemoteDataError)
def test_yahoo_DataReader_multi(self):
start = datetime(2010, 1, 1)
end = datetime(2015, 5, 9)
- result = web.DataReader(['AAPL', 'F'], 'yahoo-actions', start, end)
+ result = web.DataReader(["AAPL", "F"], "yahoo-actions", start, end)
assert isinstance(result, dict)
diff --git a/pandas_datareader/tiingo.py b/pandas_datareader/tiingo.py
index 7b9edc7c..f40e77c4 100644
--- a/pandas_datareader/tiingo.py
+++ b/pandas_datareader/tiingo.py
@@ -19,7 +19,7 @@ def get_tiingo_symbols():
-----
Reads https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip
"""
- url = 'https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip'
+ url = "https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip"
return pd.read_csv(url)
@@ -54,52 +54,69 @@ class TiingoIEXHistoricalReader(_BaseReader):
TIINGO_API_KEY is read. The API key is *required*.
"""
- def __init__(self, symbols, start=None, end=None, retry_count=3, pause=0.1,
- timeout=30, session=None, freq=None, api_key=None):
- super().__init__(symbols, start, end, retry_count, pause, timeout,
- session, freq)
+ def __init__(
+ self,
+ symbols,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ timeout=30,
+ session=None,
+ freq=None,
+ api_key=None,
+ ):
+ super().__init__(
+ symbols, start, end, retry_count, pause, timeout, session, freq
+ )
if isinstance(self.symbols, str):
self.symbols = [self.symbols]
- self._symbol = ''
+ self._symbol = ""
if api_key is None:
- api_key = os.getenv('TIINGO_API_KEY')
+ api_key = os.getenv("TIINGO_API_KEY")
if not api_key or not isinstance(api_key, str):
- raise ValueError('The tiingo API key must be provided either '
- 'through the api_key variable or through the '
- 'environmental variable TIINGO_API_KEY.')
+ raise ValueError(
+ "The tiingo API key must be provided either "
+ "through the api_key variable or through the "
+ "environmental variable TIINGO_API_KEY."
+ )
self.api_key = api_key
self._concat_axis = 0
@property
def url(self):
"""API URL"""
- _url = 'https://api.tiingo.com/iex/{ticker}/prices'
+ _url = "https://api.tiingo.com/iex/{ticker}/prices"
return _url.format(ticker=self._symbol)
@property
def params(self):
"""Parameters to use in API calls"""
- return {'startDate': self.start.strftime('%Y-%m-%d'),
- 'endDate': self.end.strftime('%Y-%m-%d'),
- 'resampleFreq': self.freq,
- 'format': 'json'}
+ return {
+ "startDate": self.start.strftime("%Y-%m-%d"),
+ "endDate": self.end.strftime("%Y-%m-%d"),
+ "resampleFreq": self.freq,
+ "format": "json",
+ }
def _get_crumb(self, *args):
pass
def _read_one_data(self, url, params):
""" read one data from specified URL """
- headers = {'Content-Type': 'application/json',
- 'Authorization': 'Token ' + self.api_key}
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": "Token " + self.api_key,
+ }
out = self._get_response(url, params=params, headers=headers).json()
return self._read_lines(out)
def _read_lines(self, out):
df = pd.DataFrame(out)
- df['symbol'] = self._symbol
- df['date'] = pd.to_datetime(df['date'])
- df = df.set_index(['symbol', 'date'])
+ df["symbol"] = self._symbol
+ df["date"] = pd.to_datetime(df["date"])
+ df = df.set_index(["symbol", "date"])
return df
def read(self):
@@ -140,51 +157,67 @@ class TiingoDailyReader(_BaseReader):
TIINGO_API_KEY is read. The API key is *required*.
"""
- def __init__(self, symbols, start=None, end=None, retry_count=3, pause=0.1,
- timeout=30, session=None, freq=None, api_key=None):
- super(TiingoDailyReader, self).__init__(symbols, start, end,
- retry_count, pause, timeout,
- session, freq)
+ def __init__(
+ self,
+ symbols,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ timeout=30,
+ session=None,
+ freq=None,
+ api_key=None,
+ ):
+ super(TiingoDailyReader, self).__init__(
+ symbols, start, end, retry_count, pause, timeout, session, freq
+ )
if isinstance(self.symbols, str):
self.symbols = [self.symbols]
- self._symbol = ''
+ self._symbol = ""
if api_key is None:
- api_key = os.getenv('TIINGO_API_KEY')
+ api_key = os.getenv("TIINGO_API_KEY")
if not api_key or not isinstance(api_key, str):
- raise ValueError('The tiingo API key must be provided either '
- 'through the api_key variable or through the '
- 'environmental variable TIINGO_API_KEY.')
+ raise ValueError(
+ "The tiingo API key must be provided either "
+ "through the api_key variable or through the "
+ "environmental variable TIINGO_API_KEY."
+ )
self.api_key = api_key
self._concat_axis = 0
@property
def url(self):
"""API URL"""
- _url = 'https://api.tiingo.com/tiingo/daily/{ticker}/prices'
+ _url = "https://api.tiingo.com/tiingo/daily/{ticker}/prices"
return _url.format(ticker=self._symbol)
@property
def params(self):
"""Parameters to use in API calls"""
- return {'startDate': self.start.strftime('%Y-%m-%d'),
- 'endDate': self.end.strftime('%Y-%m-%d'),
- 'format': 'json'}
+ return {
+ "startDate": self.start.strftime("%Y-%m-%d"),
+ "endDate": self.end.strftime("%Y-%m-%d"),
+ "format": "json",
+ }
def _get_crumb(self, *args):
pass
def _read_one_data(self, url, params):
""" read one data from specified URL """
- headers = {'Content-Type': 'application/json',
- 'Authorization': 'Token ' + self.api_key}
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": "Token " + self.api_key,
+ }
out = self._get_response(url, params=params, headers=headers).json()
return self._read_lines(out)
def _read_lines(self, out):
df = pd.DataFrame(out)
- df['symbol'] = self._symbol
- df['date'] = pd.to_datetime(df['date'])
- df = df.set_index(['symbol', 'date'])
+ df["symbol"] = self._symbol
+ df["date"] = pd.to_datetime(df["date"])
+ df = df.set_index(["symbol", "date"])
return df
def read(self):
@@ -224,17 +257,27 @@ class TiingoMetaDataReader(TiingoDailyReader):
TIINGO_API_KEY is read. The API key is *required*.
"""
- def __init__(self, symbols, start=None, end=None, retry_count=3, pause=0.1,
- timeout=30, session=None, freq=None, api_key=None):
- super(TiingoMetaDataReader, self).__init__(symbols, start, end,
- retry_count, pause, timeout,
- session, freq, api_key)
+ def __init__(
+ self,
+ symbols,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ timeout=30,
+ session=None,
+ freq=None,
+ api_key=None,
+ ):
+ super(TiingoMetaDataReader, self).__init__(
+ symbols, start, end, retry_count, pause, timeout, session, freq, api_key
+ )
self._concat_axis = 1
@property
def url(self):
"""API URL"""
- _url = 'https://api.tiingo.com/tiingo/daily/{ticker}'
+ _url = "https://api.tiingo.com/tiingo/daily/{ticker}"
return _url.format(ticker=self._symbol)
@property
diff --git a/pandas_datareader/tsp.py b/pandas_datareader/tsp.py
index 108a4826..d89caf43 100644
--- a/pandas_datareader/tsp.py
+++ b/pandas_datareader/tsp.py
@@ -26,21 +26,29 @@ class TSPReader(_BaseReader):
requests.sessions.Session instance to be used
"""
- def __init__(self,
- symbols=('Linc', 'L2020', 'L2030', 'L2040',
- 'L2050', 'G', 'F', 'C', 'S', 'I'),
- start=None, end=None, retry_count=3, pause=0.1,
- session=None):
- super(TSPReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
- self._format = 'string'
+ def __init__(
+ self,
+ symbols=("Linc", "L2020", "L2030", "L2040", "L2050", "G", "F", "C", "S", "I"),
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ ):
+ super(TSPReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
+ self._format = "string"
@property
def url(self):
"""API URL"""
- return 'https://www.tsp.gov/InvestmentFunds/FundPerformance/index.html'
+ return "https://www.tsp.gov/InvestmentFunds/FundPerformance/index.html"
def read(self):
""" read one data from specified URL """
@@ -51,10 +59,12 @@ def read(self):
@property
def params(self):
"""Parameters to use in API calls"""
- return {'startdate': self.start.strftime('%m/%d/%Y'),
- 'enddate': self.end.strftime('%m/%d/%Y'),
- 'fundgroup': self.symbols,
- 'whichButton': 'CSV'}
+ return {
+ "startdate": self.start.strftime("%m/%d/%Y"),
+ "enddate": self.end.strftime("%m/%d/%Y"),
+ "fundgroup": self.symbols,
+ "whichButton": "CSV",
+ }
@staticmethod
def _sanitize_response(response):
@@ -62,6 +72,6 @@ def _sanitize_response(response):
Clean up the response string
"""
text = response.text.strip()
- if text[-1] == ',':
+ if text[-1] == ",":
return text[0:-1]
return text
diff --git a/pandas_datareader/wb.py b/pandas_datareader/wb.py
index ecd583d7..c514f6b1 100644
--- a/pandas_datareader/wb.py
+++ b/pandas_datareader/wb.py
@@ -2,11 +2,11 @@
import warnings
-from pandas_datareader.compat import reduce, lrange, string_types
-import pandas as pd
import numpy as np
+import pandas as pd
from pandas_datareader.base import _BaseReader
+from pandas_datareader.compat import lrange, reduce, string_types
# This list of country codes was pulled from wikipedia during October 2014.
# While some exceptions do exist, it is the best proxy for countries supported
@@ -14,65 +14,511 @@
# 3-digit ISO 3166-1 alpha-3, codes, with 'all', 'ALL', and 'All' appended ot
# the end.
-WB_API_URL = 'https://api.worldbank.org/v2'
-
-country_codes = ['AD', 'AE', 'AF', 'AG', 'AI', 'AL', 'AM', 'AO', 'AQ', 'AR',
- 'AS', 'AT', 'AU', 'AW', 'AX', 'AZ', 'BA', 'BB', 'BD', 'BE',
- 'BF', 'BG', 'BH', 'BI', 'BJ', 'BL', 'BM', 'BN', 'BO', 'BQ',
- 'BR', 'BS', 'BT', 'BV', 'BW', 'BY', 'BZ', 'CA', 'CC', 'CD',
- 'CF', 'CG', 'CH', 'CI', 'CK', 'CL', 'CM', 'CN', 'CO', 'CR',
- 'CU', 'CV', 'CW', 'CX', 'CY', 'CZ', 'DE', 'DJ', 'DK', 'DM',
- 'DO', 'DZ', 'EC', 'EE', 'EG', 'EH', 'ER', 'ES', 'ET', 'FI',
- 'FJ', 'FK', 'FM', 'FO', 'FR', 'GA', 'GB', 'GD', 'GE', 'GF',
- 'GG', 'GH', 'GI', 'GL', 'GM', 'GN', 'GP', 'GQ', 'GR', 'GS',
- 'GT', 'GU', 'GW', 'GY', 'HK', 'HM', 'HN', 'HR', 'HT', 'HU',
- 'ID', 'IE', 'IL', 'IM', 'IN', 'IO', 'IQ', 'IR', 'IS', 'IT',
- 'JE', 'JM', 'JO', 'JP', 'KE', 'KG', 'KH', 'KI', 'KM', 'KN',
- 'KP', 'KR', 'KW', 'KY', 'KZ', 'LA', 'LB', 'LC', 'LI', 'LK',
- 'LR', 'LS', 'LT', 'LU', 'LV', 'LY', 'MA', 'MC', 'MD', 'ME',
- 'MF', 'MG', 'MH', 'MK', 'ML', 'MM', 'MN', 'MO', 'MP', 'MQ',
- 'MR', 'MS', 'MT', 'MU', 'MV', 'MW', 'MX', 'MY', 'MZ', 'NA',
- 'NC', 'NE', 'NF', 'NG', 'NI', 'NL', 'NO', 'NP', 'NR', 'NU',
- 'NZ', 'OM', 'PA', 'PE', 'PF', 'PG', 'PH', 'PK', 'PL', 'PM',
- 'PN', 'PR', 'PS', 'PT', 'PW', 'PY', 'QA', 'RE', 'RO', 'RS',
- 'RU', 'RW', 'SA', 'SB', 'SC', 'SD', 'SE', 'SG', 'SH', 'SI',
- 'SJ', 'SK', 'SL', 'SM', 'SN', 'SO', 'SR', 'SS', 'ST', 'SV',
- 'SX', 'SY', 'SZ', 'TC', 'TD', 'TF', 'TG', 'TH', 'TJ', 'TK',
- 'TL', 'TM', 'TN', 'TO', 'TR', 'TT', 'TV', 'TW', 'TZ', 'UA',
- 'UG', 'UM', 'US', 'UY', 'UZ', 'VA', 'VC', 'VE', 'VG', 'VI',
- 'VN', 'VU', 'WF', 'WS', 'YE', 'YT', 'ZA', 'ZM', 'ZW',
- 'ABW', 'AFG', 'AGO', 'AIA', 'ALA', 'ALB', 'AND', 'ARE',
- 'ARG', 'ARM', 'ASM', 'ATA', 'ATF', 'ATG', 'AUS', 'AUT',
- 'AZE', 'BDI', 'BEL', 'BEN', 'BES', 'BFA', 'BGD', 'BGR',
- 'BHR', 'BHS', 'BIH', 'BLM', 'BLR', 'BLZ', 'BMU', 'BOL',
- 'BRA', 'BRB', 'BRN', 'BTN', 'BVT', 'BWA', 'CAF', 'CAN',
- 'CCK', 'CHE', 'CHL', 'CHN', 'CIV', 'CMR', 'COD', 'COG',
- 'COK', 'COL', 'COM', 'CPV', 'CRI', 'CUB', 'CUW', 'CXR',
- 'CYM', 'CYP', 'CZE', 'DEU', 'DJI', 'DMA', 'DNK', 'DOM',
- 'DZA', 'ECU', 'EGY', 'ERI', 'ESH', 'ESP', 'EST', 'ETH',
- 'FIN', 'FJI', 'FLK', 'FRA', 'FRO', 'FSM', 'GAB', 'GBR',
- 'GEO', 'GGY', 'GHA', 'GIB', 'GIN', 'GLP', 'GMB', 'GNB',
- 'GNQ', 'GRC', 'GRD', 'GRL', 'GTM', 'GUF', 'GUM', 'GUY',
- 'HKG', 'HMD', 'HND', 'HRV', 'HTI', 'HUN', 'IDN', 'IMN',
- 'IND', 'IOT', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
- 'JAM', 'JEY', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM',
- 'KIR', 'KNA', 'KOR', 'KWT', 'LAO', 'LBN', 'LBR', 'LBY',
- 'LCA', 'LIE', 'LKA', 'LSO', 'LTU', 'LUX', 'LVA', 'MAC',
- 'MAF', 'MAR', 'MCO', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
- 'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ',
- 'MRT', 'MSR', 'MTQ', 'MUS', 'MWI', 'MYS', 'MYT', 'NAM',
- 'NCL', 'NER', 'NFK', 'NGA', 'NIC', 'NIU', 'NLD', 'NOR',
- 'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PCN', 'PER',
- 'PHL', 'PLW', 'PNG', 'POL', 'PRI', 'PRK', 'PRT', 'PRY',
- 'PSE', 'PYF', 'QAT', 'REU', 'ROU', 'RUS', 'RWA', 'SAU',
- 'SDN', 'SEN', 'SGP', 'SGS', 'SHN', 'SJM', 'SLB', 'SLE',
- 'SLV', 'SMR', 'SOM', 'SPM', 'SRB', 'SSD', 'STP', 'SUR',
- 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM', 'SYC', 'SYR', 'TCA',
- 'TCD', 'TGO', 'THA', 'TJK', 'TKL', 'TKM', 'TLS', 'TON',
- 'TTO', 'TUN', 'TUR', 'TUV', 'TWN', 'TZA', 'UGA', 'UKR',
- 'UMI', 'URY', 'USA', 'UZB', 'VAT', 'VCT', 'VEN', 'VGB',
- 'VIR', 'VNM', 'VUT', 'WLF', 'WSM', 'YEM', 'ZAF', 'ZMB',
- 'ZWE', 'all', 'ALL', 'All']
+WB_API_URL = "https://api.worldbank.org/v2"
+
+country_codes = [
+ "AD",
+ "AE",
+ "AF",
+ "AG",
+ "AI",
+ "AL",
+ "AM",
+ "AO",
+ "AQ",
+ "AR",
+ "AS",
+ "AT",
+ "AU",
+ "AW",
+ "AX",
+ "AZ",
+ "BA",
+ "BB",
+ "BD",
+ "BE",
+ "BF",
+ "BG",
+ "BH",
+ "BI",
+ "BJ",
+ "BL",
+ "BM",
+ "BN",
+ "BO",
+ "BQ",
+ "BR",
+ "BS",
+ "BT",
+ "BV",
+ "BW",
+ "BY",
+ "BZ",
+ "CA",
+ "CC",
+ "CD",
+ "CF",
+ "CG",
+ "CH",
+ "CI",
+ "CK",
+ "CL",
+ "CM",
+ "CN",
+ "CO",
+ "CR",
+ "CU",
+ "CV",
+ "CW",
+ "CX",
+ "CY",
+ "CZ",
+ "DE",
+ "DJ",
+ "DK",
+ "DM",
+ "DO",
+ "DZ",
+ "EC",
+ "EE",
+ "EG",
+ "EH",
+ "ER",
+ "ES",
+ "ET",
+ "FI",
+ "FJ",
+ "FK",
+ "FM",
+ "FO",
+ "FR",
+ "GA",
+ "GB",
+ "GD",
+ "GE",
+ "GF",
+ "GG",
+ "GH",
+ "GI",
+ "GL",
+ "GM",
+ "GN",
+ "GP",
+ "GQ",
+ "GR",
+ "GS",
+ "GT",
+ "GU",
+ "GW",
+ "GY",
+ "HK",
+ "HM",
+ "HN",
+ "HR",
+ "HT",
+ "HU",
+ "ID",
+ "IE",
+ "IL",
+ "IM",
+ "IN",
+ "IO",
+ "IQ",
+ "IR",
+ "IS",
+ "IT",
+ "JE",
+ "JM",
+ "JO",
+ "JP",
+ "KE",
+ "KG",
+ "KH",
+ "KI",
+ "KM",
+ "KN",
+ "KP",
+ "KR",
+ "KW",
+ "KY",
+ "KZ",
+ "LA",
+ "LB",
+ "LC",
+ "LI",
+ "LK",
+ "LR",
+ "LS",
+ "LT",
+ "LU",
+ "LV",
+ "LY",
+ "MA",
+ "MC",
+ "MD",
+ "ME",
+ "MF",
+ "MG",
+ "MH",
+ "MK",
+ "ML",
+ "MM",
+ "MN",
+ "MO",
+ "MP",
+ "MQ",
+ "MR",
+ "MS",
+ "MT",
+ "MU",
+ "MV",
+ "MW",
+ "MX",
+ "MY",
+ "MZ",
+ "NA",
+ "NC",
+ "NE",
+ "NF",
+ "NG",
+ "NI",
+ "NL",
+ "NO",
+ "NP",
+ "NR",
+ "NU",
+ "NZ",
+ "OM",
+ "PA",
+ "PE",
+ "PF",
+ "PG",
+ "PH",
+ "PK",
+ "PL",
+ "PM",
+ "PN",
+ "PR",
+ "PS",
+ "PT",
+ "PW",
+ "PY",
+ "QA",
+ "RE",
+ "RO",
+ "RS",
+ "RU",
+ "RW",
+ "SA",
+ "SB",
+ "SC",
+ "SD",
+ "SE",
+ "SG",
+ "SH",
+ "SI",
+ "SJ",
+ "SK",
+ "SL",
+ "SM",
+ "SN",
+ "SO",
+ "SR",
+ "SS",
+ "ST",
+ "SV",
+ "SX",
+ "SY",
+ "SZ",
+ "TC",
+ "TD",
+ "TF",
+ "TG",
+ "TH",
+ "TJ",
+ "TK",
+ "TL",
+ "TM",
+ "TN",
+ "TO",
+ "TR",
+ "TT",
+ "TV",
+ "TW",
+ "TZ",
+ "UA",
+ "UG",
+ "UM",
+ "US",
+ "UY",
+ "UZ",
+ "VA",
+ "VC",
+ "VE",
+ "VG",
+ "VI",
+ "VN",
+ "VU",
+ "WF",
+ "WS",
+ "YE",
+ "YT",
+ "ZA",
+ "ZM",
+ "ZW",
+ "ABW",
+ "AFG",
+ "AGO",
+ "AIA",
+ "ALA",
+ "ALB",
+ "AND",
+ "ARE",
+ "ARG",
+ "ARM",
+ "ASM",
+ "ATA",
+ "ATF",
+ "ATG",
+ "AUS",
+ "AUT",
+ "AZE",
+ "BDI",
+ "BEL",
+ "BEN",
+ "BES",
+ "BFA",
+ "BGD",
+ "BGR",
+ "BHR",
+ "BHS",
+ "BIH",
+ "BLM",
+ "BLR",
+ "BLZ",
+ "BMU",
+ "BOL",
+ "BRA",
+ "BRB",
+ "BRN",
+ "BTN",
+ "BVT",
+ "BWA",
+ "CAF",
+ "CAN",
+ "CCK",
+ "CHE",
+ "CHL",
+ "CHN",
+ "CIV",
+ "CMR",
+ "COD",
+ "COG",
+ "COK",
+ "COL",
+ "COM",
+ "CPV",
+ "CRI",
+ "CUB",
+ "CUW",
+ "CXR",
+ "CYM",
+ "CYP",
+ "CZE",
+ "DEU",
+ "DJI",
+ "DMA",
+ "DNK",
+ "DOM",
+ "DZA",
+ "ECU",
+ "EGY",
+ "ERI",
+ "ESH",
+ "ESP",
+ "EST",
+ "ETH",
+ "FIN",
+ "FJI",
+ "FLK",
+ "FRA",
+ "FRO",
+ "FSM",
+ "GAB",
+ "GBR",
+ "GEO",
+ "GGY",
+ "GHA",
+ "GIB",
+ "GIN",
+ "GLP",
+ "GMB",
+ "GNB",
+ "GNQ",
+ "GRC",
+ "GRD",
+ "GRL",
+ "GTM",
+ "GUF",
+ "GUM",
+ "GUY",
+ "HKG",
+ "HMD",
+ "HND",
+ "HRV",
+ "HTI",
+ "HUN",
+ "IDN",
+ "IMN",
+ "IND",
+ "IOT",
+ "IRL",
+ "IRN",
+ "IRQ",
+ "ISL",
+ "ISR",
+ "ITA",
+ "JAM",
+ "JEY",
+ "JOR",
+ "JPN",
+ "KAZ",
+ "KEN",
+ "KGZ",
+ "KHM",
+ "KIR",
+ "KNA",
+ "KOR",
+ "KWT",
+ "LAO",
+ "LBN",
+ "LBR",
+ "LBY",
+ "LCA",
+ "LIE",
+ "LKA",
+ "LSO",
+ "LTU",
+ "LUX",
+ "LVA",
+ "MAC",
+ "MAF",
+ "MAR",
+ "MCO",
+ "MDA",
+ "MDG",
+ "MDV",
+ "MEX",
+ "MHL",
+ "MKD",
+ "MLI",
+ "MLT",
+ "MMR",
+ "MNE",
+ "MNG",
+ "MNP",
+ "MOZ",
+ "MRT",
+ "MSR",
+ "MTQ",
+ "MUS",
+ "MWI",
+ "MYS",
+ "MYT",
+ "NAM",
+ "NCL",
+ "NER",
+ "NFK",
+ "NGA",
+ "NIC",
+ "NIU",
+ "NLD",
+ "NOR",
+ "NPL",
+ "NRU",
+ "NZL",
+ "OMN",
+ "PAK",
+ "PAN",
+ "PCN",
+ "PER",
+ "PHL",
+ "PLW",
+ "PNG",
+ "POL",
+ "PRI",
+ "PRK",
+ "PRT",
+ "PRY",
+ "PSE",
+ "PYF",
+ "QAT",
+ "REU",
+ "ROU",
+ "RUS",
+ "RWA",
+ "SAU",
+ "SDN",
+ "SEN",
+ "SGP",
+ "SGS",
+ "SHN",
+ "SJM",
+ "SLB",
+ "SLE",
+ "SLV",
+ "SMR",
+ "SOM",
+ "SPM",
+ "SRB",
+ "SSD",
+ "STP",
+ "SUR",
+ "SVK",
+ "SVN",
+ "SWE",
+ "SWZ",
+ "SXM",
+ "SYC",
+ "SYR",
+ "TCA",
+ "TCD",
+ "TGO",
+ "THA",
+ "TJK",
+ "TKL",
+ "TKM",
+ "TLS",
+ "TON",
+ "TTO",
+ "TUN",
+ "TUR",
+ "TUV",
+ "TWN",
+ "TZA",
+ "UGA",
+ "UKR",
+ "UMI",
+ "URY",
+ "USA",
+ "UZB",
+ "VAT",
+ "VCT",
+ "VEN",
+ "VGB",
+ "VIR",
+ "VNM",
+ "VUT",
+ "WLF",
+ "WSM",
+ "YEM",
+ "ZAF",
+ "ZMB",
+ "ZWE",
+ "all",
+ "ALL",
+ "All",
+]
class WorldBankReader(_BaseReader):
@@ -102,24 +548,37 @@ class WorldBankReader(_BaseReader):
errors='raise', will raise a ValueError on a bad country code.
"""
- _format = 'json'
-
- def __init__(self, symbols=None, countries=None,
- start=None, end=None, freq=None,
- retry_count=3, pause=0.1, session=None, errors='warn'):
+ _format = "json"
+
+ def __init__(
+ self,
+ symbols=None,
+ countries=None,
+ start=None,
+ end=None,
+ freq=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ errors="warn",
+ ):
if symbols is None:
- symbols = ['NY.GDP.MKTP.CD', 'NY.GNS.ICTR.ZS']
+ symbols = ["NY.GDP.MKTP.CD", "NY.GNS.ICTR.ZS"]
elif isinstance(symbols, string_types):
symbols = [symbols]
- super(WorldBankReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session)
+ super(WorldBankReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ )
if countries is None:
- countries = ['MX', 'CA', 'US']
+ countries = ["MX", "CA", "US"]
elif isinstance(countries, string_types):
countries = [countries]
@@ -127,17 +586,17 @@ def __init__(self, symbols=None, countries=None,
# Validate the input
if len(bad_countries) > 0:
tmp = ", ".join(bad_countries)
- if errors == 'raise':
+ if errors == "raise":
raise ValueError("Invalid Country Code(s): %s" % tmp)
- if errors == 'warn':
- warnings.warn('Non-standard ISO '
- 'country codes: %s' % tmp, UserWarning)
+ if errors == "warn":
+ warnings.warn(
+ "Non-standard ISO " "country codes: %s" % tmp, UserWarning
+ )
- freq_symbols = ['M', 'Q', 'A', None]
+ freq_symbols = ["M", "Q", "A", None]
if freq not in freq_symbols:
- msg = 'The frequency `{0}` is not in the accepted ' \
- 'list.'.format(freq)
+ msg = "The frequency `{0}` is not in the accepted " "list.".format(freq)
raise ValueError(msg)
self.freq = freq
@@ -147,24 +606,34 @@ def __init__(self, symbols=None, countries=None,
@property
def url(self):
"""API URL"""
- countries = ';'.join(self.countries)
- return WB_API_URL + '/countries/' + countries + '/indicators/'
+ countries = ";".join(self.countries)
+ return WB_API_URL + "/countries/" + countries + "/indicators/"
@property
def params(self):
"""Parameters to use in API calls"""
- if self.freq == 'M':
- return {'date': '{0}M{1:02d}:{2}M{3:02d}'.format(self.start.year,
- self.start.month, self.end.year, self.end.month),
- 'per_page': 25000, 'format': 'json'}
- elif self.freq == 'Q':
- return {'date': '{0}Q{1}:{2}Q{3}'.format(self.start.year,
- self.start.quarter, self.end.year,
- self.end.quarter), 'per_page': 25000,
- 'format': 'json'}
+ if self.freq == "M":
+ return {
+ "date": "{0}M{1:02d}:{2}M{3:02d}".format(
+ self.start.year, self.start.month, self.end.year, self.end.month
+ ),
+ "per_page": 25000,
+ "format": "json",
+ }
+ elif self.freq == "Q":
+ return {
+ "date": "{0}Q{1}:{2}Q{3}".format(
+ self.start.year, self.start.quarter, self.end.year, self.end.quarter
+ ),
+ "per_page": 25000,
+ "format": "json",
+ }
else:
- return {'date': '{0}:{1}'.format(self.start.year, self.end.year),
- 'per_page': 25000, 'format': 'json'}
+ return {
+ "date": "{0}:{1}".format(self.start.year, self.end.year),
+ "per_page": 25000,
+ "format": "json",
+ }
def read(self):
"""Read data"""
@@ -179,22 +648,22 @@ def _read(self):
# Build URL for api call
try:
df = self._read_one_data(self.url + indicator, self.params)
- df.columns = ['country', 'iso_code', 'year', indicator]
+ df.columns = ["country", "iso_code", "year", indicator]
data.append(df)
except ValueError as e:
- msg = str(e) + ' Indicator: ' + indicator
- if self.errors == 'raise':
+ msg = str(e) + " Indicator: " + indicator
+ if self.errors == "raise":
raise ValueError(msg)
- elif self.errors == 'warn':
+ elif self.errors == "warn":
warnings.warn(msg)
# Confirm we actually got some data, and build Dataframe
if len(data) > 0:
- out = reduce(lambda x, y: x.merge(y, how='outer'), data)
- out = out.drop('iso_code', axis=1)
- out = out.set_index(['country', 'year'])
- out = out.apply(pd.to_numeric, errors='ignore')
+ out = reduce(lambda x, y: x.merge(y, how="outer"), data)
+ out = out.drop("iso_code", axis=1)
+ out = out.set_index(["country", "year"])
+ out = out.apply(pd.to_numeric, errors="ignore")
return out
else:
@@ -205,32 +674,32 @@ def _read_lines(self, out):
# Check to see if there is a possible problem
possible_message = out[0]
- if 'message' in possible_message.keys():
- msg = possible_message['message'][0]
+ if "message" in possible_message.keys():
+ msg = possible_message["message"][0]
try:
- msg = msg['key'].split() + ["\n "] + msg['value'].split()
- wb_err = ' '.join(msg)
+ msg = msg["key"].split() + ["\n "] + msg["value"].split()
+ wb_err = " ".join(msg)
except Exception:
wb_err = ""
- if 'key' in msg.keys():
- wb_err = msg['key'] + "\n "
- if 'value' in msg.keys():
- wb_err += msg['value']
+ if "key" in msg.keys():
+ wb_err = msg["key"] + "\n "
+ if "value" in msg.keys():
+ wb_err += msg["value"]
msg = "Problem with a World Bank Query \n %s." % wb_err
raise ValueError(msg)
- if 'total' in possible_message.keys():
- if possible_message['total'] == 0:
+ if "total" in possible_message.keys():
+ if possible_message["total"] == 0:
msg = "No results found from world bank."
raise ValueError(msg)
# Parse JSON file
data = out[1]
- country = [x['country']['value'] for x in data]
- iso_code = [x['country']['id'] for x in data]
- year = [x['date'] for x in data]
- value = [x['value'] for x in data]
+ country = [x["country"]["value"] for x in data]
+ iso_code = [x["country"]["id"] for x in data]
+ year = [x["date"] for x in data]
+ value = [x["value"] for x in data]
# Prepare output
df = pd.DataFrame([country, iso_code, year, value]).T
return df
@@ -250,21 +719,19 @@ def get_countries(self):
* and longitude
"""
- url = WB_API_URL + '/countries/?per_page=1000&format=json'
+ url = WB_API_URL + "/countries/?per_page=1000&format=json"
resp = self._get_response(url)
data = resp.json()[1]
data = pd.DataFrame(data)
- data.adminregion = [x['value'] for x in data.adminregion]
- data.incomeLevel = [x['value'] for x in data.incomeLevel]
- data.lendingType = [x['value'] for x in data.lendingType]
- data.region = [x['value'] for x in data.region]
- data.latitude = [float(x) if x != "" else np.nan
- for x in data.latitude]
- data.longitude = [float(x) if x != "" else np.nan
- for x in data.longitude]
- data = data.rename(columns={'id': 'iso3c', 'iso2Code': 'iso2c'})
+ data.adminregion = [x["value"] for x in data.adminregion]
+ data.incomeLevel = [x["value"] for x in data.incomeLevel]
+ data.lendingType = [x["value"] for x in data.lendingType]
+ data.region = [x["value"] for x in data.region]
+ data.latitude = [float(x) if x != "" else np.nan for x in data.latitude]
+ data.longitude = [float(x) if x != "" else np.nan for x in data.longitude]
+ data = data.rename(columns={"id": "iso3c", "iso2Code": "iso2c"})
return data
def get_indicators(self):
@@ -273,35 +740,35 @@ def get_indicators(self):
if isinstance(_cached_series, pd.DataFrame):
return _cached_series.copy()
- url = WB_API_URL + '/indicators?per_page=50000&format=json'
+ url = WB_API_URL + "/indicators?per_page=50000&format=json"
resp = self._get_response(url)
data = resp.json()[1]
data = pd.DataFrame(data)
# Clean fields
- data.source = [x['value'] for x in data.source]
+ data.source = [x["value"] for x in data.source]
def encode_ascii(x):
- return x.encode('ascii', 'ignore')
+ return x.encode("ascii", "ignore")
data.sourceOrganization = data.sourceOrganization.apply(encode_ascii)
# Clean topic field
def get_value(x):
try:
- return x['value']
+ return x["value"]
except Exception:
- return ''
+ return ""
def get_list_of_values(x):
return [get_value(y) for y in x]
data.topics = data.topics.apply(get_list_of_values)
- data.topics = data.topics.apply(lambda x: ' ; '.join(x))
+ data.topics = data.topics.apply(lambda x: " ; ".join(x))
# Clean output
- data = data.sort_values(by='id')
+ data = data.sort_values(by="id")
data.index = pd.Index(lrange(data.shape[0]))
# cache
@@ -309,7 +776,7 @@ def get_list_of_values(x):
return data
- def search(self, string='gdp.*capi', field='name', case=False):
+ def search(self, string="gdp.*capi", field="name", case=False):
"""
Search available data series from the world bank
@@ -346,8 +813,15 @@ def search(self, string='gdp.*capi', field='name', case=False):
return out
-def download(country=None, indicator=None, start=2003, end=2005, freq=None,
- errors='warn', **kwargs):
+def download(
+ country=None,
+ indicator=None,
+ start=2003,
+ end=2005,
+ freq=None,
+ errors="warn",
+ **kwargs
+):
"""
Download data series from the World Bank's World Development Indicators
@@ -384,9 +858,15 @@ def download(country=None, indicator=None, start=2003, end=2005, freq=None,
data : DataFrame
DataFrame with columns country, iso_code, year, indicator value
"""
- return WorldBankReader(symbols=indicator, countries=country,
- start=start, end=end, freq=freq, errors=errors,
- **kwargs).read()
+ return WorldBankReader(
+ symbols=indicator,
+ countries=country,
+ start=start,
+ end=end,
+ freq=freq,
+ errors=errors,
+ **kwargs
+ ).read()
def get_countries(**kwargs):
@@ -422,7 +902,7 @@ def get_indicators(**kwargs):
_cached_series = None
-def search(string='gdp.*capi', field='name', case=False, **kwargs):
+def search(string="gdp.*capi", field="name", case=False, **kwargs):
"""
Search available data series from the world bank
@@ -455,5 +935,4 @@ def search(string='gdp.*capi', field='name', case=False, **kwargs):
* topics:
"""
- return WorldBankReader(**kwargs).search(string=string, field=field,
- case=case)
+ return WorldBankReader(**kwargs).search(string=string, field=field, case=case)
diff --git a/pandas_datareader/yahoo/actions.py b/pandas_datareader/yahoo/actions.py
index 5ca7e2b3..48a0f549 100644
--- a/pandas_datareader/yahoo/actions.py
+++ b/pandas_datareader/yahoo/actions.py
@@ -28,21 +28,21 @@ def get_actions(self):
def _get_one_action(data):
- actions = DataFrame(columns=['action', 'value'])
+ actions = DataFrame(columns=["action", "value"])
- if 'Dividends' in data.columns:
+ if "Dividends" in data.columns:
# Add a label column so we can combine our two DFs
- dividends = DataFrame(data['Dividends']).dropna()
+ dividends = DataFrame(data["Dividends"]).dropna()
dividends["action"] = "DIVIDEND"
- dividends = dividends.rename(columns={'Dividends': 'value'})
+ dividends = dividends.rename(columns={"Dividends": "value"})
actions = concat([actions, dividends], sort=True)
actions = actions.sort_index(ascending=False)
- if 'Splits' in data.columns:
+ if "Splits" in data.columns:
# Add a label column so we can combine our two DFs
- splits = DataFrame(data['Splits']).dropna()
+ splits = DataFrame(data["Splits"]).dropna()
splits["action"] = "SPLIT"
- splits = splits.rename(columns={'Splits': 'value'})
+ splits = splits.rename(columns={"Splits": "value"})
actions = concat([actions, splits], sort=True)
actions = actions.sort_index(ascending=False)
@@ -50,14 +50,12 @@ def _get_one_action(data):
class YahooDivReader(YahooActionReader):
-
def read(self):
data = super(YahooDivReader, self).read()
- return data[data['action'] == 'DIVIDEND']
+ return data[data["action"] == "DIVIDEND"]
class YahooSplitReader(YahooActionReader):
-
def read(self):
data = super(YahooSplitReader, self).read()
- return data[data['action'] == 'SPLIT']
+ return data[data["action"] == "SPLIT"]
diff --git a/pandas_datareader/yahoo/components.py b/pandas_datareader/yahoo/components.py
index 857acd23..004b6dd2 100644
--- a/pandas_datareader/yahoo/components.py
+++ b/pandas_datareader/yahoo/components.py
@@ -1,10 +1,9 @@
from pandas import DataFrame
from pandas.io.common import urlopen
-from pandas_datareader.exceptions import ImmediateDeprecationError, \
- DEP_ERROR_MSG
+from pandas_datareader.exceptions import DEP_ERROR_MSG, ImmediateDeprecationError
-_URL = 'http://download.finance.yahoo.com/d/quotes.csv?'
+_URL = "http://download.finance.yahoo.com/d/quotes.csv?"
def _get_data(idx_sym): # pragma: no cover
@@ -28,13 +27,13 @@ def _get_data(idx_sym): # pragma: no cover
-------
idx_df : DataFrame
"""
- raise ImmediateDeprecationError(DEP_ERROR_MSG.format('Yahoo Components'))
- stats = 'snx'
+ raise ImmediateDeprecationError(DEP_ERROR_MSG.format("Yahoo Components"))
+ stats = "snx"
# URL of form:
# http://download.finance.yahoo.com/d/quotes.csv?s=@%5EIXIC&f=snxl1d1t1c1ohgv
- url = _URL + 's={0}&f={1}&e=.csv&h={2}'
+ url = _URL + "s={0}&f={1}&e=.csv&h={2}"
- idx_mod = idx_sym.replace('^', '@%5E')
+ idx_mod = idx_sym.replace("^", "@%5E")
url_str = url.format(idx_mod, stats, 1)
idx_df = DataFrame()
@@ -47,12 +46,12 @@ def _get_data(idx_sym): # pragma: no cover
url_str = url.format(idx_mod, stats, comp_idx)
with urlopen(url_str) as resp:
raw = resp.read()
- lines = raw.decode('utf-8').strip().strip('"').split('"\r\n"')
+ lines = raw.decode("utf-8").strip().strip('"').split('"\r\n"')
lines = [line.strip().split('","') for line in lines]
- temp_df = DataFrame(lines, columns=['ticker', 'name', 'exchange'])
+ temp_df = DataFrame(lines, columns=["ticker", "name", "exchange"])
temp_df = temp_df.drop_duplicates()
- temp_df = temp_df.set_index('ticker')
+ temp_df = temp_df.set_index("ticker")
mask = ~temp_df.index.isin(idx_df.index)
comp_idx = comp_idx + 50
diff --git a/pandas_datareader/yahoo/daily.py b/pandas_datareader/yahoo/daily.py
index 821ca07d..2a383eb8 100644
--- a/pandas_datareader/yahoo/daily.py
+++ b/pandas_datareader/yahoo/daily.py
@@ -3,7 +3,9 @@
import json
import re
import time
-from pandas import (DataFrame, to_datetime, notnull, isnull)
+
+from pandas import DataFrame, isnull, notnull, to_datetime
+
from pandas_datareader._utils import RemoteDataError
from pandas_datareader.base import _DailyBaseReader
@@ -49,26 +51,41 @@ class YahooDailyReader(_DailyBaseReader):
If True, adjusts dividends for splits.
"""
- def __init__(self, symbols=None, start=None, end=None, retry_count=3,
- pause=0.1, session=None, adjust_price=False,
- ret_index=False, chunksize=1, interval='d',
- get_actions=False, adjust_dividends=True):
- super(YahooDailyReader, self).__init__(symbols=symbols,
- start=start, end=end,
- retry_count=retry_count,
- pause=pause, session=session,
- chunksize=chunksize)
+ def __init__(
+ self,
+ symbols=None,
+ start=None,
+ end=None,
+ retry_count=3,
+ pause=0.1,
+ session=None,
+ adjust_price=False,
+ ret_index=False,
+ chunksize=1,
+ interval="d",
+ get_actions=False,
+ adjust_dividends=True,
+ ):
+ super(YahooDailyReader, self).__init__(
+ symbols=symbols,
+ start=start,
+ end=end,
+ retry_count=retry_count,
+ pause=pause,
+ session=session,
+ chunksize=chunksize,
+ )
# Ladder up the wait time between subsequent requests to improve
# probability of a successful retry
self.pause_multiplier = 2.5
self.headers = {
- 'Connection': 'keep-alive',
- 'Expires': str(-1),
- 'Upgrade-Insecure-Requests': str(1),
+ "Connection": "keep-alive",
+ "Expires": str(-1),
+ "Upgrade-Insecure-Requests": str(1),
# Google Chrome:
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.99 Safari/537.36' # noqa
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.99 Safari/537.36", # noqa
}
self.adjust_price = adjust_price
@@ -76,17 +93,19 @@ def __init__(self, symbols=None, start=None, end=None, retry_count=3,
self.interval = interval
self._get_actions = get_actions
- if self.interval not in ['d', 'wk', 'mo', 'm', 'w']:
- raise ValueError("Invalid interval: valid values are 'd', 'wk' and 'mo'. 'm' and 'w' have been implemented for " # noqa
- "backward compatibility. 'v' has been moved to the yahoo-actions or yahoo-dividends APIs.") # noqa
- elif self.interval in ['m', 'mo']:
- self.pdinterval = 'm'
- self.interval = 'mo'
- elif self.interval in ['w', 'wk']:
- self.pdinterval = 'w'
- self.interval = 'wk'
-
- self.interval = '1' + self.interval
+ if self.interval not in ["d", "wk", "mo", "m", "w"]:
+ raise ValueError(
+ "Invalid interval: valid values are 'd', 'wk' and 'mo'. 'm' and 'w' have been implemented for " # noqa
+ "backward compatibility. 'v' has been moved to the yahoo-actions or yahoo-dividends APIs."
+ ) # noqa
+ elif self.interval in ["m", "mo"]:
+ self.pdinterval = "m"
+ self.interval = "mo"
+ elif self.interval in ["w", "wk"]:
+ self.pdinterval = "w"
+ self.interval = "wk"
+
+ self.interval = "1" + self.interval
self.adjust_dividends = adjust_dividends
@property
@@ -95,7 +114,7 @@ def get_actions(self):
@property
def url(self):
- return 'https://finance.yahoo.com/quote/{}/history'
+ return "https://finance.yahoo.com/quote/{}/history"
# Test test_get_data_interval() crashed because of this issue, probably
# whole yahoo part of package wasn't
@@ -110,88 +129,87 @@ def _get_params(self, symbol):
unix_end += four_hours_in_seconds
params = {
- 'period1': unix_start,
- 'period2': unix_end,
- 'interval': self.interval,
- 'frequency': self.interval,
- 'filter': 'history',
- 'symbol': symbol
+ "period1": unix_start,
+ "period2": unix_end,
+ "interval": self.interval,
+ "frequency": self.interval,
+ "filter": "history",
+ "symbol": symbol,
}
return params
def _read_one_data(self, url, params):
""" read one data from specified symbol """
- symbol = params['symbol']
- del params['symbol']
+ symbol = params["symbol"]
+ del params["symbol"]
url = url.format(symbol)
resp = self._get_response(url, params=params)
- ptrn = r'root\.App\.main = (.*?);\n}\(this\)\);'
+ ptrn = r"root\.App\.main = (.*?);\n}\(this\)\);"
try:
j = json.loads(re.search(ptrn, resp.text, re.DOTALL).group(1))
- data = j['context']['dispatcher']['stores']['HistoricalPriceStore']
+ data = j["context"]["dispatcher"]["stores"]["HistoricalPriceStore"]
except KeyError:
- msg = 'No data fetched for symbol {} using {}'
+ msg = "No data fetched for symbol {} using {}"
raise RemoteDataError(msg.format(symbol, self.__class__.__name__))
# price data
- prices = DataFrame(data['prices'])
+ prices = DataFrame(data["prices"])
prices.columns = [col.capitalize() for col in prices.columns]
- prices['Date'] = to_datetime(
- to_datetime(prices['Date'], unit='s').dt.date)
+ prices["Date"] = to_datetime(to_datetime(prices["Date"], unit="s").dt.date)
- if 'Data' in prices.columns:
- prices = prices[prices['Data'].isnull()]
- prices = prices[['Date', 'High', 'Low', 'Open', 'Close', 'Volume',
- 'Adjclose']]
- prices = prices.rename(columns={'Adjclose': 'Adj Close'})
+ if "Data" in prices.columns:
+ prices = prices[prices["Data"].isnull()]
+ prices = prices[["Date", "High", "Low", "Open", "Close", "Volume", "Adjclose"]]
+ prices = prices.rename(columns={"Adjclose": "Adj Close"})
- prices = prices.set_index('Date')
- prices = prices.sort_index().dropna(how='all')
+ prices = prices.set_index("Date")
+ prices = prices.sort_index().dropna(how="all")
if self.ret_index:
- prices['Ret_Index'] = \
- _calc_return_index(prices['Adj Close'])
+ prices["Ret_Index"] = _calc_return_index(prices["Adj Close"])
if self.adjust_price:
prices = _adjust_prices(prices)
# dividends & splits data
- if self.get_actions and data['eventsData']:
+ if self.get_actions and data["eventsData"]:
- actions = DataFrame(data['eventsData'])
+ actions = DataFrame(data["eventsData"])
actions.columns = [col.capitalize() for col in actions.columns]
- actions['Date'] = to_datetime(
- to_datetime(actions['Date'], unit='s').dt.date)
+ actions["Date"] = to_datetime(
+ to_datetime(actions["Date"], unit="s").dt.date
+ )
- types = actions['Type'].unique()
- if 'DIVIDEND' in types:
- divs = actions[actions.Type == 'DIVIDEND'].copy()
- divs = divs[['Date', 'Amount']].reset_index(drop=True)
- divs = divs.set_index('Date')
- divs = divs.rename(columns={'Amount': 'Dividends'})
- prices = prices.join(divs, how='outer')
+ types = actions["Type"].unique()
+ if "DIVIDEND" in types:
+ divs = actions[actions.Type == "DIVIDEND"].copy()
+ divs = divs[["Date", "Amount"]].reset_index(drop=True)
+ divs = divs.set_index("Date")
+ divs = divs.rename(columns={"Amount": "Dividends"})
+ prices = prices.join(divs, how="outer")
- if 'SPLIT' in types:
+ if "SPLIT" in types:
def split_ratio(row):
- if float(row['Numerator']) > 0:
- return eval(row['Splitratio'])
+ if float(row["Numerator"]) > 0:
+ return eval(row["Splitratio"])
else:
return 1
- splits = actions[actions.Type == 'SPLIT'].copy()
- splits['SplitRatio'] = splits.apply(split_ratio, axis=1)
+ splits = actions[actions.Type == "SPLIT"].copy()
+ splits["SplitRatio"] = splits.apply(split_ratio, axis=1)
splits = splits.reset_index(drop=True)
- splits = splits.set_index('Date')
- splits['Splits'] = splits['SplitRatio']
- prices = prices.join(splits['Splits'], how='outer')
+ splits = splits.set_index("Date")
+ splits["Splits"] = splits["SplitRatio"]
+ prices = prices.join(splits["Splits"], how="outer")
- if 'DIVIDEND' in types and not self.adjust_dividends:
+ if "DIVIDEND" in types and not self.adjust_dividends:
# dividends are adjusted automatically by Yahoo
- adj = prices['Splits'].sort_index(ascending=False).fillna(
- 1).cumprod()
- prices['Dividends'] = prices['Dividends'] / adj
+ adj = (
+ prices["Splits"].sort_index(ascending=False).fillna(1).cumprod()
+ )
+ prices["Dividends"] = prices["Dividends"] / adj
return prices
@@ -202,14 +220,14 @@ def _adjust_prices(hist_data, price_list=None):
'Adj Close' price. Adds 'Adj_Ratio' column.
"""
if price_list is None:
- price_list = 'Open', 'High', 'Low', 'Close'
- adj_ratio = hist_data['Adj Close'] / hist_data['Close']
+ price_list = "Open", "High", "Low", "Close"
+ adj_ratio = hist_data["Adj Close"] / hist_data["Close"]
data = hist_data.copy()
for item in price_list:
data[item] = hist_data[item] * adj_ratio
- data['Adj_Ratio'] = adj_ratio
- del data['Adj Close']
+ data["Adj_Ratio"] = adj_ratio
+ del data["Adj Close"]
return data
diff --git a/pandas_datareader/yahoo/fx.py b/pandas_datareader/yahoo/fx.py
index adf264e7..68bd03f7 100644
--- a/pandas_datareader/yahoo/fx.py
+++ b/pandas_datareader/yahoo/fx.py
@@ -1,10 +1,12 @@
-import time
import json
+import time
import warnings
-from pandas import (DataFrame, Series, to_datetime, concat)
-from pandas_datareader.yahoo.daily import YahooDailyReader
-from pandas_datareader._utils import (RemoteDataError, SymbolWarning)
+
+from pandas import DataFrame, Series, concat, to_datetime
+
+from pandas_datareader._utils import RemoteDataError, SymbolWarning
from pandas_datareader.compat import string_types
+from pandas_datareader.yahoo.daily import YahooDailyReader
class YahooFXReader(YahooDailyReader):
@@ -41,13 +43,13 @@ def _get_params(self, symbol):
unix_end = int(time.mktime(day_end.timetuple()))
params = {
- 'symbol': symbol + '=X',
- 'period1': unix_start,
- 'period2': unix_end,
- 'interval': self.interval, # deal with this
- 'includePrePost': 'true',
- 'events': 'div|split|earn',
- 'corsDomain': 'finance.yahoo.com'
+ "symbol": symbol + "=X",
+ "period1": unix_start,
+ "period2": unix_end,
+ "interval": self.interval, # deal with this
+ "includePrePost": "true",
+ "events": "div|split|earn",
+ "corsDomain": "finance.yahoo.com",
}
return params
@@ -64,29 +66,27 @@ def read(self):
else:
df = self._dl_mult_symbols(self.symbols)
- if 'Date' in df:
- df = df.set_index('Date')
+ if "Date" in df:
+ df = df.set_index("Date")
- if 'Volume' in df:
- df = df.drop('Volume', axis=1)
+ if "Volume" in df:
+ df = df.drop("Volume", axis=1)
- return df.sort_index().dropna(how='all')
+ return df.sort_index().dropna(how="all")
finally:
self.close()
def _read_one_data(self, symbol):
""" read one data from specified URL """
- url = 'https://query1.finance.yahoo.com/v8/finance/chart/{}=X'\
- .format(symbol)
+ url = "https://query1.finance.yahoo.com/v8/finance/chart/{}=X".format(symbol)
params = self._get_params(symbol)
resp = self._get_response(url, params=params)
jsn = json.loads(resp.text)
- data = jsn['chart']['result'][0]
- df = DataFrame(data['indicators']['quote'][0])
- df.insert(0, 'date', to_datetime(Series(data['timestamp']),
- unit='s').dt.date)
+ data = jsn["chart"]["result"][0]
+ df = DataFrame(data["indicators"]["quote"][0])
+ df.insert(0, "date", to_datetime(Series(data["timestamp"]), unit="s").dt.date)
df.columns = map(str.capitalize, df.columns)
return df
@@ -97,11 +97,11 @@ def _dl_mult_symbols(self, symbols):
for sym in symbols:
try:
df = self._read_one_data(sym)
- df['PairCode'] = sym
+ df["PairCode"] = sym
stocks[sym] = df
passed.append(sym)
except IOError:
- msg = 'Failed to read symbol: {0!r}, replacing with NaN.'
+ msg = "Failed to read symbol: {0!r}, replacing with NaN."
warnings.warn(msg.format(sym), SymbolWarning)
failed.append(sym)
@@ -109,4 +109,4 @@ def _dl_mult_symbols(self, symbols):
msg = "No data fetched using {0!r}"
raise RemoteDataError(msg.format(self.__class__.__name__))
else:
- return concat(stocks).set_index(['PairCode', 'Date'])
+ return concat(stocks).set_index(["PairCode", "Date"])
diff --git a/pandas_datareader/yahoo/options.py b/pandas_datareader/yahoo/options.py
index a8a5eea8..30e7af34 100644
--- a/pandas_datareader/yahoo/options.py
+++ b/pandas_datareader/yahoo/options.py
@@ -1,13 +1,11 @@
-import warnings
import datetime as dt
-import numpy as np
import json
+import warnings
-from pandas import to_datetime
-from pandas import concat, DatetimeIndex, Series, MultiIndex
+import numpy as np
+from pandas import DataFrame, DatetimeIndex, MultiIndex, Series, concat, to_datetime
from pandas.io.json import read_json
from pandas.tseries.offsets import MonthEnd
-from pandas import DataFrame
from pandas_datareader._utils import RemoteDataError
from pandas_datareader.base import _OptionBaseReader
@@ -64,8 +62,7 @@ class Options(_OptionBaseReader):
>>> all_data = aapl.get_all_data()
"""
- _OPTIONS_BASE_URL = ('https://query1.finance.yahoo.com/'
- 'v7/finance/options/{sym}')
+ _OPTIONS_BASE_URL = "https://query1.finance.yahoo.com/" "v7/finance/options/{sym}"
def get_options_data(self, month=None, year=None, expiry=None):
"""
@@ -135,27 +132,30 @@ def get_options_data(self, month=None, year=None, expiry=None):
the year, month and day for the expiry of the options.
"""
- return concat([f(month, year, expiry)
- for f in (self.get_put_data,
- self.get_call_data)]).sort_index()
+ return concat(
+ [f(month, year, expiry) for f in (self.get_put_data, self.get_call_data)]
+ ).sort_index()
def _option_from_url(self, url):
jd = self._parse_url(url)
- result = jd['optionChain']['result']
+ result = jd["optionChain"]["result"]
try:
- calls = result['options']['calls']
- puts = result['options']['puts']
+ calls = result["options"]["calls"]
+ puts = result["options"]["puts"]
except IndexError:
- raise RemoteDataError('Option json not available '
- 'for url: %s' % url)
-
- self.underlying_price = (result['quote']['regularMarketPrice']
- if result['quote']['marketState'] == 'PRE'
- else result['quote']['preMarketPrice'])
- quote_unix_time = (result['quote']['regularMarketTime']
- if result['quote']['marketState'] == 'PRE'
- else result['quote']['preMarketTime'])
+ raise RemoteDataError("Option json not available " "for url: %s" % url)
+
+ self.underlying_price = (
+ result["quote"]["regularMarketPrice"]
+ if result["quote"]["marketState"] == "PRE"
+ else result["quote"]["preMarketPrice"]
+ )
+ quote_unix_time = (
+ result["quote"]["regularMarketTime"]
+ if result["quote"]["marketState"] == "PRE"
+ else result["quote"]["preMarketTime"]
+ )
self.quote_time = dt.datetime.fromtimestamp(quote_unix_time)
calls = _parse_options_data(calls)
@@ -164,10 +164,10 @@ def _option_from_url(self, url):
calls = self._process_data(calls)
puts = self._process_data(puts)
- return {'calls': calls, 'puts': puts}
+ return {"calls": calls, "puts": puts}
def _get_option_data(self, expiry, name):
- frame_name = '_frames' + self._expiry_to_string(expiry)
+ frame_name = "_frames" + self._expiry_to_string(expiry)
try:
frames = getattr(self, frame_name)
@@ -323,8 +323,9 @@ def get_put_data(self, month=None, year=None, expiry=None):
expiry = self._try_parse_dates(year, month, expiry)
return self._get_data_in_date_range(expiry, put=True, call=False)
- def get_near_stock_price(self, above_below=2, call=True, put=False,
- month=None, year=None, expiry=None):
+ def get_near_stock_price(
+ self, above_below=2, call=True, put=False, month=None, year=None, expiry=None
+ ):
"""
***Experimental***
Returns a DataFrame of options that are near the current stock price.
@@ -403,17 +404,19 @@ def _chop_data(self, df, above_below=2, underlying_price=None):
except AttributeError:
underlying_price = np.nan
- max_strike = max(df.index.get_level_values('Strike'))
- min_strike = min(df.index.get_level_values('Strike'))
+ max_strike = max(df.index.get_level_values("Strike"))
+ min_strike = min(df.index.get_level_values("Strike"))
- if (not np.isnan(underlying_price) and
- min_strike < underlying_price < max_strike):
- start_index = np.where(df.index.get_level_values('Strike') >
- underlying_price)[0][0]
+ if (
+ not np.isnan(underlying_price)
+ and min_strike < underlying_price < max_strike
+ ):
+ start_index = np.where(
+ df.index.get_level_values("Strike") > underlying_price
+ )[0][0]
- get_range = slice(start_index - above_below,
- start_index + above_below + 1)
- df = df[get_range].dropna(how='all')
+ get_range = slice(start_index - above_below, start_index + above_below + 1)
+ df = df[get_range].dropna(how="all")
return df
@@ -440,20 +443,25 @@ def _try_parse_dates(self, year, month, expiry):
# Checks if the user gave one of the month or the year
# but not both and did not provide an expiry:
- if ((month is not None and year is None) or
- (month is None and year is not None) and expiry is None):
- msg = ("You must specify either (`year` and `month`) or `expiry` "
- "or none of these options for the next expiry.")
+ if (
+ (month is not None and year is None)
+ or (month is None and year is not None)
+ and expiry is None
+ ):
+ msg = (
+ "You must specify either (`year` and `month`) or `expiry` "
+ "or none of these options for the next expiry."
+ )
raise ValueError(msg)
if expiry is not None:
- if hasattr(expiry, '__iter__'):
+ if hasattr(expiry, "__iter__"):
expiry = [self._validate_expiry(exp) for exp in expiry]
else:
expiry = [self._validate_expiry(expiry)]
if len(expiry) == 0:
- raise ValueError('No expiries available for given input.')
+ raise ValueError("No expiries available for given input.")
elif year is None and month is None:
# No arguments passed, provide next expiry
@@ -464,11 +472,13 @@ def _try_parse_dates(self, year, month, expiry):
else:
# Year and month passed, provide all expiries in that month
- expiry = [expiry for expiry in self.expiry_dates
- if expiry.year == year and expiry.month == month]
+ expiry = [
+ expiry
+ for expiry in self.expiry_dates
+ if expiry.year == year and expiry.month == month
+ ]
if len(expiry) == 0:
- raise ValueError('No expiries available '
- 'in %s-%s' % (year, month))
+ raise ValueError("No expiries available " "in %s-%s" % (year, month))
return expiry
@@ -482,7 +492,7 @@ def _validate_expiry(self, expiry):
expiry_dates = self.expiry_dates
expiry = to_datetime(expiry)
- if hasattr(expiry, 'date'):
+ if hasattr(expiry, "date"):
expiry = expiry.date()
if expiry in expiry_dates:
@@ -491,8 +501,9 @@ def _validate_expiry(self, expiry):
index = DatetimeIndex(expiry_dates).sort_values()
return index[index.date >= expiry][0].date()
- def get_forward_data(self, months, call=True, put=False, near=False,
- above_below=2): # pragma: no cover
+ def get_forward_data(
+ self, months, call=True, put=False, near=False, above_below=2
+ ): # pragma: no cover
"""
***Experimental***
Gets either call, put, or both data for months starting in the current
@@ -610,16 +621,16 @@ def get_all_data(self, call=True, put=True):
def _get_data_in_date_range(self, dates, call=True, put=True):
- to_ret = Series({'call': call, 'put': put})
+ to_ret = Series({"call": call, "put": put})
to_ret = to_ret[to_ret].index
df = self._load_data(dates)
types = [typ for typ in to_ret]
- df_filtered_by_type = df[df.index.map(
- lambda x: x[2] in types).tolist()]
+ df_filtered_by_type = df[df.index.map(lambda x: x[2] in types).tolist()]
df_filtered_by_expiry = df_filtered_by_type[
- df_filtered_by_type.index.get_level_values('Expiry').isin(dates)]
+ df_filtered_by_type.index.get_level_values("Expiry").isin(dates)
+ ]
return df_filtered_by_expiry
@property
@@ -669,12 +680,13 @@ def _get_expiry_dates(self):
url = self._OPTIONS_BASE_URL.format(sym=self.symbol)
jd = self._parse_url(url)
- expiry_dates = [dt.datetime.utcfromtimestamp(ts).date()
- for ts in jd['optionChain'][
- 'result'][0]['expirationDates']]
+ expiry_dates = [
+ dt.datetime.utcfromtimestamp(ts).date()
+ for ts in jd["optionChain"]["result"][0]["expirationDates"]
+ ]
if len(expiry_dates) == 0:
- raise RemoteDataError('Data not available') # pragma: no cover
+ raise RemoteDataError("Data not available") # pragma: no cover
self._expiry_dates = expiry_dates
return expiry_dates
@@ -694,8 +706,9 @@ def _parse_url(self, url):
"""
jd = json.loads(self._read_url_as_StringIO(url).read())
if jd is None: # pragma: no cover
- raise RemoteDataError("Parsed URL {0!r} is not "
- "a valid json object".format(url))
+ raise RemoteDataError(
+ "Parsed URL {0!r} is not " "a valid json object".format(url)
+ )
return jd
def _process_data(self, jd):
@@ -716,22 +729,39 @@ def _process_data(self, jd):
A DataFrame with requested options data.
"""
- columns = ['Last', 'Bid', 'Ask', 'Chg', 'PctChg', 'Vol',
- 'Open_Int', 'IV', 'Root', 'IsNonstandard', 'Underlying',
- 'Underlying_Price', 'Quote_Time', 'Last_Trade_Date', 'JSON']
- indexes = ['Strike', 'Expiry', 'Type', 'Symbol']
+ columns = [
+ "Last",
+ "Bid",
+ "Ask",
+ "Chg",
+ "PctChg",
+ "Vol",
+ "Open_Int",
+ "IV",
+ "Root",
+ "IsNonstandard",
+ "Underlying",
+ "Underlying_Price",
+ "Quote_Time",
+ "Last_Trade_Date",
+ "JSON",
+ ]
+ indexes = ["Strike", "Expiry", "Type", "Symbol"]
rows_list, index = self._process_rows(jd)
if len(rows_list) > 0:
- df = DataFrame(rows_list, columns=columns,
- index=MultiIndex.from_tuples(index, names=indexes))
+ df = DataFrame(
+ rows_list,
+ columns=columns,
+ index=MultiIndex.from_tuples(index, names=indexes),
+ )
else:
df = DataFrame(columns=columns)
- df['IsNonstandard'] = df['Root'] != self.symbol.replace('-', '')
+ df["IsNonstandard"] = df["Root"] != self.symbol.replace("-", "")
# Make dtype consistent, requires float64 as there can be NaNs
- df['Vol'] = df['Vol'].astype('float64')
- df['Open_Int'] = df['Open_Int'].astype('float64')
+ df["Vol"] = df["Vol"].astype("float64")
+ df["Open_Int"] = df["Open_Int"].astype("float64")
return df.sort_index()
@@ -740,59 +770,65 @@ def _process_rows(self, jd):
index = []
# handle no results
- if len(jd['optionChain']['result']) <= 0:
+ if len(jd["optionChain"]["result"]) <= 0:
return rows_list, index
- quote = jd['optionChain']['result'][0]['quote']
- for option in jd['optionChain']['result'][0]['options']:
- for typ in ['calls', 'puts']:
+ quote = jd["optionChain"]["result"][0]["quote"]
+ for option in jd["optionChain"]["result"][0]["options"]:
+ for typ in ["calls", "puts"]:
options_by_type = option[typ]
for option_by_strike in options_by_type:
d = {}
for dkey, rkey, ntype in [
- ('Last', 'lastPrice', float),
- ('Bid', 'bid', float),
- ('Ask', 'ask', float),
- ('Chg', 'change', float),
- ('PctChg', 'percentChange', float),
- ('Vol', 'volume', int),
- ('Open_Int', 'openInterest', int),
- ('IV', 'impliedVolatility', float),
- ('Last_Trade_Date', 'lastTradeDate', int)
+ ("Last", "lastPrice", float),
+ ("Bid", "bid", float),
+ ("Ask", "ask", float),
+ ("Chg", "change", float),
+ ("PctChg", "percentChange", float),
+ ("Vol", "volume", int),
+ ("Open_Int", "openInterest", int),
+ ("IV", "impliedVolatility", float),
+ ("Last_Trade_Date", "lastTradeDate", int),
]:
try:
d[dkey] = ntype(option_by_strike[rkey])
except (KeyError, ValueError):
pass
- d['JSON'] = option_by_strike
- d['Root'] = option_by_strike['contractSymbol'][:-15]
- d['Underlying'] = self.symbol
-
- d['Underlying_Price'] = quote['regularMarketPrice']
- quote_unix_time = quote['regularMarketTime']
- if (quote['marketState'] == 'PRE' and
- 'preMarketPrice' in quote):
- d['Underlying_Price'] = quote['preMarketPrice']
- quote_unix_time = quote['preMarketTime']
- elif (quote['marketState'] == 'POSTPOST' and
- 'postMarketPrice' in quote):
- d['Underlying_Price'] = quote['postMarketPrice']
- quote_unix_time = quote['postMarketTime']
- d['Quote_Time'] = dt.datetime.utcfromtimestamp(
- quote_unix_time)
-
- self._underlying_price = d['Underlying_Price']
- self._quote_time = d['Quote_Time']
-
- d['Last_Trade_Date'] = dt.datetime.utcfromtimestamp(
- d['Last_Trade_Date'])
+ d["JSON"] = option_by_strike
+ d["Root"] = option_by_strike["contractSymbol"][:-15]
+ d["Underlying"] = self.symbol
+
+ d["Underlying_Price"] = quote["regularMarketPrice"]
+ quote_unix_time = quote["regularMarketTime"]
+ if quote["marketState"] == "PRE" and "preMarketPrice" in quote:
+ d["Underlying_Price"] = quote["preMarketPrice"]
+ quote_unix_time = quote["preMarketTime"]
+ elif (
+ quote["marketState"] == "POSTPOST"
+ and "postMarketPrice" in quote
+ ):
+ d["Underlying_Price"] = quote["postMarketPrice"]
+ quote_unix_time = quote["postMarketTime"]
+ d["Quote_Time"] = dt.datetime.utcfromtimestamp(quote_unix_time)
+
+ self._underlying_price = d["Underlying_Price"]
+ self._quote_time = d["Quote_Time"]
+
+ d["Last_Trade_Date"] = dt.datetime.utcfromtimestamp(
+ d["Last_Trade_Date"]
+ )
rows_list.append(d)
- index.append((float(option_by_strike['strike']),
- dt.datetime.utcfromtimestamp(
- option_by_strike['expiration']),
- typ.replace('s', ''),
- option_by_strike['contractSymbol']))
+ index.append(
+ (
+ float(option_by_strike["strike"]),
+ dt.datetime.utcfromtimestamp(
+ option_by_strike["expiration"]
+ ),
+ typ.replace("s", ""),
+ option_by_strike["contractSymbol"],
+ )
+ )
return rows_list, index
def _load_data(self, exp_dates=None):
@@ -816,14 +852,18 @@ def _load_data(self, exp_dates=None):
try:
if exp_dates is None:
exp_dates = self._get_expiry_dates()
- exp_unix_times = [int((dt.datetime(exp_date.year,
- exp_date.month,
- exp_date.day) - epoch
- ).total_seconds())
- for exp_date in exp_dates]
+ exp_unix_times = [
+ int(
+ (
+ dt.datetime(exp_date.year, exp_date.month, exp_date.day) - epoch
+ ).total_seconds()
+ )
+ for exp_date in exp_dates
+ ]
for exp_date in exp_unix_times:
- url = (self._OPTIONS_BASE_URL + '?date={exp_date}').format(
- sym=self.symbol, exp_date=exp_date)
+ url = (self._OPTIONS_BASE_URL + "?date={exp_date}").format(
+ sym=self.symbol, exp_date=exp_date
+ )
jd = self._parse_url(url)
data.append(self._process_data(jd))
return concat(data).sort_index()
diff --git a/pandas_datareader/yahoo/quotes.py b/pandas_datareader/yahoo/quotes.py
index b779e2ee..4dddff3b 100644
--- a/pandas_datareader/yahoo/quotes.py
+++ b/pandas_datareader/yahoo/quotes.py
@@ -1,14 +1,15 @@
-import json
from collections import OrderedDict
+import json
+
from pandas import DataFrame
from pandas_datareader.base import _BaseReader
from pandas_datareader.compat import string_types
_DEFAULT_PARAMS = {
- 'lang': 'en-US',
- 'corsDomain': 'finance.yahoo.com',
- '.tsrc': 'finance',
+ "lang": "en-US",
+ "corsDomain": "finance.yahoo.com",
+ ".tsrc": "finance",
}
@@ -18,7 +19,7 @@ class YahooQuotesReader(_BaseReader):
@property
def url(self):
- return 'https://query1.finance.yahoo.com/v7/finance/quote'
+ return "https://query1.finance.yahoo.com/v7/finance/quote"
def read(self):
if isinstance(self.symbols, string_types):
@@ -26,20 +27,20 @@ def read(self):
else:
data = OrderedDict()
for symbol in self.symbols:
- data[symbol] = \
- self._read_one_data(
- self.url, self.params(symbol)).loc[symbol]
- return DataFrame.from_dict(data, orient='index')
+ data[symbol] = self._read_one_data(self.url, self.params(symbol)).loc[
+ symbol
+ ]
+ return DataFrame.from_dict(data, orient="index")
def params(self, symbol):
"""Parameters to use in API calls"""
# Construct the code request string.
- params = {'symbols': symbol}
+ params = {"symbols": symbol}
params.update(_DEFAULT_PARAMS)
return params
def _read_lines(self, out):
- data = json.loads(out.read())['quoteResponse']['result'][0]
- idx = data.pop('symbol')
- data['price'] = data['regularMarketPrice']
+ data = json.loads(out.read())["quoteResponse"]["result"][0]
+ idx = data.pop("symbol")
+ data["price"] = data["regularMarketPrice"]
return DataFrame(data, index=[idx])
diff --git a/setup.cfg b/setup.cfg
index 626ed74a..108cb1be 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -13,11 +13,13 @@ known_compat=pandas_datareader.compat.*
sections=FUTURE,COMPAT,STDLIB,THIRDPARTY,PRE_CORE,FIRSTPARTY,LOCALFOLDER
known_first_party=pandas_datareader
known_third_party=numpy,pandas,pytest,requests
-multi_line_output=0
+multi_line_output=3
force_grid_wrap=0
+use_parentheses=True
+include_trailing_comma=True
combine_as_imports=True
force_sort_within_sections=True
-line_width=99
+line_width=88
[tool:pytest]
# sync minversion with setup.cfg & install.rst