Skip to content

Commit

Permalink
Merge pull request microsoft#372 from zhupr/data_storage
Browse files Browse the repository at this point in the history
add data storage
  • Loading branch information
you-n-g committed May 26, 2021
2 parents 635624e + a34d596 commit 5016d8c
Show file tree
Hide file tree
Showing 11 changed files with 1,070 additions and 55 deletions.
28 changes: 28 additions & 0 deletions docs/reference/api.rst
Expand Up @@ -53,6 +53,34 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:


Storage
-------------
.. autoclass:: qlib.data.storage.storage.BaseStorage
:members:

.. autoclass:: qlib.data.storage.storage.CalendarStorage
:members:

.. autoclass:: qlib.data.storage.storage.InstrumentStorage
:members:

.. autoclass:: qlib.data.storage.storage.FeatureStorage
:members:

.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
:members:

.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
:members:

.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
:members:

.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
:members:


Dataset
---------------

Expand Down
2 changes: 1 addition & 1 deletion qlib/contrib/backtest/position.py
Expand Up @@ -166,7 +166,7 @@ def update_weight_all(self):
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash = pd.Series(dtype=float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]
Expand Down
108 changes: 62 additions & 46 deletions qlib/data/data.py
Expand Up @@ -6,7 +6,9 @@
from __future__ import print_function

import os
import re
import abc
import copy
import time
import queue
import bisect
Expand All @@ -27,12 +29,41 @@
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path


class CalendarProvider(abc.ABC):
class ProviderBackendMixin:
def get_default_backend(self):
backend = {}
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
backend.setdefault("module_path", "qlib.data.storage.file_storage")
return backend

def backend_obj(self, **kwargs):
backend = self.backend if self.backend else self.get_default_backend()
backend = copy.deepcopy(backend)

# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {})
# default provider_uri map
if "provider_uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
freq = kwargs.get("freq", "day")
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)


class CalendarProvider(abc.ABC, ProviderBackendMixin):
"""Calendar provider base class
Provide calendar data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})

@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
"""Get calendar of certain market in given time range.
Expand Down Expand Up @@ -127,12 +158,15 @@ def _uri(self, start_time, end_time, freq, future=False):
return hash_args(start_time, end_time, freq, future)


class InstrumentProvider(abc.ABC):
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
"""Instrument provider base class
Provide instrument data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})

@staticmethod
def instruments(market="all", filter_pipe=None):
"""Get the general config dictionary for a base market adding several dynamic filters.
Expand Down Expand Up @@ -215,12 +249,15 @@ def get_inst_type(cls, inst):
raise ValueError(f"Unknown instrument type {inst}")


class FeatureProvider(abc.ABC):
class FeatureProvider(abc.ABC, ProviderBackendMixin):
"""Feature provider class
Provide feature data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})

@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
"""Get feature data.
Expand Down Expand Up @@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider):
"""

def __init__(self, **kwargs):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)

@property
Expand All @@ -517,21 +555,22 @@ def load_calendar(self, freq, future):
list
list of timestamps
"""
if future:
fname = self._uri_cal.format(freq + "_future")
# if future calendar not exists, return current calendar
if not os.path.exists(fname):
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")

try:
backend_obj = self.backend_obj(freq=freq, future=future).data
except ValueError:
if future:
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
fname = self._uri_cal.format(freq)
else:
fname = self._uri_cal.format(freq)
if not os.path.exists(fname):
raise ValueError("calendar not exists for freq " + freq)
with open(fname) as f:
return [pd.Timestamp(x.strip()) for x in f]
backend_obj = self.backend_obj(freq=freq, future=False).data
else:
raise

return [pd.Timestamp(x) for x in backend_obj]

def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
Expand Down Expand Up @@ -562,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""

def __init__(self):
pass

@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")

def _load_instruments(self, market):
fname = self._uri_inst.format(market)
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)

_instruments = dict()
df = pd.read_csv(
fname,
sep="\t",
usecols=[0, 1, 2],
names=["inst", "start_datetime", "end_datetime"],
dtype={"inst": str},
parse_dates=["start_datetime", "end_datetime"],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data

def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
if market in H["i"]:
_instruments = H["i"][market]
else:
_instruments = self._load_instruments(market)
_instruments = self._load_instruments(market, freq=freq)
H["i"][market] = _instruments
# strip
# use calendar boundary
Expand All @@ -604,7 +625,7 @@ def list_instruments(self, instruments, start_time=None, end_time=None, freq="da
inst: list(
filter(
lambda x: x[0] <= x[1],
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
)
)
for inst, spans in _instruments.items()
Expand All @@ -630,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider):
"""

def __init__(self, **kwargs):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)

@property
Expand All @@ -641,14 +663,7 @@ def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
return pd.Series(dtype=np.float32)
# raise ValueError('uri_data not found: ' + uri_data)
# load
series = read_bin(uri_data, start_index, end_index)
return series
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]


class LocalExpressionProvider(ExpressionProvider):
Expand Down Expand Up @@ -1065,7 +1080,8 @@ def register_all_wrappers(C):
register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")

register_wrapper(Inst, C.instrument_provider, "qlib.data")
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
register_wrapper(Inst, _instrument_provider, "qlib.data")
logger.debug(f"registering Inst {C.instrument_provider}")

if getattr(C, "feature_provider", None) is not None:
Expand Down
2 changes: 1 addition & 1 deletion qlib/data/dataset/__init__.py
Expand Up @@ -357,7 +357,7 @@ def build_index(data: pd.DataFrame) -> dict:
# get the previous index of a line given index
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
idx_df = lazy_sort_index(idx_df.unstack())
# NOTE: the correctness of `__getitem__` depends on columns sorted here
idx_df = lazy_sort_index(idx_df, axis=1)
Expand Down
4 changes: 4 additions & 0 deletions qlib/data/storage/__init__.py
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT

0 comments on commit 5016d8c

Please sign in to comment.