Skip to content

Commit

Permalink
config.mod新增strategy_name字段 (#827)
Browse files Browse the repository at this point in the history
* config.mod新增strategy_name字段

* ricequant水印位置调整
  • Loading branch information
Lin-Dongzhao committed Dec 4, 2023
1 parent d72b936 commit 80b655f
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 20 deletions.
2 changes: 2 additions & 0 deletions rqalpha/mod/rqalpha_mod_sys_analyser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"benchmark": None,
# 当不输出 csv/pickle/plot 等内容时,关闭该项可关闭策略运行过程中部分收集数据的逻辑,用以提升性能
"record": True,
# 策略名称,可在summary、回测报告,收益图中展示
"strategy_name": None,
# 回测结果输出的文件路径,该文件为 pickle 格式,内容为每日净值、头寸、流水及风险指标等;若不设置则不输出该文件
"output_file": None,
# 回测报告的数据目录,报告为 csv 格式;若不设置则不输出报告
Expand Down
7 changes: 5 additions & 2 deletions rqalpha/mod/rqalpha_mod_sys_analyser/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,10 @@ def tear_down(self, code, exception=None):
if len(self._total_portfolios) == 0:
return

strategy_name = os.path.basename(self._env.config.base.strategy_file).split(".")[0]
if self._mod_config.strategy_name:
strategy_name = self._mod_config.strategy_name
else:
strategy_name = os.path.basename(self._env.config.base.strategy_file).split(".")[0]
data_proxy = self._env.data_proxy
start_date, end_date = attrgetter("start_date", "end_date")(self._env.config.base)
summary = {
Expand Down Expand Up @@ -536,7 +539,7 @@ def tear_down(self, code, exception=None):
_plot_template_cls = PLOT_TEMPLATE.get(self._mod_config.plot, DefaultPlot)
plot_result(
result_dict, self._mod_config.plot, self._mod_config.plot_save_file,
plot_config.weekly_indicators, plot_config.open_close_points, _plot_template_cls
plot_config.weekly_indicators, plot_config.open_close_points, _plot_template_cls, self._mod_config.strategy_name
)

return result_dict
3 changes: 3 additions & 0 deletions rqalpha/mod/rqalpha_mod_sys_analyser/plot/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
LABEL_FONT_SIZE = 11
SUPPORT_CHINESE = True

TITLE_FONT_SIZE = 16

RED = "#aa4643"
BLUE = "#4572a7"
YELLOW = "#F3A423"
Expand All @@ -50,6 +52,7 @@
IMG_WIDTH = 15

# 两部分的相对高度
PLOT_TITLE_HEIGHT = 1
INDICATOR_AREA_HEIGHT = 3
PLOT_AREA_HEIGHT = 5
USER_PLOT_AREA_HEIGHT = 2
Expand Down
54 changes: 36 additions & 18 deletions rqalpha/mod/rqalpha_mod_sys_analyser/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from .utils import IndicatorInfo, LineInfo, max_dd as _max_dd, SpotInfo, max_ddd as _max_ddd
from .utils import weekly_returns, trading_dates_index
from .consts import PlotTemplate, DefaultPlot
from .consts import IMG_WIDTH, INDICATOR_AREA_HEIGHT, PLOT_AREA_HEIGHT, USER_PLOT_AREA_HEIGHT
from .consts import LABEL_FONT_SIZE, BLACK, SUPPORT_CHINESE
from .consts import IMG_WIDTH, INDICATOR_AREA_HEIGHT, PLOT_AREA_HEIGHT, USER_PLOT_AREA_HEIGHT, PLOT_TITLE_HEIGHT
from .consts import LABEL_FONT_SIZE, BLACK, SUPPORT_CHINESE, TITLE_FONT_SIZE
from .consts import MAX_DD, MAX_DDD, OPEN_POINT, CLOSE_POINT
from .consts import LINE_BENCHMARK, LINE_STRATEGY, LINE_WEEKLY_BENCHMARK, LINE_WEEKLY, LINE_EXCESS

Expand All @@ -49,11 +49,12 @@ class IndicatorArea(SubPlot):

def __init__(
self, indicators: List[List[IndicatorInfo]], indicator_values: Mapping[str, float],
plot_template: PlotTemplate
plot_template: PlotTemplate, strategy_name=None
):
self._indicators = indicators
self._values = indicator_values
self._template = plot_template
self._strategy_name = strategy_name

def plot(self, ax: Axes):
ax.axis("off")
Expand All @@ -70,7 +71,10 @@ def plot(self, ax: Axes):
value = "nan"
ax.text(x, y_label, i.label, color=i.color, fontsize=LABEL_FONT_SIZE),
ax.text(x, y_value, value, color=BLACK, fontsize=i.value_font_size)

if self._strategy_name:
p = TitlePlot(self._strategy_name, len(self._indicators), self._template)
p.plot(ax)


class ReturnPlot(SubPlot):
height: int = PLOT_AREA_HEIGHT
Expand Down Expand Up @@ -130,8 +134,21 @@ def plot(self, ax: Axes):
pyplot.legend(loc="best").get_frame().set_alpha(0.5)


class TitlePlot(SubPlot):
height: int = PLOT_TITLE_HEIGHT

def __init__(self, strategy_name, indicator_area_rows, plot_template: PlotTemplate):
self._strategy_name = strategy_name
self._indicator_area_rows = indicator_area_rows
self._template = plot_template

def plot(self, ax:Axes):
x = 0.57 # title 为整图居中,而非子图居中
y = (self._template.INDICATOR_LABEL_HEIGHT + self._template.INDICATOR_VALUE_HEIGHT) * self._indicator_area_rows + 0.1
ax.text(x, y, self._strategy_name, ha='center', va='bottom', color=BLACK, fontsize=TITLE_FONT_SIZE)

class WaterMark:
def __init__(self, img_width, img_height):
def __init__(self, img_width, img_height, strategy_name):
logo_file = os.path.join(
os.path.dirname(os.path.realpath(rqalpha.__file__)),
"resource", 'ricequant-logo.png')
Expand All @@ -142,21 +159,22 @@ def __init__(self, img_width, img_height):

def plot(self, fig: Figure):
fig.figimage(
self.logo_img,
xo=(self.img_width * self.dpi - self.logo_img.shape[1]) / 2,
yo=(self.img_height * self.dpi - self.logo_img.shape[0]) / 2,
alpha=0.4,
)

self.logo_img,
xo = (self.img_width * self.dpi - self.logo_img.shape[1]) / 2,
yo = (PLOT_AREA_HEIGHT * self.dpi - self.logo_img.shape[0]) / 2,
alpha=0.4
)

def _plot(title: str, sub_plots: List[SubPlot]):

def _plot(title: str, sub_plots: List[SubPlot], strategy_name):
img_height = sum(s.height for s in sub_plots)
water_mark = WaterMark(IMG_WIDTH, img_height)
water_mark = WaterMark(IMG_WIDTH, img_height, strategy_name)
fig = pyplot.figure(title, figsize=(IMG_WIDTH, img_height), dpi=water_mark.dpi, clear=True)
water_mark.plot(fig)

gs = gridspec.GridSpec(img_height, 8, figure=fig)
last_height = 0
if (strategy_name): last_height = 1
else: last_height = 0
for p in sub_plots:
p.plot(pyplot.subplot(gs[last_height:last_height + p.height, :p.right_pad]))
last_height += p.height
Expand All @@ -167,7 +185,7 @@ def _plot(title: str, sub_plots: List[SubPlot]):

def plot_result(
result_dict, show=True, save=None, weekly_indicators: bool = False, open_close_points: bool = False,
plot_template_cls=DefaultPlot
plot_template_cls=DefaultPlot, strategy_name=None
):
summary = result_dict["summary"]
portfolio = result_dict["portfolio"]
Expand Down Expand Up @@ -216,13 +234,13 @@ def plot_result(
sub_plots = [IndicatorArea(indicators, ChainMap(summary, {
"max_dd_ddd": "MaxDD {}\nMaxDDD {}".format(max_dd.repr, max_ddd.repr),
"excess_max_dd_ddd": ex_max_dd_ddd,
}), plot_template), ReturnPlot(
}), plot_template, strategy_name), ReturnPlot(
portfolio.unit_net_value - 1, return_lines, spots_on_returns
)]
if "plots" in result_dict:
sub_plots.append(UserPlot(result_dict["plots"]))
_plot(summary["strategy_file"], sub_plots)

_plot(summary["strategy_file"], sub_plots, strategy_name)

if save:
file_path = save
Expand Down

0 comments on commit 80b655f

Please sign in to comment.