Skip to content

Commit

Permalink
pr update
Browse files Browse the repository at this point in the history
  • Loading branch information
Lin-Dongzhao committed Mar 18, 2024
1 parent 7c0b56e commit bcf3443
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions rqalpha/data/bundle.py
Expand Up @@ -678,10 +678,6 @@ def _auto_update_task(self, instrument):
start_date = START_DATE
try:
h5 = h5py.File(self._file, "a")
except OSError as e:
raise OSError(_("File {} update failed, if it is using, please update later, "
"or you can delete then update again".format(self._file))) from e
try:
if order_book_id in h5:
if len(h5[order_book_id][:]) != 0:
last_date = datetime.datetime.strptime(str(h5[order_book_id][-1]['trading_dt']), "%Y%m%d").date()
Expand All @@ -692,8 +688,8 @@ def _auto_update_task(self, instrument):
return
arr = self._get_array(instrument, start_date)
if arr is None:
arr = np.array([])
if order_book_id not in h5:
arr = np.array([])
h5.create_dataset(order_book_id, data=arr)
else:
if order_book_id in h5:
Expand All @@ -704,6 +700,9 @@ def _auto_update_task(self, instrument):
h5.create_dataset(order_book_id, data=data)
else:
h5.create_dataset(order_book_id, data=arr)
except OSError as e:
raise OSError(_("File {} update failed, if it is using, please update later, "
"or you can delete then update again".format(self._file))) from e
finally:
h5.close()

Expand All @@ -716,14 +715,20 @@ def _get_array(self, instrument, start_date):
dtype = [('trading_dt', 'int')]
for field in self._fields:
dtype.append((field, record.dtype[field]))
update_list = []
for index, row in df.iterrows():
values = [row[field] for field in self._fields]
update_data = tuple([convert_date_to_date_int(
self._env.data_proxy._data_source.get_future_trading_date(index)
)] + values)
update_list.append(update_data)
arr = np.array(update_list, dtype=dtype)

dt_arr = np.array(df.index.tolist()).reshape((-1, 1))
dt_arr = np.apply_along_axis(self._get_trading_date, 1, dt_arr)
dt_arr = (dt_arr.astype('datetime64[Y]').astype(int) + 1970) * 10000 + (dt_arr.astype('datetime64[M]').astype(int) % 12 + 1) * 100 + (dt_arr.astype('datetime64[D]') - dt_arr.astype('datetime64[M]') + 1)
dt_arr.astype(int)
arr = np.ones((dt_arr.shape[0], ), dtype=dtype)
arr['trading_dt'] = dt_arr
for field in self._fields:
arr[field] = df[field].values
return arr
return None


def _get_trading_date(self, dt):
# type: (numpy.ndarray) -> Timestamp
dt = dt[0]
dt = self._env.data_proxy._data_source.get_future_trading_date(dt)
return dt

0 comments on commit bcf3443

Please sign in to comment.