From 3d75a2f11cf8b702fdb9538e230a40ad23a96047 Mon Sep 17 00:00:00 2001 From: Cedric Porter Date: Fri, 7 Apr 2017 22:25:43 +0800 Subject: [PATCH] format test code --- rqalpha/model/instrument.py | 2 +- rqalpha/model/order.py | 2 -- test.py | 4 ++-- tests/api/__init__.py | 33 +++++++++++++++------------------ tests/api/test_api_base.py | 36 +++++++++++++++++++++++++++++++++++- tests/api/test_api_future.py | 22 +++++++++++++++++++++- tests/api/test_api_stock.py | 23 +++++++++++++++++++++-- tests/test_s_dma.py | 3 +-- tests/test_s_dual_thrust.py | 1 - tests/test_s_scheduler.py | 2 +- tests/test_s_turtle.py | 13 ++++++------- 11 files changed, 103 insertions(+), 38 deletions(-) diff --git a/rqalpha/model/instrument.py b/rqalpha/model/instrument.py index 1caa09ae4..759c9731e 100644 --- a/rqalpha/model/instrument.py +++ b/rqalpha/model/instrument.py @@ -124,7 +124,7 @@ def name(self): return self.__name def __repr__(self): - return "{0}:{1}".format(self.__code,self.__name) + return "{0}:{1}".format(self.__code, self.__name) class IndustryCode(object): diff --git a/rqalpha/model/order.py b/rqalpha/model/order.py index 312715c38..595d546f0 100644 --- a/rqalpha/model/order.py +++ b/rqalpha/model/order.py @@ -22,8 +22,6 @@ from ..utils.logger import user_system_log - - class Order(object): order_id_gen = id_gen(int(time.time())) diff --git a/test.py b/test.py index af1cab12f..8ee7bc760 100644 --- a/test.py +++ b/test.py @@ -276,11 +276,11 @@ def write_csv(path, fields): writer = csv.DictWriter(csv_file, fieldnames=fields) writer.writerow({'date_time': end_time, 'time_spend': time_spend}) else: - if 0 < len(old_test_times) < 5 and time_spend > float(sum(float(i['time_spend']) for i in old_test_times))/len(old_test_times) * 1.1: + if 0 < len(old_test_times) < 5 and time_spend > float(sum(float(i['time_spend']) for i in old_test_times)) / len(old_test_times) * 1.1: print('Average time of last 5 runs:', float(sum(float(i['time_spend']) for i in old_test_times))/len(old_test_times)) print('Now time spend:', time_spend) raise RuntimeError('Performance regresses!') - elif len(old_test_times) >= 5 and time_spend > float(sum(float(i['time_spend']) for i in old_test_times[-5:]))/5 * 1.1: + elif len(old_test_times) >= 5 and time_spend > float(sum(float(i['time_spend']) for i in old_test_times[-5:])) / 5 * 1.1: print('Average time of last 5 runs:', float(sum(float(i['time_spend']) for i in old_test_times[-5:])) / 5) print('Now time spend:', time_spend) diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 64dea0dba..c6d24b7bd 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -1,19 +1,16 @@ #!/usr/bin/env python -# encoding: utf-8 - - -""" -@author: Gu Xi -@contact: guxi@ricequant.com -@site: http://www.ricequant.com -@file: __init__.py.py -@time: 2017/2/17 下午1:35 -""" - - -def func(): - pass - - -if __name__ == '__main__': - pass \ No newline at end of file +# -*- coding: utf-8 -*- +# +# Copyright 2017 Ricequant, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/api/test_api_base.py b/tests/api/test_api_base.py index 45429eb24..acab3c96e 100644 --- a/tests/api/test_api_base.py +++ b/tests/api/test_api_base.py @@ -1,9 +1,26 @@ #!/usr/bin/env python -# encoding: utf-8 +# -*- coding: utf-8 -*- +# +# Copyright 2017 Ricequant, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect + def test_get_order(): from rqalpha.api import order_shares, get_order + def init(context): context.s1 = '000001.XSHE' context.amount = 100 @@ -19,6 +36,7 @@ def handle_bar(context, bar_dict): def test_get_open_order(): from rqalpha.api import order_shares, get_open_orders, get_order + def init(context): context.s1 = '000001.XSHE' context.limitprice = 8.9 @@ -45,6 +63,7 @@ def handle_bar(context, bar_dict): def test_cancel_order(): from rqalpha.api import order_shares, cancel_order, get_order + def init(context): context.s1 = '000001.XSHE' context.limitprice = 8.59 @@ -63,6 +82,7 @@ def handle_bar(context, bar_dict): def test_update_universe(): from rqalpha.api import update_universe, history_bars + def init(context): context.s1 = '000001.XSHE' context.s2 = '600340.XSHG' @@ -82,6 +102,7 @@ def handle_bar(context, bar_dict): def test_subscribe(): from rqalpha.api import subscribe + def init(context): context.f1 = 'AU88' context.amount = 1 @@ -94,6 +115,7 @@ def handle_bar(context, bar_dict): def test_unsubscribe(): from rqalpha.api import subscribe, unsubscribe + def init(context): context.f1 = 'AU88' context.amount = 1 @@ -107,6 +129,7 @@ def handle_bar(context, bar_dict): def test_get_yield_curve(): from rqalpha.api import get_yield_curve + def init(context): pass @@ -119,6 +142,7 @@ def handle_bar(context, bar_dict): def test_history_bars(): from rqalpha.api import history_bars + def init(context): context.s1 = '000001.XSHE' pass @@ -132,6 +156,7 @@ def handle_bar(context, bar_dict): def test_all_instruments(): from rqalpha.api import all_instruments + def init(context): pass @@ -151,8 +176,10 @@ def handle_bar(context, bar_dict): assert all_instruments('Future').shape >= (3500, 16) test_all_instruments_code_new = "".join(inspect.getsourcelines(test_all_instruments)[0]) + def test_instruments_code(): from rqalpha.api import instruments + def init(context): context.s1 = '000001.XSHE' pass @@ -170,6 +197,7 @@ def handle_bar(context, bar_dict): def test_sector(): from rqalpha.api import sector + def init(context): pass @@ -180,6 +208,7 @@ def handle_bar(context, bar_dict): def test_industry(): from rqalpha.api import industry, instruments + def init(context): context.s1 = '000001.XSHE' context.s2 = '600340.XSHG' @@ -196,6 +225,7 @@ def handle_bar(context, bar_dict): def test_concept(): from rqalpha.api import concept, instruments + def init(context): context.s1 = '000002.XSHE' @@ -210,6 +240,7 @@ def handle_bar(context, bar_dict): def test_get_trading_dates(): from rqalpha.api import get_trading_dates import datetime + def init(context): pass @@ -228,6 +259,7 @@ def handle_bar(context, bar_dict): def test_get_previous_trading_date(): from rqalpha.api import get_previous_trading_date + def init(context): pass @@ -244,6 +276,7 @@ def handle_bar(context, bar_dict): def test_get_next_trading_date(): from rqalpha.api import get_next_trading_date + def init(context): pass @@ -256,6 +289,7 @@ def handle_bar(context, bar_dict): def test_get_dividend(): from rqalpha.api import get_dividend import pandas + def init(context): pass diff --git a/tests/api/test_api_future.py b/tests/api/test_api_future.py index 76c71b826..a638d965d 100644 --- a/tests/api/test_api_future.py +++ b/tests/api/test_api_future.py @@ -1,9 +1,26 @@ #!/usr/bin/env python -# encoding: utf-8 +# -*- coding: utf-8 -*- +# +# Copyright 2017 Ricequant, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect + def test_buy_open(): from rqalpha.api import buy_open, subscribe, get_order, ORDER_STATUS, POSITION_EFFECT, SIDE + def init(context): context.f1 = 'P88' context.amount = 1 @@ -27,6 +44,7 @@ def handle_bar(context, bar_dict): def test_sell_open(): from rqalpha.api import sell_open, subscribe, get_order, ORDER_STATUS, POSITION_EFFECT, SIDE + def init(context): context.f1 = 'P88' context.amount = 1 @@ -50,6 +68,7 @@ def handle_bar(context, bar_dict): def test_buy_close(): from rqalpha.api import buy_close, subscribe, get_order, ORDER_STATUS, POSITION_EFFECT, SIDE + def init(context): context.f1 = 'P88' context.amount = 1 @@ -73,6 +92,7 @@ def handle_bar(context, bar_dict): def test_sell_close(): from rqalpha.api import sell_close, subscribe, get_order, ORDER_STATUS, POSITION_EFFECT, SIDE + def init(context): context.f1 = 'P88' context.amount = 1 diff --git a/tests/api/test_api_stock.py b/tests/api/test_api_stock.py index ebf3f1352..38edb96ac 100644 --- a/tests/api/test_api_stock.py +++ b/tests/api/test_api_stock.py @@ -1,7 +1,21 @@ #!/usr/bin/env python -# encoding: utf-8 +# -*- coding: utf-8 -*- +# +# Copyright 2017 Ricequant, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect -import datetime def test_order_shares(): @@ -14,6 +28,7 @@ def handle_bar(context, bar_dict): def test_order_shares(): from rqalpha.api import order_shares, get_order, SIDE, LimitOrder + def init(context): context.order_count = 0 context.s1 = "000001.XSHE" @@ -36,6 +51,7 @@ def handle_bar(context, bar_dict): def test_order_lots(): from rqalpha.api import order_lots, get_order, SIDE, LimitOrder + def init(context): context.order_count = 0 context.s1 = "000001.XSHE" @@ -58,6 +74,7 @@ def handle_bar(context, bar_dict): def test_order_value(): from rqalpha.api import order_value, get_order, SIDE, LimitOrder + def init(context): context.order_count = 0 context.s1 = "000001.XSHE" @@ -78,6 +95,7 @@ def handle_bar(context, bar_dict): def test_order_percent(): from rqalpha.api import order_percent, get_order, SIDE, LimitOrder + def init(context): context.order_count = 0 context.s1 = "000001.XSHE" @@ -97,6 +115,7 @@ def handle_bar(context, bar_dict): def test_order_target_value(): from rqalpha.api import order_target_percent, get_order, SIDE, LimitOrder + def init(context): context.order_count = 0 context.s1 = "000001.XSHE" diff --git a/tests/test_s_dma.py b/tests/test_s_dma.py index 2b665f4ad..7557f52d4 100644 --- a/tests/test_s_dma.py +++ b/tests/test_s_dma.py @@ -24,8 +24,7 @@ def handle_bar(context, bar_dict): if DDD < AMA and cur_position > 0: order_target_percent(context.s1, 0) - if (HHV(MAX(O, C), 50) / LLV(MIN(O, C), 50) < 2 - and CROSS(DDD, AMA) and cur_position == 0): + if (HHV(MAX(O, C), 50) / LLV(MIN(O, C), 50) < 2 and CROSS(DDD, AMA) and cur_position == 0): order_target_percent(context.s1, 1) diff --git a/tests/test_s_dual_thrust.py b/tests/test_s_dual_thrust.py index aae048ee8..33a46e92d 100644 --- a/tests/test_s_dual_thrust.py +++ b/tests/test_s_dual_thrust.py @@ -48,7 +48,6 @@ def handle_bar(context, bar_dict): # 使用第n-1日的收盘价作为当前价 current_price = Close[2] - Range = max((HH - LC), (HC - LL)) K1 = 0.9 BuyLine = Openprice + K1 * Range diff --git a/tests/test_s_scheduler.py b/tests/test_s_scheduler.py index 56e96f8e7..5046749ed 100644 --- a/tests/test_s_scheduler.py +++ b/tests/test_s_scheduler.py @@ -2,7 +2,7 @@ def init(context): - scheduler.run_weekly(rebalance, 1, time_rule = market_open(0, 0)) + scheduler.run_weekly(rebalance, 1, time_rule=market_open(0, 0)) def rebalance(context, bar_dict): diff --git a/tests/test_s_turtle.py b/tests/test_s_turtle.py index ae159ff95..b8a43c780 100644 --- a/tests/test_s_turtle.py +++ b/tests/test_s_turtle.py @@ -11,15 +11,14 @@ def get_extreme(array_high_price_result, array_low_price_result): return [max_result, min_result] -def get_atr_and_unit( atr_array_result, atr_length_result, portfolio_value_result): - atr = atr_array_result[ atr_length_result-1] +def get_atr_and_unit(atr_array_result, atr_length_result, portfolio_value_result): + atr = atr_array_result[atr_length_result - 1] unit = math.floor(portfolio_value_result * .01 / atr) return [atr, unit] def get_stop_price(first_open_price_result, units_hold_result, atr_result): - stop_price = first_open_price_result - 2 * atr_result \ - + (units_hold_result - 1) * 0.5 * atr_result + stop_price = first_open_price_result - 2 * atr_result + (units_hold_result - 1) * 0.5 * atr_result return stop_price @@ -42,9 +41,9 @@ def init(context): def handle_bar(context, bar_dict): portfolio_value = context.portfolio.portfolio_value - high_price = history_bars(context.s, context.open_observe_time+1, '1d', 'high') - low_price_for_atr = history_bars(context.s, context.open_observe_time+1, '1d', 'low') - low_price_for_extreme = history_bars(context.s, context.close_observe_time+1, '1d', 'low') + high_price = history_bars(context.s, context.open_observe_time + 1, '1d', 'high') + low_price_for_atr = history_bars(context.s, context.open_observe_time + 1, '1d', 'low') + low_price_for_extreme = history_bars(context.s, context.close_observe_time + 1, '1d', 'low') close_price = history_bars(context.s, context.open_observe_time+2, '1d', 'close') close_price_for_atr = close_price[:-1]