From 56a123919b88a01d5c1a25ffe3b40ee3b15368aa Mon Sep 17 00:00:00 2001 From: "lin.dongzhao" <542698096@qq.com> Date: Mon, 25 Mar 2024 16:33:56 +0800 Subject: [PATCH] pr update --- rqalpha/data/bundle.py | 20 +++---------------- rqalpha/data/trading_dates_mixin.py | 3 +-- .../test_auto_update_bundle_mixin.py | 2 +- 3 files changed, 5 insertions(+), 20 deletions(-) diff --git a/rqalpha/data/bundle.py b/rqalpha/data/bundle.py index 51b65caa2..5c8ccca51 100644 --- a/rqalpha/data/bundle.py +++ b/rqalpha/data/bundle.py @@ -633,16 +633,15 @@ def update_futures_trading_parameters(path, end_date): class AutomaticUpdateBundle(object): - def __init__(self, path: str, filename: str, rqdata_api: Callable, fields: List[str], end_date: datetime.date, completion: bool =False) -> None: + def __init__(self, path: str, filename: str, api: Callable, fields: List[str], end_date: datetime.date) -> None: if not os.path.exists(path): os.makedirs(path) self._file = os.path.join(path, filename) self._trading_dates = None self._filename = filename - self._rqdata_api = rqdata_api + self._api = api self._fields = fields self._end_date = end_date - self._completion = completion # 缓存 h5 文件时,是否需要对缺失数据的日期进行补 0 self.updated = [] self._env = Environment.get_instance() @@ -709,7 +708,7 @@ def _auto_update_task(self, instrument: Instrument) -> None: h5.close() def _get_array(self, instrument: Instrument, start_date: datetime.date) -> Optional[np.ndarray]: - df = self._rqdata_api(instrument.order_book_id, start_date, self._end_date, self._fields) + df = self._api(instrument.order_book_id, start_date, self._end_date, self._fields) if not (df is None or df.empty): df = df[self._fields].loc[instrument.order_book_id] # rqdatac.get_open_auction_info get Futures's data will auto add 'open_interest' and 'prev_settlement' record = df.iloc[0: 1].to_records() @@ -722,18 +721,5 @@ def _get_array(self, instrument: Instrument, start_date: datetime.date) -> Optio arr['trading_dt'] = trading_dt for field in self._fields: arr[field] = df[field].values - if self._completion: - arr = self._completion_zero(instrument, arr) return arr return None - - def _completion_zero(self, instrument: Instrument, arr: np.ndarray) -> np.ndarray: - completion_start_date = max(instrument.listed_date.date(), datetime.date(2005, 1, 4)) - trading_dates = self._env.data_proxy._data_source.get_trading_dates(completion_start_date, self._end_date) - trading_dates = convert_date_to_date_int(trading_dates) - completion_dt = np.array(list(set(trading_dates).difference(set(arr['trading_dt'])))) - arr_zero = np.zeros((completion_dt.shape[0], ), dtype=arr.dtype) - arr_zero['trading_dt'] = completion_dt - arr = np.sort(np.concatenate((arr, arr_zero)), order="trading_dt") - return arr - diff --git a/rqalpha/data/trading_dates_mixin.py b/rqalpha/data/trading_dates_mixin.py index 92445c765..920106804 100644 --- a/rqalpha/data/trading_dates_mixin.py +++ b/rqalpha/data/trading_dates_mixin.py @@ -116,8 +116,7 @@ def _get_future_trading_date(self, dt): return td - def batch_get_trading_date(self, dt_index): - # type: (DatetimeIndex) -> DatetimeIndex + def batch_get_trading_date(self, dt_index: pd.DatetimeIndex): # 获取 numpy.array 中所有时间所在的交易日 # 认为晚八点后为第二个交易日,认为晚八点至次日凌晨四点为夜盘 dt = dt_index - datetime.timedelta(hours=4) diff --git a/tests/unittest/test_data/test_auto_update_bundle/test_auto_update_bundle_mixin.py b/tests/unittest/test_data/test_auto_update_bundle/test_auto_update_bundle_mixin.py index 817cd74a1..57f36d3f7 100644 --- a/tests/unittest/test_data/test_auto_update_bundle/test_auto_update_bundle_mixin.py +++ b/tests/unittest/test_data/test_auto_update_bundle/test_auto_update_bundle_mixin.py @@ -32,7 +32,7 @@ def init_fixture(self): self._auto_update_bundle_module = AutomaticUpdateBundle( path=self._path, filename="open_auction_volume.h5", - rqdata_api=self._mock_get_open_auction_info, + api=self._mock_get_open_auction_info, fields=['volume'], end_date=datetime.date(2024, 2, 28), )