From bcf34435b45ad429bbd2650a8f22ae21bc4866e0 Mon Sep 17 00:00:00 2001 From: "lin.dongzhao" <542698096@qq.com> Date: Mon, 18 Mar 2024 16:08:33 +0800 Subject: [PATCH] pr update --- rqalpha/data/bundle.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/rqalpha/data/bundle.py b/rqalpha/data/bundle.py index 04ad75403..362e3374f 100644 --- a/rqalpha/data/bundle.py +++ b/rqalpha/data/bundle.py @@ -678,10 +678,6 @@ def _auto_update_task(self, instrument): start_date = START_DATE try: h5 = h5py.File(self._file, "a") - except OSError as e: - raise OSError(_("File {} update failed, if it is using, please update later, " - "or you can delete then update again".format(self._file))) from e - try: if order_book_id in h5: if len(h5[order_book_id][:]) != 0: last_date = datetime.datetime.strptime(str(h5[order_book_id][-1]['trading_dt']), "%Y%m%d").date() @@ -692,8 +688,8 @@ def _auto_update_task(self, instrument): return arr = self._get_array(instrument, start_date) if arr is None: - arr = np.array([]) if order_book_id not in h5: + arr = np.array([]) h5.create_dataset(order_book_id, data=arr) else: if order_book_id in h5: @@ -704,6 +700,9 @@ def _auto_update_task(self, instrument): h5.create_dataset(order_book_id, data=data) else: h5.create_dataset(order_book_id, data=arr) + except OSError as e: + raise OSError(_("File {} update failed, if it is using, please update later, " + "or you can delete then update again".format(self._file))) from e finally: h5.close() @@ -716,14 +715,20 @@ def _get_array(self, instrument, start_date): dtype = [('trading_dt', 'int')] for field in self._fields: dtype.append((field, record.dtype[field])) - update_list = [] - for index, row in df.iterrows(): - values = [row[field] for field in self._fields] - update_data = tuple([convert_date_to_date_int( - self._env.data_proxy._data_source.get_future_trading_date(index) - )] + values) - update_list.append(update_data) - arr = np.array(update_list, dtype=dtype) + + dt_arr = np.array(df.index.tolist()).reshape((-1, 1)) + dt_arr = np.apply_along_axis(self._get_trading_date, 1, dt_arr) + dt_arr = (dt_arr.astype('datetime64[Y]').astype(int) + 1970) * 10000 + (dt_arr.astype('datetime64[M]').astype(int) % 12 + 1) * 100 + (dt_arr.astype('datetime64[D]') - dt_arr.astype('datetime64[M]') + 1) + dt_arr.astype(int) + arr = np.ones((dt_arr.shape[0], ), dtype=dtype) + arr['trading_dt'] = dt_arr + for field in self._fields: + arr[field] = df[field].values return arr return None - \ No newline at end of file + + def _get_trading_date(self, dt): + # type: (numpy.ndarray) -> Timestamp + dt = dt[0] + dt = self._env.data_proxy._data_source.get_future_trading_date(dt) + return dt