Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config.mod新增strategy_name字段 #827

Merged
merged 3 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions rqalpha/mod/rqalpha_mod_sys_analyser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"benchmark": None,
# 当不输出 csv/pickle/plot 等内容时,关闭该项可关闭策略运行过程中部分收集数据的逻辑,用以提升性能
"record": True,
# 策略名称,可在summary、回测报告,收益图中展示
"strategy_name": None,
# 回测结果输出的文件路径,该文件为 pickle 格式,内容为每日净值、头寸、流水及风险指标等;若不设置则不输出该文件
"output_file": None,
# 回测报告的数据目录,报告为 csv 格式;若不设置则不输出报告
Expand Down
7 changes: 5 additions & 2 deletions rqalpha/mod/rqalpha_mod_sys_analyser/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,10 @@ def tear_down(self, code, exception=None):
if len(self._total_portfolios) == 0:
return

strategy_name = os.path.basename(self._env.config.base.strategy_file).split(".")[0]
if self._mod_config.strategy_name:
strategy_name = self._mod_config.strategy_name
else:
strategy_name = os.path.basename(self._env.config.base.strategy_file).split(".")[0]
data_proxy = self._env.data_proxy
start_date, end_date = attrgetter("start_date", "end_date")(self._env.config.base)
summary = {
Expand Down Expand Up @@ -536,7 +539,7 @@ def tear_down(self, code, exception=None):
_plot_template_cls = PLOT_TEMPLATE.get(self._mod_config.plot, DefaultPlot)
plot_result(
result_dict, self._mod_config.plot, self._mod_config.plot_save_file,
plot_config.weekly_indicators, plot_config.open_close_points, _plot_template_cls
plot_config.weekly_indicators, plot_config.open_close_points, _plot_template_cls, self._mod_config.strategy_name
)

return result_dict
3 changes: 3 additions & 0 deletions rqalpha/mod/rqalpha_mod_sys_analyser/plot/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
LABEL_FONT_SIZE = 11
SUPPORT_CHINESE = True

TITLE_FONT_SIZE = 16

RED = "#aa4643"
BLUE = "#4572a7"
YELLOW = "#F3A423"
Expand All @@ -50,6 +52,7 @@
IMG_WIDTH = 15

# 两部分的相对高度
PLOT_TITLE_HEIGHT = 1
INDICATOR_AREA_HEIGHT = 3
PLOT_AREA_HEIGHT = 5
USER_PLOT_AREA_HEIGHT = 2
Expand Down
54 changes: 36 additions & 18 deletions rqalpha/mod/rqalpha_mod_sys_analyser/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from .utils import IndicatorInfo, LineInfo, max_dd as _max_dd, SpotInfo, max_ddd as _max_ddd
from .utils import weekly_returns, trading_dates_index
from .consts import PlotTemplate, DefaultPlot
from .consts import IMG_WIDTH, INDICATOR_AREA_HEIGHT, PLOT_AREA_HEIGHT, USER_PLOT_AREA_HEIGHT
from .consts import LABEL_FONT_SIZE, BLACK, SUPPORT_CHINESE
from .consts import IMG_WIDTH, INDICATOR_AREA_HEIGHT, PLOT_AREA_HEIGHT, USER_PLOT_AREA_HEIGHT, PLOT_TITLE_HEIGHT
from .consts import LABEL_FONT_SIZE, BLACK, SUPPORT_CHINESE, TITLE_FONT_SIZE
from .consts import MAX_DD, MAX_DDD, OPEN_POINT, CLOSE_POINT
from .consts import LINE_BENCHMARK, LINE_STRATEGY, LINE_WEEKLY_BENCHMARK, LINE_WEEKLY, LINE_EXCESS

Expand All @@ -49,11 +49,12 @@ class IndicatorArea(SubPlot):

def __init__(
self, indicators: List[List[IndicatorInfo]], indicator_values: Mapping[str, float],
plot_template: PlotTemplate
plot_template: PlotTemplate, strategy_name=None
):
self._indicators = indicators
self._values = indicator_values
self._template = plot_template
self._strategy_name = strategy_name

def plot(self, ax: Axes):
ax.axis("off")
Expand All @@ -70,7 +71,10 @@ def plot(self, ax: Axes):
value = "nan"
ax.text(x, y_label, i.label, color=i.color, fontsize=LABEL_FONT_SIZE),
ax.text(x, y_value, value, color=BLACK, fontsize=i.value_font_size)

if self._strategy_name:
p = TitlePlot(self._strategy_name, len(self._indicators), self._template)
p.plot(ax)


class ReturnPlot(SubPlot):
height: int = PLOT_AREA_HEIGHT
Expand Down Expand Up @@ -130,8 +134,21 @@ def plot(self, ax: Axes):
pyplot.legend(loc="best").get_frame().set_alpha(0.5)


class TitlePlot(SubPlot):
height: int = PLOT_TITLE_HEIGHT

def __init__(self, strategy_name, indicator_area_rows, plot_template: PlotTemplate):
self._strategy_name = strategy_name
self._indicator_area_rows = indicator_area_rows
self._template = plot_template

def plot(self, ax:Axes):
x = 0.57 # title 为整图居中,而非子图居中
y = (self._template.INDICATOR_LABEL_HEIGHT + self._template.INDICATOR_VALUE_HEIGHT) * self._indicator_area_rows + 0.1
ax.text(x, y, self._strategy_name, ha='center', va='bottom', color=BLACK, fontsize=TITLE_FONT_SIZE)

class WaterMark:
def __init__(self, img_width, img_height):
def __init__(self, img_width, img_height, strategy_name):
logo_file = os.path.join(
os.path.dirname(os.path.realpath(rqalpha.__file__)),
"resource", 'ricequant-logo.png')
Expand All @@ -142,21 +159,22 @@ def __init__(self, img_width, img_height):

def plot(self, fig: Figure):
fig.figimage(
self.logo_img,
xo=(self.img_width * self.dpi - self.logo_img.shape[1]) / 2,
yo=(self.img_height * self.dpi - self.logo_img.shape[0]) / 2,
alpha=0.4,
)

self.logo_img,
xo = (self.img_width * self.dpi - self.logo_img.shape[1]) / 2,
yo = (PLOT_AREA_HEIGHT * self.dpi - self.logo_img.shape[0]) / 2,
alpha=0.4
)

