Skip to content

Commit

Permalink
Merge pull request #860 from ricequant/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Cuizi7 committed Mar 27, 2024
2 parents 598bbf4 + 5407931 commit 814e820
Show file tree
Hide file tree
Showing 18 changed files with 276 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Expand Up @@ -6,7 +6,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
# Checks out a copy of your repository on the ubuntu-latest machine
- uses: actions/checkout@v2
Expand Down
5 changes: 5 additions & 0 deletions rqalpha/data/base_data_source/data_source.py
Expand Up @@ -400,3 +400,8 @@ def history_ticks(self, instrument, count, dt):

def get_algo_bar(self, id_or_ins, start_min, end_min, dt):
raise NotImplementedError("open source rqalpha not support algo order")

def get_open_auction_volume(self, instrument, dt):
# type: (Instrument, datetime.datetime) -> float
volume = self.get_open_auction_bar(instrument, dt)['volume']
return volume
103 changes: 101 additions & 2 deletions rqalpha/data/bundle.py
Expand Up @@ -18,17 +18,23 @@
import pickle
import re
from itertools import chain
from typing import Callable, Optional, List

import h5py
import numpy as np
import pandas as pd
from rqalpha.apis.api_rqdatac import rqdatac
from rqalpha.utils.concurrent import (ProgressedProcessPoolExecutor,
ProgressedTask)
from rqalpha.utils.datetime_func import (convert_date_to_date_int,
convert_date_to_int)
convert_date_to_int,)
from rqalpha.utils.exception import RQDatacVersionTooLow
from rqalpha.utils.i18n import gettext as _
from rqalpha.utils.logger import system_log, user_system_log
from rqalpha.utils.logger import system_log
from rqalpha.const import TRADING_CALENDAR_TYPE
from rqalpha.utils.functools import lru_cache
from rqalpha.environment import Environment
from rqalpha.model.instrument import Instrument

START_DATE = 20050104
END_DATE = 29991231
Expand Down Expand Up @@ -624,3 +630,96 @@ def update_futures_trading_parameters(path, end_date):
FUTURES_TRADING_PARAMETERS_FIELDS,
end_date
)


class AutomaticUpdateBundle(object):
def __init__(self, path: str, filename: str, api: Callable, fields: List[str], end_date: datetime.date) -> None:
if not os.path.exists(path):
os.makedirs(path)
self._file = os.path.join(path, filename)
self._trading_dates = None
self._filename = filename
self._api = api
self._fields = fields
self._end_date = end_date
self.updated = []
self._env = Environment.get_instance()

def get_data(self, instrument: Instrument, dt: datetime.date) -> Optional[np.ndarray]:
dt = convert_date_to_date_int(dt)
data = self._get_data_all_time(instrument)
if data is None:
return data
else:
try:
data = data[np.searchsorted(data['trading_dt'], dt)]
except IndexError:
data = None
return data

@lru_cache(128)
def _get_data_all_time(self, instrument: Instrument) -> Optional[np.ndarray]:
if instrument.order_book_id not in self.updated:
self._auto_update_task(instrument)
self.updated.append(instrument.order_book_id)
with h5py.File(self._file, "r") as h5:
data = h5[instrument.order_book_id][:]
if len(data) == 0:
return None
return data

def _auto_update_task(self, instrument: Instrument) -> None:
"""
在 rqalpha 策略运行过程中自动更新所需的日线数据
:param instrument: 合约对象
:type instrument: `Instrument`
"""
order_book_id = instrument.order_book_id
start_date = START_DATE
try:
h5 = h5py.File(self._file, "a")
if order_book_id in h5:
if len(h5[order_book_id][:]) != 0:
last_date = datetime.datetime.strptime(str(h5[order_book_id][-1]['trading_dt']), "%Y%m%d").date()
if last_date >= self._end_date:
return
start_date = self._env.data_proxy._data_source.get_next_trading_date(last_date).date()
if start_date > self._end_date:
return
arr = self._get_array(instrument, start_date)
if arr is None:
if order_book_id not in h5:
arr = np.array([])
h5.create_dataset(order_book_id, data=arr)
else:
if order_book_id in h5:
data = np.array(
[tuple(i) for i in chain(h5[order_book_id][:], arr)],
dtype=h5[order_book_id].dtype)
del h5[order_book_id]
h5.create_dataset(order_book_id, data=data)
else:
h5.create_dataset(order_book_id, data=arr)
except OSError as e:
raise OSError(_("File {} update failed, if it is using, please update later, "
"or you can delete then update again".format(self._file))) from e
finally:
h5.close()

