In [2]:
# candlestick plot functions
import pandas as pd
import os
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from mplfinance.original_flavor import candlestick_ohlc
import seaborn as sns


def create_candlestick_grid(data_list, rows, cols, filename='candlestick_grid.png'):
    sns.set(style="whitegrid")

    # Calculate figure size for office paper (8.5 x 11 inches) at 300 DPI
    fig_width = 8.5
    fig_height = 11
    dpi = 300

    fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height), dpi=dpi)
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration

    for ax, data in zip(axes, data_list):
        # Convert timestamp to datetime
        data = data.copy()
        data['timestamp'] = pd.to_datetime(data['timestamp'], unit='s')
        
        # Prepare data for candlestick plotting
        btc_data = data[['timestamp', 'open (BTC)', 'high (BTC)', 'low (BTC)', 'close (BTC)']].copy()
        eth_data = data[['timestamp', 'open (ETH)', 'high (ETH)', 'low (ETH)', 'close (ETH)']].copy()
        
        # Convert timestamp to Matplotlib date format
        btc_data['timestamp'] = btc_data['timestamp'].apply(mdates.date2num)
        eth_data['timestamp'] = eth_data['timestamp'].apply(mdates.date2num)
        
        # Plot BTC candlestick chart
        candlestick_ohlc(ax, btc_data.values, width=0.01, colorup='green', colordown='red')
        ax.xaxis_date()
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
        ax.set_title('BTC-USD')
        ax.set_ylabel('Price (BTC)')
        ax.set_xlabel('Time')
        plt.xticks(rotation=45)
    
    # Hide any remaining empty subplots
    for i in range(len(data_list), len(axes)):
        fig.delaxes(axes[i])
    
    plt.tight_layout()
    plt.savefig(filename)
    plt.show()
    

def create_dual_candlestick_plot(data, symbol1, symbol2, y_range1=None, y_range2=None):
    # Convert timestamp to datetime
    data = data.copy()
    data['timestamp'] = pd.to_datetime(data['timestamp'], unit='s')
    
    # Prepare data for candlestick plotting
    symbol1_data = data[['timestamp', f'open ({symbol1})', f'high ({symbol1})', f'low ({symbol1})', f'close ({symbol1})']].copy()
    symbol2_data = data[['timestamp', f'open ({symbol2})', f'high ({symbol2})', f'low ({symbol2})', f'close ({symbol2})']].copy()
    
    # Convert timestamp to Matplotlib date format
    symbol1_data['timestamp'] = symbol1_data['timestamp'].apply(mdates.date2num)
    symbol2_data['timestamp'] = symbol2_data['timestamp'].apply(mdates.date2num)
    
    sns.set_theme(style="darkgrid")
    
    fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 10), sharex=True)
    
    # Plot symbol1 candlestick chart
    candlestick_ohlc(ax1, symbol1_data.values, width=0.01, colorup='green', colordown='red')
    ax1.xaxis_date()
    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
    ax1.set_title(f'{symbol1} Candlestick Chart')
    ax1.set_ylabel(f'Price ({symbol1})')
    
    # Set y-axis limits for symbol1 chart if provided
    if y_range1:
        ax1.set_ylim(y_range1)
    
    # Plot symbol2 candlestick chart
    candlestick_ohlc(ax2, symbol2_data.values, width=0.01, colorup='blue', colordown='orange')
    ax2.xaxis_date()
    ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
    ax2.set_title(f'{symbol2} Candlestick Chart')
    ax2.set_ylabel(f'Price ({symbol2})')
    ax2.set_xlabel('Timestamp')
    
    # Set y-axis limits for symbol2 chart if provided
    if y_range2:
        ax2.set_ylim(y_range2)
    
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

