In [1]:
from base_types import IdxValue, DataElements
from data import Data
from plot import PricePlot
from utils import milliseconds_to_date
from base_types import DataType
import json
from analysis_utils import SavedInfo, read_data

<Figure size 1000x500 with 0 Axes>

In [3]:
import matplotlib.pyplot as plt
from mplfinance.original_flavor import candlestick_ohlc
from matplotlib import ticker
from base_types import DataElements
import numpy as np

def get_data_from_time(time_value: IdxValue, begin, end):
    timestamp = np.array(time_value.idx)
    value = np.array(time_value.value)
    idx = (timestamp >= begin) & (timestamp <= end)

    return timestamp[idx], value[idx]


def get_plot_data(saved_info: SavedInfo, start_idx=None, end_idx=None):
    data, buy_points, sell_points, tops, bottoms, earn_points, tops_confirm, bottoms_confirm = saved_info.get_all()
    subdata = data.data.loc[start_idx:end_idx, [DataElements.OPEN.value, DataElements.HIGH.value, 
                                DataElements.LOW.value, DataElements.CLOSE.value, DataElements.OPEN_TIME.value]].copy().reset_index()
    subdata['timestamp'] = subdata[DataElements.OPEN_TIME.value].map(milliseconds_to_date)


    start_time = subdata[DataElements.OPEN_TIME.value].values[0]
    end_time = subdata[DataElements.OPEN_TIME.value].values[-1]

    points = [ 
        PricePlot.Points(*get_data_from_time(buy_points, start_time, end_time), s=90, c='r', label='buy'),
        PricePlot.Points(*get_data_from_time(sell_points, start_time, end_time), s=90, c='g', label='sell'),
        PricePlot.Points(*get_data_from_time(tops, start_time, end_time), s=30, c='b', label='top'),
        PricePlot.Points(*get_data_from_time(bottoms, start_time, end_time), s=30, c='y', label='bottoms'),
        PricePlot.Points(*get_data_from_time(tops_confirm, start_time, end_time), s=10, c='m', label='tops_confirm'),
        PricePlot.Points(*get_data_from_time(bottoms_confirm, start_time, end_time), s=10, c='orange', label='bottoms_confirm'),
    ]

    for point in points:
        data.time_list_to_idx(point.idx)
    
    return {'subdata': subdata, 'points': points, 'earn_points': earn_points}

# subdata_multi_0_5, points_multi_0_5 = get_plot_data(info_multi_0_5, None, None)
# plot_data_new_policy = get_plot_data(info_new_policy)
# plot_data = get_plot_data(info)


In [4]:
def plot_with_data_points(all_data_points, plot_candle: bool):
    figure_num = len(all_data_points)
    subplot = None
    for num in range(figure_num):
        data = all_data_points[num]['subdata']
        points = all_data_points[num]['points']

        subplot = plt.subplot2grid((figure_num+1, 1), (num, 0), rowspan=1, colspan=1, sharex=subplot, sharey=subplot)

        if plot_candle:
            candlestick_ohlc(ax=subplot, quotes=data.values,
                                        width=0.7, colorup='g', colordown='r')
        else:
            subplot.plot(range(0, len(data)), data[DataElements.CLOSE.value].values,  # type: ignore
                        color="gray", linewidth=1.0, label='base')

        subplot.set_ylabel('Price') # type: ignore

        if points:
            for p in points:
                subplot.scatter(p.idx, p.value, s=p.s, c=p.c, label=p.label)  # type: ignore

    subplot = plt.subplot2grid((figure_num+1, 1), (figure_num, 0), rowspan=1, colspan=1, sharex=subplot)
    colors = ['b', 'r', 'g', 'y']
    for num in range(figure_num):
        earn_points = all_data_points[num]['earn_points']
        subplot.plot(earn_points.idx + [len(earn_points.idx)-1], earn_points.value + [earn_points.value[-1]], # type: ignore
                     color=colors[num], linewidth=1.0, label=str(num))

    def format_date(x, pos):
        if x < 0 or x > len(data) - 1:
            return ''
        return data['timestamp'].values[int(x)]
    subplot.xaxis.set_major_formatter(ticker.FuncFormatter(format_date)) # type: ignore

%matplotlib qt5
all_data = [
    get_plot_data(read_data(
            symbol = 'BTCBUSD',
            exp_name = 'SetExitAtCreationAtr300',
            start = 1652229060000,
            end = 1658229059999
    )),
    get_plot_data(read_data(
            symbol = 'BTCBUSD',
            exp_name = '+-3Atr300',
            start = 1652229060000,
            end = 1658229059999
    )),
]
plot_with_data_points(all_data, False)