Skip to content

Commit

Permalink
pass context to user event handler
Browse files Browse the repository at this point in the history
  • Loading branch information
Cuizi7 committed Dec 10, 2018
1 parent 687ec31 commit 5b4135a
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 12 deletions.
3 changes: 2 additions & 1 deletion rqalpha/api/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,4 +940,5 @@ def get_position(order_book_id, direction):
)
def subscribe_event(event_type, handler):
env = Environment.get_instance()
env.event_bus.add_listener(event_type, handler, user=True)
user_strategy = env.user_strategy
env.event_bus.add_listener(event_type, user_strategy.wrap_user_event_handler(handler), user=True)
10 changes: 10 additions & 0 deletions rqalpha/core/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps

from rqalpha.events import EVENT, Event
from rqalpha.utils import run_when_strategy_not_hold
from rqalpha.utils.logger import user_system_log
Expand Down Expand Up @@ -104,3 +106,11 @@ def after_trading(self, event):
with ExecutionContext(EXECUTION_PHASE.AFTER_TRADING):
with ModifyExceptionFromType(EXC_TYPE.USER_EXC):
self._after_trading(self._user_context)

def wrap_user_event_handler(self, handler):
@wraps(handler)
def wrapped_handler(event):
with ExecutionContext(EXECUTION_PHASE.GLOBAL):
with ModifyExceptionFromType(EXC_TYPE.USER_EXC):
return handler(self._user_context, event)
return wrapped_handler
1 change: 1 addition & 0 deletions rqalpha/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, config):
self.mod_dict = None
self.plot_store = None
self.bar_dict = None
self.user_strategy = None
self._frontend_validators = []
self._account_model_dict = {}
self._position_model_dict = {}
Expand Down
1 change: 1 addition & 0 deletions rqalpha/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def run(config, source_code=None, user_funcs=None):

ucontext = StrategyContext()
user_strategy = Strategy(env.event_bus, scope, ucontext)
env.user_strategy = user_strategy
scheduler.set_user_context(ucontext)

if not config.extra.force_run_init_when_pt_resume:
Expand Down
12 changes: 5 additions & 7 deletions tests/api/test_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,14 @@ def handle_bar(context, bar_dict):

@as_test_strategy()
def test_subscribe_event():
flags = {}

def init(_):
subscribe_event(EVENT.BEFORE_TRADING, on_before_trading)

def before_trading(_):
flags["before_trading_ran"] = True
def before_trading(context):
context.before_trading_ran = True

def on_before_trading(_):
assert flags["before_trading_ran"]
flags["before_trading_ran"] = False
def on_before_trading(context, _):
assert context.before_trading_ran
context.before_trading_ran = False

return init, before_trading
5 changes: 3 additions & 2 deletions tests/api/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.

from rqalpha.api import *
from rqalpha.environment import Environment

from ..utils import make_test_strategy_decorator

Expand Down Expand Up @@ -66,9 +67,9 @@ def handle_bar(*_):
buy_open("SC1809", 2)
sell_close("SC1809", 2, close_today=False)

def on_trade(event):
def on_trade(_, event):
trade = event.trade
contract_multiplier = instruments("SC1809").contract_multiplier
contract_multiplier = Environment.get_instance().data_proxy.instruments("SC1809").contract_multiplier
if trade.position_effect == POSITION_EFFECT.OPEN:
assert_almost_equal(
trade.transaction_cost, 0.0002 * trade.last_quantity * trade.last_price * contract_multiplier
Expand Down
5 changes: 3 additions & 2 deletions tests/test_s_tick_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

def init(context):
context.count = 0
context.tick_size = instruments(stock).tick_size()
subscribe_event(EVENT.TRADE, on_trade)


def on_trade(event):
def on_trade(context, event):
global price
trade = event.trade
assert trade.last_price == price + instruments(stock).tick_size() * SLIPPAGE
assert trade.last_price == price + context.tick_size * SLIPPAGE


def before_trading(context):
Expand Down

0 comments on commit 5b4135a

Please sign in to comment.