diff --git a/docs/reference/api.rst b/docs/reference/api.rst index 57f61f18b1..5e6e50b0ba 100644 --- a/docs/reference/api.rst +++ b/docs/reference/api.rst @@ -53,6 +53,34 @@ Cache .. autoclass:: qlib.data.cache.DiskDatasetCache :members: + +Storage +------------- +.. autoclass:: qlib.data.storage.storage.BaseStorage + :members: + +.. autoclass:: qlib.data.storage.storage.CalendarStorage + :members: + +.. autoclass:: qlib.data.storage.storage.InstrumentStorage + :members: + +.. autoclass:: qlib.data.storage.storage.FeatureStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage + :members: + + Dataset --------------- diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 09313f9336..8a4e137ca3 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -166,7 +166,7 @@ def update_weight_all(self): def save_position(self, path, last_trade_date): path = pathlib.Path(path) p = copy.deepcopy(self.position) - cash = pd.Series(dtype=np.float) + cash = pd.Series(dtype=float) cash["init_cash"] = self.init_cash cash["cash"] = p["cash"] cash["today_account_value"] = p["today_account_value"] diff --git a/qlib/data/data.py b/qlib/data/data.py index c2638e2344..eb7fbe0ead 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -6,7 +6,9 @@ from __future__ import print_function import os +import re import abc +import copy import time import queue import bisect @@ -27,12 +29,41 @@ from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path -class CalendarProvider(abc.ABC): +class ProviderBackendMixin: + def get_default_backend(self): + backend = {} + provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] + # set default storage class + backend.setdefault("class", f"File{provider_name}Storage") + # set default storage module + backend.setdefault("module_path", "qlib.data.storage.file_storage") + return backend + + def backend_obj(self, **kwargs): + backend = self.backend if self.backend else self.get_default_backend() + backend = copy.deepcopy(backend) + + # set default storage kwargs + backend_kwargs = backend.setdefault("kwargs", {}) + # default provider_uri map + if "provider_uri" not in backend_kwargs: + # if the user has no uri configured, use: uri = uri_map[freq] + freq = kwargs.get("freq", "day") + provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()}) + backend_kwargs["provider_uri"] = provider_uri_map[freq] + backend.setdefault("kwargs", {}).update(**kwargs) + return init_instance_by_config(backend) + + +class CalendarProvider(abc.ABC, ProviderBackendMixin): """Calendar provider base class Provide calendar data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + @abc.abstractmethod def calendar(self, start_time=None, end_time=None, freq="day", future=False): """Get calendar of certain market in given time range. @@ -127,12 +158,15 @@ def _uri(self, start_time, end_time, freq, future=False): return hash_args(start_time, end_time, freq, future) -class InstrumentProvider(abc.ABC): +class InstrumentProvider(abc.ABC, ProviderBackendMixin): """Instrument provider base class Provide instrument data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + @staticmethod def instruments(market="all", filter_pipe=None): """Get the general config dictionary for a base market adding several dynamic filters. @@ -215,12 +249,15 @@ def get_inst_type(cls, inst): raise ValueError(f"Unknown instrument type {inst}") -class FeatureProvider(abc.ABC): +class FeatureProvider(abc.ABC, ProviderBackendMixin): """Feature provider class Provide feature data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + @abc.abstractmethod def feature(self, instrument, field, start_time, end_time, freq): """Get feature data. @@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider): """ def __init__(self, **kwargs): + super(LocalCalendarProvider, self).__init__(**kwargs) self.remote = kwargs.get("remote", False) @property @@ -517,21 +555,22 @@ def load_calendar(self, freq, future): list list of timestamps """ - if future: - fname = self._uri_cal.format(freq + "_future") - # if future calendar not exists, return current calendar - if not os.path.exists(fname): - get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") + + try: + backend_obj = self.backend_obj(freq=freq, future=future).data + except ValueError: + if future: + get_module_logger("data").warning( + f"load calendar error: freq={freq}, future={future}; return current calendar!" + ) get_module_logger("data").warning( "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" ) - fname = self._uri_cal.format(freq) - else: - fname = self._uri_cal.format(freq) - if not os.path.exists(fname): - raise ValueError("calendar not exists for freq " + freq) - with open(fname) as f: - return [pd.Timestamp(x.strip()) for x in f] + backend_obj = self.backend_obj(freq=freq, future=False).data + else: + raise + + return [pd.Timestamp(x) for x in backend_obj] def calendar(self, start_time=None, end_time=None, freq="day", future=False): _calendar, _calendar_index = self._get_calendar(freq, future) @@ -562,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider): Provide instrument data from local data source. """ - def __init__(self): - pass - @property def _uri_inst(self): """Instrument file uri.""" return os.path.join(C.get_data_path(), "instruments", "{}.txt") - def _load_instruments(self, market): - fname = self._uri_inst.format(market) - if not os.path.exists(fname): - raise ValueError("instruments not exists for market " + market) - - _instruments = dict() - df = pd.read_csv( - fname, - sep="\t", - usecols=[0, 1, 2], - names=["inst", "start_datetime", "end_datetime"], - dtype={"inst": str}, - parse_dates=["start_datetime", "end_datetime"], - ) - for row in df.itertuples(index=False): - _instruments.setdefault(row[0], []).append((row[1], row[2])) - return _instruments + def _load_instruments(self, market, freq): + return self.backend_obj(market=market, freq=freq).data def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): market = instruments["market"] if market in H["i"]: _instruments = H["i"][market] else: - _instruments = self._load_instruments(market) + _instruments = self._load_instruments(market, freq=freq) H["i"][market] = _instruments # strip # use calendar boundary @@ -604,7 +625,7 @@ def list_instruments(self, instruments, start_time=None, end_time=None, freq="da inst: list( filter( lambda x: x[0] <= x[1], - [(max(start_time, x[0]), min(end_time, x[1])) for x in spans], + [(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans], ) ) for inst, spans in _instruments.items() @@ -630,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider): """ def __init__(self, **kwargs): + super(LocalFeatureProvider, self).__init__(**kwargs) self.remote = kwargs.get("remote", False) @property @@ -641,14 +663,7 @@ def feature(self, instrument, field, start_index, end_index, freq): # validate field = str(field).lower()[1:] instrument = code_to_fname(instrument) - uri_data = self._uri_data.format(instrument.lower(), field, freq) - if not os.path.exists(uri_data): - get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field)) - return pd.Series(dtype=np.float32) - # raise ValueError('uri_data not found: ' + uri_data) - # load - series = read_bin(uri_data, start_index, end_index) - return series + return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] class LocalExpressionProvider(ExpressionProvider): @@ -1065,7 +1080,8 @@ def register_all_wrappers(C): register_wrapper(Cal, _calendar_provider, "qlib.data") logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}") - register_wrapper(Inst, C.instrument_provider, "qlib.data") + _instrument_provider = init_instance_by_config(C.instrument_provider, module) + register_wrapper(Inst, _instrument_provider, "qlib.data") logger.debug(f"registering Inst {C.instrument_provider}") if getattr(C, "feature_provider", None) is not None: diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 206561aed9..8d77863684 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -357,7 +357,7 @@ def build_index(data: pd.DataFrame) -> dict: # get the previous index of a line given index """ # object incase of pandas converting int to flaot - idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object) + idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object) idx_df = lazy_sort_index(idx_df.unstack()) # NOTE: the correctness of `__getitem__` depends on columns sorted here idx_df = lazy_sort_index(idx_df, axis=1) diff --git a/qlib/data/storage/__init__.py b/qlib/data/storage/__init__.py new file mode 100644 index 0000000000..552e1e3e8e --- /dev/null +++ b/qlib/data/storage/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py new file mode 100644 index 0000000000..a2b145c4df --- /dev/null +++ b/qlib/data/storage/file_storage.py @@ -0,0 +1,292 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import struct +from pathlib import Path +from typing import Iterable, Union, Dict, Mapping, Tuple, List + +import numpy as np +import pandas as pd + +from qlib.log import get_module_logger +from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT + +logger = get_module_logger("file_storage") + + +class FileStorageMixin: + @property + def uri(self) -> Path: + _provider_uri = self.kwargs.get("provider_uri", None) + if _provider_uri is None: + raise ValueError( + f"The `provider_uri` parameter is not found in {self.__class__.__name__}, " + f'please specify `provider_uri` in the "provider\'s backend"' + ) + return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name) + + def check(self): + """check self.uri + + Raises + ------- + ValueError + """ + if not self.uri.exists(): + raise ValueError(f"{self.storage_name} not exists: {self.uri}") + + +class FileCalendarStorage(FileStorageMixin, CalendarStorage): + def __init__(self, freq: str, future: bool, **kwargs): + super(FileCalendarStorage, self).__init__(freq, future, **kwargs) + self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower() + + def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]: + if not self.uri.exists(): + self._write_calendar(values=[]) + with self.uri.open("rb") as fp: + return [ + str(x) + for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8") + ] + + def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"): + with self.uri.open(mode=mode) as fp: + np.savetxt(fp, values, fmt="%s", encoding="utf-8") + + @property + def data(self) -> List[CalVT]: + self.check() + return self._read_calendar() + + def extend(self, values: Iterable[CalVT]) -> None: + self._write_calendar(values, mode="ab") + + def clear(self) -> None: + self._write_calendar(values=[]) + + def index(self, value: CalVT) -> int: + self.check() + calendar = self._read_calendar() + return int(np.argwhere(calendar == value)[0]) + + def insert(self, index: int, value: CalVT): + calendar = self._read_calendar() + calendar = np.insert(calendar, index, value) + self._write_calendar(values=calendar) + + def remove(self, value: CalVT) -> None: + self.check() + index = self.index(value) + calendar = self._read_calendar() + calendar = np.delete(calendar, index) + self._write_calendar(values=calendar) + + def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None: + calendar = self._read_calendar() + calendar[i] = values + self._write_calendar(values=calendar) + + def __delitem__(self, i: Union[int, slice]) -> None: + self.check() + calendar = self._read_calendar() + calendar = np.delete(calendar, i) + self._write_calendar(values=calendar) + + def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]: + self.check() + return self._read_calendar()[i] + + def __len__(self) -> int: + return len(self.data) + + +class FileInstrumentStorage(FileStorageMixin, InstrumentStorage): + + INSTRUMENT_SEP = "\t" + INSTRUMENT_START_FIELD = "start_datetime" + INSTRUMENT_END_FIELD = "end_datetime" + SYMBOL_FIELD_NAME = "instrument" + + def __init__(self, market: str, **kwargs): + super(FileInstrumentStorage, self).__init__(market, **kwargs) + self.file_name = f"{market.lower()}.txt" + + def _read_instrument(self) -> Dict[InstKT, InstVT]: + if not self.uri.exists(): + self._write_instrument() + + _instruments = dict() + df = pd.read_csv( + self.uri, + sep="\t", + usecols=[0, 1, 2], + names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], + dtype={self.SYMBOL_FIELD_NAME: str}, + parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], + ) + for row in df.itertuples(index=False): + _instruments.setdefault(row[0], []).append((row[1], row[2])) + return _instruments + + def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None: + if not data: + with self.uri.open("w") as _: + pass + return + + res = [] + for inst, v_list in data.items(): + _df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]) + _df[self.SYMBOL_FIELD_NAME] = inst + res.append(_df) + + df = pd.concat(res, sort=False) + df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv( + self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False + ) + df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False) + + def clear(self) -> None: + self._write_instrument(data={}) + + @property + def data(self) -> Dict[InstKT, InstVT]: + self.check() + return self._read_instrument() + + def __setitem__(self, k: InstKT, v: InstVT) -> None: + inst = self._read_instrument() + inst[k] = v + self._write_instrument(inst) + + def __delitem__(self, k: InstKT) -> None: + self.check() + inst = self._read_instrument() + del inst[k] + self._write_instrument(inst) + + def __getitem__(self, k: InstKT) -> InstVT: + self.check() + return self._read_instrument()[k] + + def update(self, *args, **kwargs) -> None: + + if len(args) > 1: + raise TypeError(f"update expected at most 1 arguments, got {len(args)}") + inst = self._read_instrument() + if args: + other = args[0] # type: dict + if isinstance(other, Mapping): + for key in other: + inst[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + inst[key] = other[key] + else: + for key, value in other: + inst[key] = value + for key, value in kwargs.items(): + inst[key] = value + + self._write_instrument(inst) + + def __len__(self) -> int: + return len(self.data) + + +class FileFeatureStorage(FileStorageMixin, FeatureStorage): + def __init__(self, instrument: str, field: str, freq: str, **kwargs): + super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs) + self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin" + + def clear(self): + with self.uri.open("wb") as _: + pass + + @property + def data(self) -> pd.Series: + return self[:] + + def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None: + if len(data_array) == 0: + logger.info( + "len(data_array) == 0, write" + "if you need to clear the FeatureStorage, please execute: FeatureStorage.clear" + ) + return + if not self.uri.exists(): + # write + index = 0 if index is None else index + with self.uri.open("wb") as fp: + np.hstack([index, data_array]).astype(" self.end_index: + # append + index = 0 if index is None else index + with self.uri.open("ab+") as fp: + np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype(" Union[int, None]: + if not self.uri.exists(): + return None + with self.uri.open("rb") as fp: + index = int(np.frombuffer(fp.read(4), dtype=" Union[int, None]: + if not self.uri.exists(): + return None + # The next data appending index point will be `end_index + 1` + return self.start_index + len(self) - 1 + + def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]: + if not self.uri.exists(): + if isinstance(i, int): + return None, None + elif isinstance(i, slice): + return pd.Series(dtype=np.float32) + else: + raise TypeError(f"type(i) = {type(i)}") + + storage_start_index = self.start_index + storage_end_index = self.end_index + with self.uri.open("rb") as fp: + if isinstance(i, int): + + if storage_start_index > i: + raise IndexError(f"{i}: start index is {storage_start_index}") + fp.seek(4 * (i - storage_start_index) + 4) + return i, struct.unpack("f", fp.read(4))[0] + elif isinstance(i, slice): + start_index = storage_start_index if i.start is None else i.start + end_index = storage_end_index if i.stop is None else i.stop - 1 + si = max(start_index, storage_start_index) + if si > end_index: + return pd.Series(dtype=np.float32) + fp.seek(4 * (si - storage_start_index) + 4) + # read n bytes + count = end_index - si + 1 + data = np.frombuffer(fp.read(4 * count), dtype=" int: + self.check() + return self.uri.stat().st_size // 4 - 1 diff --git a/qlib/data/storage/storage.py b/qlib/data/storage/storage.py new file mode 100644 index 0000000000..8426ebe66f --- /dev/null +++ b/qlib/data/storage/storage.py @@ -0,0 +1,501 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import re +from typing import Iterable, overload, Tuple, List, Text, Union, Dict + +import numpy as np +import pandas as pd +from qlib.log import get_module_logger + +# calendar value type +CalVT = str + +# instrument value +InstVT = List[Tuple[CalVT, CalVT]] +# instrument key +InstKT = Text + +logger = get_module_logger("storage") + +""" +If the user is only using it in `qlib`, you can customize Storage to implement only the following methods: + +class UserCalendarStorage(CalendarStorage): + + @property + def data(self) -> Iterable[CalVT]: + '''get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + ''' + raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") + + +class UserInstrumentStorage(InstrumentStorage): + + @property + def data(self) -> Dict[InstKT, InstVT]: + '''get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + ''' + raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") + + +class UserFeatureStorage(FeatureStorage): + + def __getitem__(self, s: slice) -> pd.Series: + '''x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] + + Returns + ------- + pd.Series(values, index=pd.RangeIndex(start, len(values)) + + Notes + ------- + if data(storage) does not exist: + if isinstance(i, int): + return (None, None) + if isinstance(i, slice): + # return empty pd.Series + return pd.Series(dtype=np.float32) + ''' + raise NotImplementedError( + "Subclass of FeatureStorage must implement `__getitem__(s: slice)` method" + ) + + +""" + + +class BaseStorage: + @property + def storage_name(self) -> str: + return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2].lower() + + +class CalendarStorage(BaseStorage): + """ + The behavior of CalendarStorage's methods and List's methods of the same name remain consistent + """ + + def __init__(self, freq: str, future: bool, **kwargs): + self.freq = freq + self.future = future + self.kwargs = kwargs + + @property + def data(self) -> Iterable[CalVT]: + """get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") + + def clear(self) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method") + + def extend(self, iterable: Iterable[CalVT]) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") + + def index(self, value: CalVT) -> int: + """ + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError("Subclass of CalendarStorage must implement `index` method") + + def insert(self, index: int, value: CalVT) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method") + + def remove(self, value: CalVT) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `remove` method") + + @overload + def __setitem__(self, i: int, value: CalVT) -> None: + """x.__setitem__(i, o) <==> (x[i] = o)""" + ... + + @overload + def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None: + """x.__setitem__(s, o) <==> (x[s] = o)""" + ... + + def __setitem__(self, i, value) -> None: + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])` method" + ) + + @overload + def __delitem__(self, i: int) -> None: + """x.__delitem__(i) <==> del x[i]""" + ... + + @overload + def __delitem__(self, i: slice) -> None: + """x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]""" + ... + + def __delitem__(self, i) -> None: + """ + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)` method" + ) + + @overload + def __getitem__(self, s: slice) -> Iterable[CalVT]: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" + ... + + @overload + def __getitem__(self, i: int) -> CalVT: + """x.__getitem__(i) <==> x[i]""" + ... + + def __getitem__(self, i) -> CalVT: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" + ) + + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") + + +class InstrumentStorage(BaseStorage): + def __init__(self, market: str, **kwargs): + self.market = market + self.kwargs = kwargs + + @property + def data(self) -> Dict[InstKT, InstVT]: + """get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") + + def clear(self) -> None: + raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") + + def update(self, *args, **kwargs) -> None: + """D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. + + Notes + ------ + If E present and has a .keys() method, does: for k in E: D[k] = E[k] + + If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v + + In either case, this is followed by: for k, v in F.items(): D[k] = v + + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method") + + def __setitem__(self, k: InstKT, v: InstVT) -> None: + """Set self[key] to value.""" + raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") + + def __delitem__(self, k: InstKT) -> None: + """Delete self[key]. + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method") + + def __getitem__(self, k: InstKT) -> InstVT: + """x.__getitem__(k) <==> x[k]""" + raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") + + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method") + + +class FeatureStorage(BaseStorage): + def __init__(self, instrument: str, field: str, freq: str, **kwargs): + self.instrument = instrument + self.field = field + self.freq = freq + self.kwargs = kwargs + + @property + def data(self) -> pd.Series: + """get all data + + Notes + ------ + if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)` + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") + + @property + def start_index(self) -> Union[int, None]: + """get FeatureStorage start index + + Notes + ----- + If the data(storage) does not exist, return None + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `start_index` method") + + @property + def end_index(self) -> Union[int, None]: + """get FeatureStorage end index + + Notes + ----- + The right index of the data range (both sides are closed) + + The next data appending point will be `end_index + 1` + + If the data(storage) does not exist, return None + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `end_index` method") + + def clear(self) -> None: + raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") + + def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None): + """Write data_array to FeatureStorage starting from index. + + Notes + ------ + If index is None, append data_array to feature. + + If len(data_array) == 0; return + + If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan + + Examples + --------- + .. code-block:: + + feature: + 3 4 + 4 5 + 5 6 + + + >>> self.write([6, 7], index=6) + + feature: + 3 4 + 4 5 + 5 6 + 6 6 + 7 7 + + >>> self.write([8], index=9) + + feature: + 3 4 + 4 5 + 5 6 + 6 6 + 7 7 + 8 np.nan + 9 8 + + >>> self.write([1, np.nan], index=3) + + feature: + 3 1 + 4 np.nan + 5 6 + 6 6 + 7 7 + 8 np.nan + 9 8 + + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `write` method") + + def rebase(self, start_index: int = None, end_index: int = None): + """Rebase the start_index and end_index of the FeatureStorage. + + start_index and end_index are closed intervals: [start_index, end_index] + + Examples + --------- + + .. code-block:: + + feature: + 3 4 + 4 5 + 5 6 + + + >>> self.rebase(start_index=4) + + feature: + 4 5 + 5 6 + + >>> self.rebase(start_index=3) + + feature: + 3 np.nan + 4 5 + 5 6 + + >>> self.write([3], index=3) + + feature: + 3 3 + 4 5 + 5 6 + + >>> self.rebase(end_index=4) + + feature: + 3 3 + 4 5 + + >>> self.write([6, 7, 8], index=4) + + feature: + 3 3 + 4 6 + 5 7 + 6 8 + + >>> self.rebase(start_index=4, end_index=5) + + feature: + 4 6 + 5 7 + + """ + storage_si = self.start_index + storage_ei = self.end_index + if storage_si is None or storage_ei is None: + raise ValueError("storage.start_index or storage.end_index is None, storage may not exist") + + start_index = storage_si if start_index is None else start_index + end_index = storage_ei if end_index is None else end_index + + if start_index is None or end_index is None: + logger.warning("both start_index and end_index are None, or storage does not exist; rebase is ignored") + return + + if start_index < 0 or end_index < 0: + logger.warning("start_index or end_index cannot be less than 0") + return + if start_index > end_index: + logger.warning( + f"start_index({start_index}) > end_index({end_index}), rebase is ignored; " + f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear" + ) + return + + if start_index <= storage_si: + self.write([np.nan] * (storage_si - start_index), start_index) + else: + self.rewrite(self[start_index:].values, start_index) + + if end_index >= self.end_index: + self.write([np.nan] * (end_index - self.end_index)) + else: + self.rewrite(self[: end_index + 1].values, start_index) + + def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int): + """overwrite all data in FeatureStorage with data + + Parameters + ---------- + data: Union[List, np.ndarray, Tuple] + data + index: int + data start index + """ + self.clear() + self.write(data, index) + + @overload + def __getitem__(self, s: slice) -> pd.Series: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] + + Returns + ------- + pd.Series(values, index=pd.RangeIndex(start, len(values)) + """ + ... + + @overload + def __getitem__(self, i: int) -> Tuple[int, float]: + """x.__getitem__(y) <==> x[y]""" + ... + + def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]: + """x.__getitem__(y) <==> x[y] + + Notes + ------- + if data(storage) does not exist: + if isinstance(i, int): + return (None, None) + if isinstance(i, slice): + # return empty pd.Series + return pd.Series(dtype=np.float32) + """ + raise NotImplementedError( + "Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" + ) + + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index f92e727875..8b53bc53a5 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -9,19 +9,19 @@ class TestAutoData(unittest.TestCase): _setup_kwargs = {} + provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir @classmethod def setUpClass(cls) -> None: # use default data - provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") + if not exists_qlib_data(cls.provider_uri): + print(f"Qlib data is not found in {cls.provider_uri}") GetData().qlib_data( name="qlib_data_simple", region="cn", interval="1d", - target_dir=provider_uri, + target_dir=cls.provider_uri, delete_old=False, ) - init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs) + init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index dbbe69d43e..1e8ee2e480 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -668,7 +668,10 @@ def exists_qlib_data(qlib_dir): return False # check calendar bin for _calendar in calendars_dir.iterdir(): - if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")): + + if ("_future" not in _calendar.name) and ( + not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")) + ): return False # check instruments diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 0b063fddac..b3a18cc902 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -120,7 +120,7 @@ def _get_date( else: df = file_or_df if df.empty or self.date_field_name not in df.columns.tolist(): - _calendars = pd.Series() + _calendars = pd.Series(dtype=np.float32) else: _calendars = df[self.date_field_name] diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py new file mode 100644 index 0000000000..aad8d11e48 --- /dev/null +++ b/tests/storage_tests/test_storage.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from pathlib import Path +from collections.abc import Iterable + +import pytest +import numpy as np +from qlib.tests import TestAutoData + +from qlib.data.storage.file_storage import ( + FileCalendarStorage as CalendarStorage, + FileInstrumentStorage as InstrumentStorage, + FileFeatureStorage as FeatureStorage, +) + +_file_name = Path(__file__).name.split(".")[0] +DATA_DIR = Path(__file__).parent.joinpath(f"{_file_name}_data") +QLIB_DIR = DATA_DIR.joinpath("qlib") +QLIB_DIR.mkdir(exist_ok=True, parents=True) + + +class TestStorage(TestAutoData): + def test_calendar_storage(self): + + calendar = CalendarStorage(freq="day", future=False, provider_uri=self.provider_uri) + assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable" + assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable" + + print(f"calendar[1: 5]: {calendar[1:5]}") + print(f"calendar[0]: {calendar[0]}") + print(f"calendar[-1]: {calendar[-1]}") + + calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found") + with pytest.raises(ValueError): + print(calendar.data) + + with pytest.raises(ValueError): + print(calendar[:]) + + with pytest.raises(ValueError): + print(calendar[0]) + + def test_instrument_storage(self): + """ + The meaning of instrument, such as CSI500: + + CSI500 composition changes: + + date add remove + 2005-01-01 SH600000 + 2005-01-01 SH600001 + 2005-01-01 SH600002 + 2005-02-01 SH600003 SH600000 + 2005-02-15 SH600000 SH600002 + + Calendar: + pd.date_range(start="2020-01-01", stop="2020-03-01", freq="1D") + + Instrument: + symbol start_time end_time + SH600000 2005-01-01 2005-01-31 (2005-02-01 Last trading day) + SH600000 2005-02-15 2005-03-01 + SH600001 2005-01-01 2005-03-01 + SH600002 2005-01-01 2005-02-14 (2005-02-15 Last trading day) + SH600003 2005-02-01 2005-03-01 + + InstrumentStorage: + { + "SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)], + "SH600001": [(2005-01-01, 2005-03-01)], + "SH600002": [(2005-01-01, 2005-02-14)], + "SH600003": [(2005-02-01, 2005-03-01)], + } + + """ + + instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri) + + for inst, spans in instrument.data.items(): + assert isinstance(inst, str) and isinstance( + spans, Iterable + ), f"{instrument.__class__.__name__} value is not Iterable" + for s_e in spans: + assert ( + isinstance(s_e, tuple) and len(s_e) == 2 + ), f"{instrument.__class__.__name__}.__getitem__(k) TypeError" + + print(f"instrument['SH600000']: {instrument['SH600000']}") + + instrument = InstrumentStorage(market="csi300", provider_uri="not_found") + with pytest.raises(ValueError): + print(instrument.data) + + with pytest.raises(ValueError): + print(instrument["sSH600000"]) + + def test_feature_storage(self): + """ + Calendar: + pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D") + + Instrument: + { + "SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)], + "SH600001": [(2005-01-01, 2005-03-01)], + "SH600002": [(2005-01-01, 2005-02-14)], + "SH600003": [(2005-02-01, 2005-03-01)], + } + + Feature: + Stock data(close): + 2005-01-01 ... 2005-02-01 ... 2005-02-14 2005-02-15 ... 2005-03-01 + SH600000 1 ... 3 ... 4 5 6 + SH600001 1 ... 4 ... 5 6 7 + SH600002 1 ... 5 ... 6 nan nan + SH600003 nan ... 1 ... 2 3 4 + + FeatureStorage(SH600000, close): + + [ + (calendar.index("2005-01-01"), 1), + ..., + (calendar.index("2005-03-01"), 6) + ] + + ====> [(0, 1), ..., (59, 6)] + + + FeatureStorage(SH600002, close): + + [ + (calendar.index("2005-01-01"), 1), + ..., + (calendar.index("2005-02-14"), 6) + ] + + ===> [(0, 1), ..., (44, 6)] + + FeatureStorage(SH600003, close): + + [ + (calendar.index("2005-02-01"), 1), + ..., + (calendar.index("2005-03-01"), 4) + ] + + ===> [(31, 1), ..., (59, 4)] + + """ + + feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri) + + with pytest.raises(IndexError): + print(feature[0]) + assert isinstance( + feature[815][1], (float, np.float32) + ), f"{feature.__class__.__name__}.__getitem__(i: int) error" + assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error" + print(f"feature[815: 818]: \n{feature[815: 818]}") + + print(f"feature[:].tail(): \n{feature[:].tail()}") + + feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount") + + assert feature[0] == (None, None), "FeatureStorage does not exist, feature[i] should return `(None, None)`" + assert feature[:].empty, "FeatureStorage does not exist, feature[:] should return `pd.Series(dtype=np.float32)`" + assert ( + feature.data.empty + ), "FeatureStorage does not exist, feature.data should return `pd.Series(dtype=np.float32)`"