def create_minimal_candlestick_plot(data, symbol1, symbol2, y_range1=None, y_range2=None, output_dir='plots', sample_id=1, image_size=(64, 64)):
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert timestamp to datetime
    data = data.copy()
    data['timestamp'] = pd.to_datetime(data['timestamp'], unit='s')
    
    # Prepare data for candlestick plotting
    symbol1_data = data[['timestamp', f'open ({symbol1})', f'high ({symbol1})', f'low ({symbol1})', f'close ({symbol1})']].copy()
    symbol2_data = data[['timestamp', f'open ({symbol2})', f'high ({symbol2})', f'low ({symbol2})', f'close ({symbol2})']].copy()
    
    # Convert timestamp to Matplotlib date format
    symbol1_data['timestamp'] = symbol1_data['timestamp'].apply(mdates.date2num)
    symbol2_data['timestamp'] = symbol2_data['timestamp'].apply(mdates.date2num)
    
    sns.set_theme(style="darkgrid")
    
    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=image_size, dpi=100)
    ax1, ax2 = axes
    
    # Plot symbol1 candlestick chart
    candlestick_ohlc(ax1, symbol1_data.values, width=0.01, colorup='green', colordown='red')
    ax1.axis('off')  # Turn off axis labels, ticks, and grid
    if y_range1:
        ax1.set_ylim(y_range1)
    
    # Plot symbol2 candlestick chart
    candlestick_ohlc(ax2, symbol2_data.values, width=0.01, colorup='blue', colordown='orange')
    ax2.axis('off')  # Turn off axis labels, ticks, and grid
    if y_range2:
        ax2.set_ylim(y_range2)
    
    plt.subplots_adjust(wspace=0, hspace=0)  # Remove space between subplots
    plt.tight_layout(pad=0)  # Remove padding
    
    # Save the figure as an image file
    plot_filename = os.path.join(output_dir, f'candlestick_plot_{sample_id}.png')
    plt.savefig(plot_filename, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    
    return plot_filename


In [3]:
#get_market_data_for_symbols(list_of_symbols, timeframe):
import pandas as pd
from data_service import DataService


def get_market_data_for_symbols(list_of_symbols, timeframe):
    """queries all marketdata available in database for the selected symbols and timeframe

    Each series name includes symbol

        params:
            list_of_symbols: format like 'ETH-USDT'
            timeframe: format like '1min'

        return: merged dataframe
    """
    dataframes = []
    for symbol in list_of_symbols:
        df = DataService(symbol, timeframe).load_market_data()
        df.columns = [col + f" ({symbol.split("-")[0]})" if col != "timestamp" else col for col in df.columns]
        dataframes.append(df)
    
    for i in range(len(dataframes) - 1):
        df = pd.merge(df, dataframes[i], on='timestamp', how='inner')
    return df

In [4]:
#get_previous_candles(data, window_size=45):

from datetime import datetime
def get_previous_candles(data, window_size=45, win_value=1):
    dataframes = []
    # Find the indices where 'win' column matches the win_value
    win_indices = data.index[data['win'] == win_value].tolist()
    
    for idx in win_indices:
        if idx >= window_size:
            # Extract the previous `window_size` rows excluding the originating row
            df = data.iloc[idx - window_size:idx].copy()
            dataframes.append(df)
    
    return dataframes


def extract_previous_candles_with_outcome(data, window_size=45):
    results = []
    
    for idx in range(len(data)):
        if idx >= window_size:
            # Extract the previous `window_size` rows excluding the originating row
            df = data.iloc[idx - window_size:idx].copy()
            # Extract the 'win' and 'timestamp' values for the current idx
            win_value = data.iloc[idx]['win']
            timestamp_value = data.iloc[idx]['timestamp']
            # Append the dataframe, 'win' value, and 'timestamp' as a tuple to the results list
            results.append({
                'data': df, 
                'label': int(win_value), 
                'timestamp': int(timestamp_value)
            })
    
    return results

In [5]:
#get_following_candles(data, window_size=45):
def get_following_candles(data, window_size=45):
    dataframes = []
    # Find the indices where 'wins' column is 1
    win_indices = data.index[data['win'] == 1].tolist()
    
    for idx in win_indices:
        # Ensure there are enough rows after the current index
        if idx + 1 + window_size <= len(data):
            # Extract the following 45 rows, excluding the originating row, and make a copy to avoid SettingWithCopyWarning
            df = data.iloc[idx + 1:idx + 1 + window_size].copy()
            dataframes.append(df)
    
    return dataframes

In [6]:
#scale_dfs_with_multiple_scalers(symbol_a, symbol_b, list_of_dataframes):
from sklearn.preprocessing import RobustScaler, StandardScaler, MinMaxScaler


def scale_dfs_with_multiple_scalers(list_of_dataframes, symbol_a="ETH-USDT", symbol_b="BTC-USDT"):
    symbol_a, symbol_b = [_.split("-")[0] for _ in (symbol_a, symbol_b)]
    cols_to_scale = [
                    f'open ({symbol_b})', f'close ({symbol_b})', f'high ({symbol_b})', f'low ({symbol_b})',
                    f'volume ({symbol_b})', f'amount ({symbol_b})', 
                    f'open ({symbol_a})', f'close ({symbol_a})', f'high ({symbol_a})', f'low ({symbol_a})', 
                    f'volume ({symbol_a})', f'amount ({symbol_a})',
                    ]

    scalers = [RobustScaler(), StandardScaler(), MinMaxScaler()]
    for df in list_of_dataframes:
            for scaler in scalers:
                scaler_name = scaler.__class__.__name__
                for col in cols_to_scale:
                    if col in df.columns:
                        df[col + "_" + scaler_name] = scaler.fit_transform(df.copy()[[col]])
                    else:
                        print(f"Warning: {col} not found in DataFrame")
        
    return list_of_dataframes

In [7]:
#drop_unwanted_columns(df):
def drop_unwanted_columns(df):
    drop_cols = [col for col in df.columns 
                 if 'id (' in col 
                 or 'symbol (' in col 
                 or 'timeframe (' in col]
    df = df.drop(drop_cols, axis=1)
    return df

In [8]:
#label_wins(df, symbol, long_or_short, candle_span, pct_chg_threshold=0.01):
from labeler import BinaryWinFinder 

def label_wins(df, symbol, long_or_short, candle_span, pct_chg_threshold=0.01):
    win_finder = BinaryWinFinder(df, f"{symbol}", long_or_short, candle_span, pct_chg_threshold)
    return win_finder.find_wins()

In [9]:
def replace_with_scaled_columns(list_of_dataframes):
    for df in list_of_dataframes:
        # Identify the scaler columns
        scaler_cols = [col for col in df.columns if "Scaler" in col]
        
        # Extract the scaled columns
        scaled_cols = df[scaler_cols].copy()
        
        # Drop the scaler columns from the DataFrame
        df.drop(columns=scaler_cols, inplace=True)
        
        # Replace the unscaled columns with the scaled columns
        for col in df.columns:
            for scaler_col in scaler_cols:
                if col in scaler_col:
                    df[col] = scaled_cols[scaler_col]
                    break
    return list_of_dataframes

In [10]:
def scale_and_annotate_samples(samples_with_metadata, symbol_a="ETH-USDT", symbol_b="BTC-USDT"):
    """
    Extracts data, timestamps, and labels from the input list of samples, applies the scaler function to the data,
    and returns a list of dictionaries with scaled data, timestamps, and labels.

    Args:
        samples_with_metadata (list of dict): List of samples, each containing 'data', 'timestamp', and 'label'.
        symbol_a (str): The symbol for the first asset (default is "ETH-USDT").
        symbol_b (str): The symbol for the second asset (default is "BTC-USDT").

    Returns:
        list of dict: List of dictionaries, each containing 'dataframe', 'timestamp', and 'label'.
    """

    def scale_dfs_with_multiple_scalers(list_of_dataframes):
        # Extract asset names from symbols
        asset_a, asset_b = [_.split("-")[0] for _ in (symbol_a, symbol_b)]
        
        # Columns to scale
        cols_to_scale = [
            f'open ({asset_b})', f'close ({asset_b})', f'high ({asset_b})', f'low ({asset_b})',
            f'volume ({asset_b})', f'amount ({asset_b})', 
            f'open ({asset_a})', f'close ({asset_a})', f'high ({asset_a})', f'low ({asset_a})', 
            f'volume ({asset_a})', f'amount ({asset_a})',
        ]

        # List of scalers
        scalers = [RobustScaler(), StandardScaler(), MinMaxScaler()]
        
        # Scale columns in each dataframe
        for df in list_of_dataframes:
            for scaler in scalers:
                scaler_name = scaler.__class__.__name__
                for col in cols_to_scale:
                    if col in df.columns:
                        df[col + "_" + scaler_name] = scaler.fit_transform(df.copy()[[col]])
                    else:
                        print(f"Warning: {col} not found in DataFrame")
        
        return list_of_dataframes

    # Extract data, timestamps, and labels from the samples
    data_list = [_['data'] for _ in samples_with_metadata]
    timestamp_list = [_['timestamp'] for _ in samples_with_metadata]
    label_list = [_['label'] for _ in samples_with_metadata]

    # Apply the scaler function to the data
    scaled_data = scale_dfs_with_multiple_scalers(data_list)

    # Create the list of dictionaries with scaled data, timestamps, and labels
    scaled_samples_with_metadata = [{'dataframe': df, 'timestamp': ts, 'label': lbl} 
                                    for df, ts, lbl in zip(scaled_data, timestamp_list, label_list)]
    
    return scaled_samples_with_metadata

In [11]:
# get market data for multiple symbols
df = get_market_data_for_symbols(['ETH-USDT', 'BTC-USDT'], "3min")
df = drop_unwanted_columns(df)
df["win"] = label_wins(df, 'ETH-USDT', "long", 40)
samples_with_metadata = extract_previous_candles_with_outcome(df)
scaled_samples_with_metadata = scale_and_annotate_samples(samples_with_metadata)



In [12]:
data_service = DataService("ETH-USDT", "3min") 
data_service.save_samples_to_collection(samples_with_metadata=scaled_samples_with_metadata, collection_name="ETH-USDT_with_BTC_45_previous_candles_TEST")

Collection 'ETH-USDT_with_BTC_45_previous_candles_TEST' saved successfully with 19955 samples.


In [None]:
display(wins[0])

In [None]:
def get_ranges(symbol, list_of_dataframes):
    range_max = float('-inf')
    range_min = float('inf')
    
    for df in list_of_dataframes:
        # Get the high and low columns for the symbol
        high_col = f"high ({symbol})"
        low_col = f"low ({symbol})"
        
        if high_col in df and low_col in df:
            # Update range_max and range_min
            range_max = max(range_max, df[high_col].max())
            range_min = min(range_min, df[low_col].min())
    
    return range_min, range_max

In [None]:
symbols = ["ETH", "BTC"]
range_a, range_b = [get_ranges(symbol, wins) for symbol in symbols]
print(symbols)
print(range_a, range_b)


In [None]:
print(wins[0].columns)

In [None]:
display(wins[0])

In [None]:
for i, df in enumerate(wins):
    create_minimal_candlestick_plot(df, symbols[0], symbols[1], range_a, range_b, output_dir="test", sample_id=i)