Skip to content

Commit

Permalink
implement RealTimePriceSeriesBenchmarkProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
Cuizi7 committed Dec 4, 2018
1 parent 5c2bf91 commit 406bed5
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
33 changes: 33 additions & 0 deletions rqalpha/mod/rqalpha_mod_sys_benchmark/benchmark_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from rqalpha.interface import AbstractBenchmarkProvider
from rqalpha.environment import Environment
from rqalpha.events import EVENT
from rqalpha.utils.logger import system_log
from rqalpha.utils.i18n import gettext as _


Expand Down Expand Up @@ -63,3 +64,35 @@ def total_returns(self):
class RealTimePriceSeriesBenchmarkProvider(AbstractBenchmarkProvider):
def __init__(self, order_book_id):
self._order_book_id = order_book_id

self._daily_returns = 0
self._total_returns = 0

event_bus = Environment.get_instance().event_bus
event_bus.prepend_listener(EVENT.AFTER_TRADING, self._on_after_trading)

def _refresh_returns(self, end_date):
env = Environment.get_instance()
bar_count = env.data_proxy.count_trading_dates(env.config.base.start_date, end_date) + 1

close_series = env.data_proxy.history_bars(
self._order_book_id, bar_count, "1d", "close", end_date, skip_suspended=False, adjust_type='pre'
)

if len(close_series) < bar_count:
system_log.error(_("Valid benchmark: unable to load enough close price."))

self._daily_returns = float((close_series[-1] - close_series[-2]) / close_series[-2])
self._total_returns = float((close_series[-1] - close_series[0]) / close_series[0])

def _on_after_trading(self, _):
env = Environment.get_instance()
self._refresh_returns(env.trading_dt.date())

@property
def daily_returns(self):
return self._daily_returns

@property
def total_returns(self):
return self._total_returns
12 changes: 6 additions & 6 deletions rqalpha/mod/rqalpha_mod_sys_benchmark/testing.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from rqalpha.utils.testing import DataProxyFixture
from rqalpha.mod.rqalpha_mod_sys_benchmark.benchmark_provider import BackTestPriceSeriesBenchmarkProvider


class BackTestPriceSeriesBenchmarkProviderFixture(DataProxyFixture):
class PriceSeriesBenchmarkProviderFixture(DataProxyFixture):
def __init__(self, *args, **kwargs):
super(BackTestPriceSeriesBenchmarkProviderFixture, self).__init__(*args, **kwargs)
super(PriceSeriesBenchmarkProviderFixture, self).__init__(*args, **kwargs)

self.benchmark_provider = None
self.benchmark_order_book_id = None
self.provider_class = BackTestPriceSeriesBenchmarkProvider

def init_fixture(self):
from rqalpha.mod.rqalpha_mod_sys_benchmark.benchmark_provider import BackTestPriceSeriesBenchmarkProvider

super(BackTestPriceSeriesBenchmarkProviderFixture, self).init_fixture()
self.benchmark_provider = BackTestPriceSeriesBenchmarkProvider(self.benchmark_order_book_id)
super(PriceSeriesBenchmarkProviderFixture, self).init_fixture()
self.benchmark_provider = self.provider_class (self.benchmark_order_book_id)
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from datetime import date
from datetime import date, datetime, time

from rqalpha.utils.testing import RQAlphaTestCase
from rqalpha.mod.rqalpha_mod_sys_benchmark.testing import BackTestPriceSeriesBenchmarkProviderFixture
from rqalpha.mod.rqalpha_mod_sys_benchmark.testing import PriceSeriesBenchmarkProviderFixture
from rqalpha.events import EVENT, Event


class BackTestPriceSeriesBenchmarkProviderTestCase(BackTestPriceSeriesBenchmarkProviderFixture, RQAlphaTestCase):
class BackTestPriceSeriesBenchmarkProviderTestCase(PriceSeriesBenchmarkProviderFixture, RQAlphaTestCase):
def __init__(self, *args, **kwargs):
super(BackTestPriceSeriesBenchmarkProviderTestCase, self).__init__(*args, **kwargs)
self.benchmark_order_book_id = "000300.XSHG"
Expand All @@ -25,6 +25,35 @@ def test_returns(self):
self.assertAlmostEqual(self.benchmark_provider.total_returns, (3204.92 - 3334.50) / 3334.50)


class RealTimePriceSeriesBenchmarkProviderTestCase(PriceSeriesBenchmarkProviderFixture, RQAlphaTestCase):
def __init__(self, *args, **kwargs):
from rqalpha.mod.rqalpha_mod_sys_benchmark.benchmark_provider import RealTimePriceSeriesBenchmarkProvider

super(RealTimePriceSeriesBenchmarkProviderTestCase, self).__init__(*args, **kwargs)
self.provider_class = RealTimePriceSeriesBenchmarkProvider
self.benchmark_order_book_id = "000300.XSHG"
self.env_config["base"].update({
"start_date": date(2018, 9, 3), "end_date": date(2018, 9, 25)
})

def test_returns(self):
trading_date_gen = (trading_date for trading_date in self.env.config.base.trading_calendar)

self.env.event_bus.publish_event(Event(EVENT.POST_SYSTEM_INIT))
self.env.trading_dt = datetime.combine(next(trading_date_gen), time(10))

self.assertEqual(self.benchmark_provider.daily_returns, 0)
self.env.event_bus.publish_event(Event(EVENT.AFTER_TRADING))
self.env.trading_dt = datetime.combine(next(trading_date_gen), time(10))
self.assertAlmostEqual(self.benchmark_provider.daily_returns, (3321.82 - 3334.50) / 3334.50)
self.assertAlmostEqual(self.benchmark_provider.total_returns, (3321.82 - 3334.50) / 3334.50)
for i in range(10):
self.env.event_bus.publish_event(Event(EVENT.AFTER_TRADING))
self.env.trading_dt = datetime.combine(next(trading_date_gen), time(10))
self.assertAlmostEqual(self.benchmark_provider.daily_returns, (3204.92 - 3242.09) / 3242.09)
self.assertAlmostEqual(self.benchmark_provider.total_returns, (3204.92 - 3334.50) / 3334.50)


if __name__ == "__main__":
import unittest
unittest.main()

0 comments on commit 406bed5

Please sign in to comment.