In [40]:
from base_types import IdxValue, DataElements
from data import Data
from plot import PricePlot
from utils import milliseconds_to_date
from base_types import DataType
import json

%matplotlib qt5

def read_data(symbol, exp_name, start, end):
    start_str = milliseconds_to_date(start) + ' UTC+8'
    end_str = milliseconds_to_date(end + 1) + ' UTC+8'

    data = Data(symbol, DataType.INTERVAL_1MINUTE, start_str=start_str, end_str=end_str, is_futures=True)
    # print(data.start_time())
    # print(data.end_time())

    base_path = '.\\log\\{}\\{}_start_{}_end_{}_'.format(exp_name, symbol, data.start_time(), data.end_time())
    trade_info_path = base_path + 'trade_info.json'
    vertices_path = base_path + 'vertices.json'
    earn_path = base_path + 'earn_points.json'

    with open(trade_info_path, 'r') as f:
        json_data = f.read()
        trade_info = json.loads(json_data)

    with open(vertices_path, 'r') as f:
        json_data = f.read()
        vertices = json.loads(json_data)

    with open(earn_path, 'r') as f:
        json_data = f.read()
        earn_points_dict = json.loads(json_data)

    # print(trade_info.keys())
    # print(vertices.keys())
    # print(earn_points_dict.keys())

    buy_points = IdxValue(trade_info['buy_time'], trade_info['buy_price'])
    sell_points = IdxValue(trade_info['sell_time'], trade_info['sell_price'])
    tops = IdxValue(vertices['top_time'], vertices['top_value'])
    bottoms = IdxValue(vertices['bottom_time'], vertices['bottom_value'])
    earn_points = IdxValue(earn_points_dict['earn_idx'], earn_points_dict['earn_value'])
    
    return data, buy_points, sell_points, tops, bottoms, earn_points
    
data, buy_points, sell_points, tops, bottoms, earn_points = read_data(
    symbol = 'LUNA2BUSD',
    exp_name = 'threshold_20_multi_0.5',
    start = 1654052700000,
    end = 1655080739999
)

data, buy_points, sell_points, tops, bottoms, earn_points = read_data(
    symbol = 'LUNA2BUSD',
    exp_name = 'threshold_20',
    start = 1654052700000,
    end = 1655080739999
)


In [41]:
import matplotlib.pyplot as plt
import mpl_finance
from matplotlib import ticker
from base_types import DataElements
import numpy as np

def get_data_from_time(time_value: IdxValue, begin, end):
    timestamp = np.array(time_value.idx)
    value = np.array(time_value.value)
    idx = (timestamp >= begin) & (timestamp <= end)

    return timestamp[idx], value[idx]


def get_plot_data(data:Data, buy_points:IdxValue, sell_points:IdxValue, tops:IdxValue, bottoms:IdxValue, start_idx, end_idx):

    subdata = data.data.loc[start_idx:end_idx, [DataElements.OPEN.value, DataElements.HIGH.value, 
                                DataElements.LOW.value, DataElements.CLOSE.value, DataElements.OPEN_TIME.value]].copy().reset_index()
    subdata['timestamp'] = subdata[DataElements.OPEN_TIME.value].map(milliseconds_to_date)


    start_time = subdata[DataElements.OPEN_TIME.value].values[0]
    end_time = subdata[DataElements.OPEN_TIME.value].values[-1]

    points = [ 
        PricePlot.Points(*get_data_from_time(buy_points, start_time, end_time), s=90, c='r', label='buy'),
        PricePlot.Points(*get_data_from_time(sell_points, start_time, end_time), s=90, c='g', label='sell'),
        PricePlot.Points(*get_data_from_time(tops, start_time, end_time), s=30, c='b', label='top'),
        PricePlot.Points(*get_data_from_time(bottoms, start_time, end_time), s=30, c='y', label='bottoms'),
    ]

    for point in points:
        data.time_list_to_idx(point.idx)
    
    return subdata, points

subdata, points = get_plot_data(data, buy_points, sell_points, tops, bottoms, None, 1000)


In [46]:
def plot_with_data_points(data, points):
    figure_num = 1
    subplot = plt.subplot2grid((figure_num, 1), (0, 0), rowspan=1, colspan=1)
    ohlc = data

    mpl_finance.candlestick_ohlc(ax=subplot, quotes=ohlc.values,
                                    width=0.7, colorup='g', colordown='r')

    subplot.set_ylabel('Price') # type: ignore

    if points:
        for p in points:
            subplot.scatter(p.idx, p.value, s=p.s, c=p.c, label=p.label)  # type: ignore

    def format_date(x, pos):
        if x < 0 or x > data.shape[0] - 1:
            return ''
        return data['timestamp'].values[int(x)]
    subplot.xaxis.set_major_formatter(ticker.FuncFormatter(format_date)) # type: ignore

plot_with_data_points(subdata, points)