diff --git a/pandas_datareader/__init__.py b/pandas_datareader/__init__.py index a792a806..616a75b9 100644 --- a/pandas_datareader/__init__.py +++ b/pandas_datareader/__init__.py @@ -28,6 +28,8 @@ get_records_iex, get_summary_iex, get_tops_iex, + get_custom_datareader, + register_custom_datareader, ) PKG = os.path.dirname(__file__) @@ -62,6 +64,8 @@ "get_data_tiingo", "get_iex_data_tiingo", "get_data_alphavantage", + "get_custom_datareader", + "register_custom_datareader", "test", ] diff --git a/pandas_datareader/data.py b/pandas_datareader/data.py index c2d6223a..4dc78348 100644 --- a/pandas_datareader/data.py +++ b/pandas_datareader/data.py @@ -7,7 +7,7 @@ import warnings from pandas.util._decorators import deprecate_kwarg - +from pandas_datareader.base import _BaseReader from pandas_datareader.av.forex import AVForexReader from pandas_datareader.av.quotes import AVQuotesReader from pandas_datareader.av.sector import AVSectorPerformanceReader @@ -60,9 +60,12 @@ "get_iex_book", "get_dailysummary_iex", "get_data_stooq", + "register_custom_datareader", "DataReader", ] +custom_datareader = {} + def get_data_alphavantage(*args, **kwargs): return AVTimeSeriesReader(*args, **kwargs).read() @@ -270,6 +273,58 @@ def get_iex_book(*args, **kwargs): return IEXDeep(*args, **kwargs).read() +def register_custom_datareader(custom_name, custom_class): + """ + Registers a custom datareader to be used + + Parameters + ---------- + custom_name : str + A string represents name you want to give to your custom class + custom_class : _BaseReader + A class that extends _BaseReader + + Returns + ------- + True if successful, Error otherwise. + """ + custom_datareader[custom_name] = custom_class + return True + + +def unregister_custom_datareader(custom_name): + """ + Unregisters a custom datareader to be used + + Parameters + ---------- + custom_name : str + A string represents name you want to give to your custom class + + Returns + ------- + True if successful, Error otherwise. + """ + del custom_datareader[custom_name] + return True + + +def get_custom_datareader(custom_name): + """ + Get a custom datareader registered before + + Parameters + ---------- + custom_name : str + A string represents name you gave to your custom class + + Returns + ------- + Class registered before + """ + return custom_datareader[custom_name] + + @deprecate_kwarg("access_key", "api_key") def DataReader( name, @@ -329,6 +384,7 @@ def DataReader( ff = DataReader("6_Portfolios_2x3", "famafrench") ff = DataReader("F-F_ST_Reversal_Factor", "famafrench") """ + custom_source = list(custom_datareader.keys()) expected_source = [ "yahoo", "iex", @@ -360,7 +416,7 @@ def DataReader( "av-intraday", "econdb", "naver", - ] + ] + custom_source if data_source not in expected_source: msg = "data_source=%r is not implemented" % data_source @@ -668,6 +724,18 @@ def DataReader( session=session, ).read() + elif data_source in custom_source: + CustomDataReader = get_custom_datareader(data_source) + return CustomDataReader( + symbols=name, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + api_key=api_key, + ).read() + else: msg = "data_source=%r is not implemented" % data_source raise NotImplementedError(msg) diff --git a/pandas_datareader/tests/test_data.py b/pandas_datareader/tests/test_data.py index 4cdea1d0..0a77a5e3 100644 --- a/pandas_datareader/tests/test_data.py +++ b/pandas_datareader/tests/test_data.py @@ -1,7 +1,8 @@ from pandas import DataFrame import pytest -from pandas_datareader.data import DataReader +from pandas_datareader.base import _DailyBaseReader +from pandas_datareader.data import DataReader, register_custom_datareader pytestmark = pytest.mark.stable @@ -18,3 +19,40 @@ def test_read_fred(self): def test_not_implemented(self): with pytest.raises(NotImplementedError): DataReader("NA", "NA") + + def test_custom_reader_acc(self): + class DemoReader(_DailyBaseReader): + def __init__( + self, + symbols=None, + start=None, + end=None, + retry_count=3, + pause=0.1, + session=None, + api_key=None, + ): + super().__init__( + symbols=symbols, + start=start, + end=end, + retry_count=retry_count, + pause=pause, + session=session, + ) + + @property + def url(self): + return "https://stooq.com/q/d/l/" + + def _get_params(self, symbol): + params = {"s": symbol, "i": "d"} + return params + + register_custom_datareader("demo", DemoReader) + result = DataReader("USDJPY", "demo") + assert isinstance(result, DataFrame) + + def test_custom_reader_fail(self): + with pytest.raises(NotImplementedError): + DataReader("USDJPY", "demo1")