def _plot(title: str, sub_plots: List[SubPlot]):

def _plot(title: str, sub_plots: List[SubPlot], strategy_name):
img_height = sum(s.height for s in sub_plots)
water_mark = WaterMark(IMG_WIDTH, img_height)
water_mark = WaterMark(IMG_WIDTH, img_height, strategy_name)
fig = pyplot.figure(title, figsize=(IMG_WIDTH, img_height), dpi=water_mark.dpi, clear=True)
water_mark.plot(fig)

gs = gridspec.GridSpec(img_height, 8, figure=fig)
last_height = 0
if (strategy_name): last_height = 1
else: last_height = 0
for p in sub_plots:
p.plot(pyplot.subplot(gs[last_height:last_height + p.height, :p.right_pad]))
last_height += p.height
Expand All @@ -167,7 +185,7 @@ def _plot(title: str, sub_plots: List[SubPlot]):

def plot_result(
result_dict, show=True, save=None, weekly_indicators: bool = False, open_close_points: bool = False,
plot_template_cls=DefaultPlot
plot_template_cls=DefaultPlot, strategy_name=None
):
summary = result_dict["summary"]
portfolio = result_dict["portfolio"]
Expand Down Expand Up @@ -216,13 +234,13 @@ def plot_result(
sub_plots = [IndicatorArea(indicators, ChainMap(summary, {
"max_dd_ddd": "MaxDD {}\nMaxDDD {}".format(max_dd.repr, max_ddd.repr),
"excess_max_dd_ddd": ex_max_dd_ddd,
}), plot_template), ReturnPlot(
}), plot_template, strategy_name), ReturnPlot(
portfolio.unit_net_value - 1, return_lines, spots_on_returns
)]
if "plots" in result_dict:
sub_plots.append(UserPlot(result_dict["plots"]))
_plot(summary["strategy_file"], sub_plots)

_plot(summary["strategy_file"], sub_plots, strategy_name)

if save:
file_path = save
Expand Down