Skip to content

Commit

Permalink
refactor RealTimePriceSeriesBenchmarkProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
Cuizi7 committed Dec 10, 2018
1 parent 3d96c23 commit 687ec31
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 42 deletions.
34 changes: 21 additions & 13 deletions rqalpha/mod/rqalpha_mod_sys_benchmark/benchmark_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from rqalpha.interface import AbstractBenchmarkProvider
from rqalpha.environment import Environment
from rqalpha.events import EVENT
from rqalpha.utils.logger import system_log
from rqalpha.utils.i18n import gettext as _


Expand Down Expand Up @@ -65,29 +64,38 @@ class RealTimePriceSeriesBenchmarkProvider(AbstractBenchmarkProvider):
def __init__(self, order_book_id):
self._order_book_id = order_book_id

self._first_close = None
self._last_close = None

self._daily_returns = 0
self._total_returns = 0

event_bus = Environment.get_instance().event_bus
event_bus.add_listener(EVENT.POST_SYSTEM_INIT, self._on_system_init)
event_bus.prepend_listener(EVENT.AFTER_TRADING, self._on_after_trading)
event_bus.prepend_listener(EVENT.BAR, self._on_bar)

def _refresh_returns(self, end_date):
def _get_close(self, frequency, dt):
env = Environment.get_instance()
bar_count = env.data_proxy.count_trading_dates(env.config.base.start_date, end_date) + 1
return env.data_proxy.history_bars(
self._order_book_id, 1, frequency, "close", dt, skip_suspended=False, adjust_type='pre'
)[0]

close_series = env.data_proxy.history_bars(
self._order_book_id, bar_count, "1d", "close", end_date, skip_suspended=False, adjust_type='pre'
def _on_system_init(self, _):
env = Environment.get_instance()
self._first_close = self._last_close = self._get_close(
"1d", env.data_proxy.get_previous_trading_date(env.config.base.start_date)
)
print(env.data_proxy.get_previous_trading_date(env.config.base.start_date))
print(self._first_close)

if len(close_series) < bar_count:
system_log.error(_("Valid benchmark: unable to load enough close price."))

self._daily_returns = float((close_series[-1] - close_series[-2]) / close_series[-2])
self._total_returns = float((close_series[-1] - close_series[0]) / close_series[0])
def _on_after_trading(self, event):
self._last_close = self._get_close("1d", event.calendar_dt)

def _on_after_trading(self, _):
env = Environment.get_instance()
self._refresh_returns(env.trading_dt.date())
def _on_bar(self, event):
close = self._get_close("1m", event.calendar_dt)
self._daily_returns = float((close - self._last_close) / self._last_close)
self._total_returns = float((close - self._first_close) / self._first_close)

@property
def daily_returns(self):
Expand Down
3 changes: 3 additions & 0 deletions rqalpha/model/benchmark_portfolio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

from rqalpha.environment import Environment
from rqalpha.const import DAYS_CNT
from rqalpha.utils.repr import property_repr


class BenchmarkPortfolio(object):
__repr__ = property_repr

def __init__(self, benchmark_provider, units):
self._provider = benchmark_provider
self._units = units
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,6 @@ def test_returns(self):
self.assertAlmostEqual(self.benchmark_provider.total_returns, (3204.92 - 3334.50) / 3334.50)


class RealTimePriceSeriesBenchmarkProviderTestCase(PriceSeriesBenchmarkProviderFixture, RQAlphaTestCase):
def __init__(self, *args, **kwargs):
from rqalpha.mod.rqalpha_mod_sys_benchmark.benchmark_provider import RealTimePriceSeriesBenchmarkProvider

super(RealTimePriceSeriesBenchmarkProviderTestCase, self).__init__(*args, **kwargs)
self.provider_class = RealTimePriceSeriesBenchmarkProvider
self.benchmark_order_book_id = "000300.XSHG"
self.env_config["base"].update({
"start_date": date(2018, 9, 3), "end_date": date(2018, 9, 25)
})

def test_returns(self):
trading_date_gen = (trading_date for trading_date in self.env.config.base.trading_calendar)

self.env.event_bus.publish_event(Event(EVENT.POST_SYSTEM_INIT))
self.env.trading_dt = datetime.combine(next(trading_date_gen), time(10))

self.assertEqual(self.benchmark_provider.daily_returns, 0)
self.env.event_bus.publish_event(Event(EVENT.AFTER_TRADING))
self.env.trading_dt = datetime.combine(next(trading_date_gen), time(10))
self.assertAlmostEqual(self.benchmark_provider.daily_returns, (3321.82 - 3334.50) / 3334.50)
self.assertAlmostEqual(self.benchmark_provider.total_returns, (3321.82 - 3334.50) / 3334.50)
for i in range(10):
self.env.event_bus.publish_event(Event(EVENT.AFTER_TRADING))
self.env.trading_dt = datetime.combine(next(trading_date_gen), time(10))
self.assertAlmostEqual(self.benchmark_provider.daily_returns, (3204.92 - 3242.09) / 3242.09)
self.assertAlmostEqual(self.benchmark_provider.total_returns, (3204.92 - 3334.50) / 3334.50)


if __name__ == "__main__":
import unittest
unittest.main()

0 comments on commit 687ec31

Please sign in to comment.