def _get_array(self, instrument: Instrument, start_date: datetime.date) -> Optional[np.ndarray]:
df = self._api(instrument.order_book_id, start_date, self._end_date, self._fields)
if not (df is None or df.empty):
df = df[self._fields].loc[instrument.order_book_id] # rqdatac.get_open_auction_info get Futures's data will auto add 'open_interest' and 'prev_settlement'
record = df.iloc[0: 1].to_records()
dtype = [('trading_dt', 'int')]
for field in self._fields:
dtype.append((field, record.dtype[field]))
trading_dt = self._env.data_proxy._data_source.batch_get_trading_date(df.index)
trading_dt = convert_date_to_date_int(trading_dt)
arr = np.ones((trading_dt.shape[0], ), dtype=dtype)
arr['trading_dt'] = trading_dt
for field in self._fields:
arr[field] = df[field].values
return arr
return None
5 changes: 5 additions & 0 deletions rqalpha/data/data_proxy.py
Expand Up @@ -185,6 +185,11 @@ def get_open_auction_bar(self, order_book_id, dt):
"datetime", "open", "limit_up", "limit_down", "volume", "total_turnover"
]}
return PartialBarObject(instrument, bar)

def get_open_auction_volume(self, order_book_id, dt):
instrument = self.instruments(order_book_id)
volume = self._data_source.get_open_auction_volume(instrument, dt)
return volume

def history(self, order_book_id, bar_count, frequency, field, dt):
data = self.history_bars(order_book_id, bar_count, frequency,
Expand Down
9 changes: 9 additions & 0 deletions rqalpha/data/trading_dates_mixin.py
Expand Up @@ -19,6 +19,7 @@
from typing import Dict, Optional, Union

import pandas as pd
import numpy as np

from rqalpha.utils.functools import lru_cache
from rqalpha.const import TRADING_CALENDAR_TYPE
Expand Down Expand Up @@ -114,3 +115,11 @@ def _get_future_trading_date(self, dt):
return trading_dates[pos + 1]

return td

def batch_get_trading_date(self, dt_index: pd.DatetimeIndex):
# 获取 numpy.array 中所有时间所在的交易日
# 认为晚八点后为第二个交易日,认为晚八点至次日凌晨四点为夜盘
dt = dt_index - datetime.timedelta(hours=4)
trading_dates = self.get_trading_calendar(TRADING_CALENDAR_TYPE.EXCHANGE)
pos = trading_dates.searchsorted(dt.date) + np.where(dt.hour >= 16, 1, 0)
return trading_dates[pos]
3 changes: 2 additions & 1 deletion rqalpha/environment.py
Expand Up @@ -30,7 +30,7 @@
class Environment(object):
_env = None # type: Environment

def __init__(self, config):
def __init__(self, config, rqdatac_init):
Environment._env = self
self.config = config
self.data_proxy = None # type: Optional[rqalpha.data.data_proxy.DataProxy]
Expand All @@ -55,6 +55,7 @@ def __init__(self, config):
self._frontend_validators = {} # type: Dict[str, List]
self._default_frontend_validators = []
self._transaction_cost_decider_dict = {}
self.rqdatac_init = rqdatac_init # type: Boolean

# Environment.event_bus used in StrategyUniverse()
from rqalpha.core.strategy_universe import StrategyUniverse
Expand Down
14 changes: 14 additions & 0 deletions rqalpha/interface.py
Expand Up @@ -347,6 +347,20 @@ def get_open_auction_bar(self, instrument, dt):
datetime, open, limit_up, limit_down, volume, total_turnover
"""
raise NotImplementedError

def get_open_auction_volume(self, instrument, dt):
"""
获取指定资产当日的集合竞价成交量
:param instrument: 合约对象
:type instrument: class:`~Instrument`
:param dt: 集合竞价时间
:type dt: datetime.datetime
:return: `float`
"""
raise NotImplementedError

def get_settle_price(self, instrument, date):
"""
Expand Down
5 changes: 3 additions & 2 deletions rqalpha/main.py
Expand Up @@ -122,12 +122,14 @@ def init_rqdatac(rqdatac_uri):
init_rqdatac_env(rqdatac_uri)
try:
rqdatac.init()
return True
except Exception as e:
system_log.warn(_('rqdatac init failed, some apis will not function properly: {}').format(str(e)))
return


def run(config, source_code=None, user_funcs=None):
env = Environment(config)
env = Environment(config, init_rqdatac(getattr(config.base, 'rqdatac_uri', None)))
persist_helper = None
init_succeed = False
mod_handler = ModHandler()
Expand All @@ -136,7 +138,6 @@ def run(config, source_code=None, user_funcs=None):
# avoid register handlers everytime
# when running in ipython
set_loggers(config)
init_rqdatac(getattr(config.base, 'rqdatac_uri', None))
system_log.debug("\n" + pformat(config.convert_to_dict()))

env.set_strategy_loader(init_strategy_loader(env, source_code, user_funcs, config))
Expand Down
6 changes: 6 additions & 0 deletions rqalpha/mod/rqalpha_mod_sys_accounts/position_model.py
Expand Up @@ -318,6 +318,12 @@ def settlement(self, trading_date):
user_system_log.warn(_(u"{order_book_id} is expired, close all positions by system").format(
order_book_id=self._order_book_id
))
account = self._env.get_account(self._order_book_id)
side = SIDE.SELL if self.direction == POSITION_DIRECTION.LONG else SIDE.BUY
trade = Trade.__from_create__(
None, self.last_price, self._quantity, side, POSITION_EFFECT.CLOSE, self._order_book_id
)
self._env.event_bus.publish_event(Event(EVENT.TRADE, account=account, trade=trade, order=None))
self._quantity = self._old_quantity = 0
return delta_cash

Expand Down
2 changes: 1 addition & 1 deletion rqalpha/mod/rqalpha_mod_sys_simulation/matcher.py
Expand Up @@ -85,7 +85,7 @@ def _open_auction_deal_price_decider(self, order_book_id, _):

def _get_bar_volume(self, order, open_auction=False):
if open_auction:
volume = self._env.data_proxy.get_open_auction_bar(order.order_book_id, self._env.trading_dt).volume
volume = self._env.data_proxy.get_open_auction_volume(order.order_book_id, self._env.trading_dt)
else:
if isinstance(order.style, ALGO_ORDER_STYLES):
_, volume = self._env.data_proxy.get_algo_bar(order.order_book_id, order.style, self._env.calendar_dt)
Expand Down
2 changes: 1 addition & 1 deletion rqalpha/utils/testing/fixtures.py
Expand Up @@ -26,7 +26,7 @@ def init_fixture(self):
from rqalpha.environment import Environment

super(EnvironmentFixture, self).init_fixture()
self.env = Environment(RqAttrDict(self.env_config))
self.env = Environment(RqAttrDict(self.env_config), False)

@contextmanager
def mock_env_method(self, name, mock_method):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -5,7 +5,7 @@

[metadata]
name = rqalpha
version = 5.3.7
version = 5.3.8

[versioneer]
VCS = git
Expand Down
3 changes: 3 additions & 0 deletions tests/api_tests/mod/sys_simulation/test_simulation_broker.py
Expand Up @@ -43,6 +43,9 @@

def test_open_auction_match():
__config__ = {
"base": {
"auto_update_bundle": False,
},
"mod": {
"sys_simulation": {
"volume_limit": True,
Expand Down
Binary file added tests/outs/test_f_delivery.pkl
Binary file not shown.
42 changes: 42 additions & 0 deletions tests/test_f_delivery.py
@@ -0,0 +1,42 @@
import datetime

def init(context):
context.f1 = "IH2402"
context.f2 = "IC2402"
context.fired = False


def handle_bar(context, bar_dict):
if not context.fired:
buy_open(context.f1, 3)
sell_open(context.f2, 3)
context.fired = True
if bar_dict._dt.date() == datetime.date(2024, 2, 19):
context.cash_before_delivery = context.portfolio.cash
context.daily_pnl = context.portfolio.daily_pnl
if bar_dict._dt.date() == datetime.date(2024, 2, 20):
assert get_position(context.f1).quantity == 0
assert get_position(context.f2).quantity == 0
assert abs(context.portfolio.cash - (((2363 * 300) + (5105.6 * 200)) * 0.12 * 3 + context.cash_before_delivery + context.daily_pnl)) < 0.0000001
assert context.portfolio.total_value == context.portfolio.cash


__config__ = {
"base": {
"start_date": "2024-02-05",
"end_date": "2024-02-20",
"frequency": "1d",
"accounts": {
"future": 10000000,
},
},
"extra": {
"log_level": "error",
},
"mod": {
"sys_progress": {
"enabled":True,
"show": True,
},
},
}
7 changes: 7 additions & 0 deletions tests/unittest/test_data/test_auto_update_bundle/__init__.py
@@ -0,0 +1,7 @@
import os


def load_tests(loader, standard_tests, pattern):
this_dir = os.path.dirname(__file__)
standard_tests.addTests(loader.discover(start_dir=this_dir, pattern=pattern))
return standard_tests
Binary file not shown.

0 comments on commit 814e820

Please sign in to comment.