In [116]:
# Necessary imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os

In [117]:
class charts():
    def __init__(self,
                 window_size=60, 
                 data_path='data-SPY-20240617/yfinance_SPY.csv',
                 ma_cols=['5_day_MA', '20_day_MA', '60_day_MA'],
                 other_cols=[],
                 other_cols_color=[],
                 freq='D',
                 percent_train=0.7,
                 ):
        # Load data from data_path
        data = pd.read_csv(data_path, index_col=0, parse_dates=True)

        self.window_size = window_size

        # Set the symbol for the data, default is SPY
        self.symbol = 'SPY'
        if 'symbol' in data.columns:
            self.symbol = data['symbol'].iloc[0]
            data.drop('symbol', axis=1, inplace=True)

        # Resample at the desired frequency, default is daily (only support daily, weekly, and monthly data for now)
        self.freq = freq
        if self.freq == 'D':
            self.data = data.copy()
        elif self.freq == 'M':
            self.data = data.resample('ME').last().copy()
        elif self.freq == 'W':
            self.data = data.resample('W').last().copy()

        # Set all column names
        self.feat_cols = ['open', 'high', 'low', 'close'] + ma_cols + other_cols
        self.ma_cols = ma_cols
        self.other_cols = other_cols
        self.other_cols_color = other_cols_color
        self.cutoff = percent_train

    # Main function
    def plot(self, df):
        if self.window_size == 60:
            const = [76, 19, 96]
        elif self.window_size == 20:
            const = [51, 12, 64]
        elif self.window_size == 5:
            const = [25, 6, 32]
        df = df.reset_index(drop=True).copy()
        high = df[['high'] + self.ma_cols + self.other_cols].values.max()
        low = df[['low'] + self.ma_cols + self.other_cols].values.min()
        for col in self.feat_cols:
            df[col] = (df[col] - low) / (high - low) * const[0]
            df[col] = df[col].round(0).astype(int) + int(const[1])

        df['volume'] = df['volume'] / df['volume'].max() * const[1]
        df['volume'] = df['volume'].round(0).astype(int)

        plot = np.zeros((const[-1], 3 * self.window_size, 3))

        for idx, row in df.iterrows():
            plot[int(row['open']), idx*3, :] = [255] * 3
            plot[int(row['low']):int(row['high']+1), idx*3+1, :] = [255] * 3
            plot[int(row['close']):int(row['close']+1), idx*3+2, :] = [255] * 3
            plot[:int(row['volume']+1), idx*3+1, :] = [255] * 3

            for col in self.ma_cols:
                pre_ma = df.loc[idx-1, col] if idx >= 1 else df.loc[idx, col]
                next_ma = df.loc[idx+1, col] if idx <= len(df)-2 else df.loc[idx, col]

                plot[int((row[col] + pre_ma)//2), idx*3, :] = [255] * 3
                plot[int(row[col]), idx*3+1, :] = [255] * 3
                plot[int((row[col] + next_ma)//2), idx*3+2, :] = [255] * 3
            
            for i, col in enumerate(self.other_cols):
                pre = df.loc[idx-1, col] if idx >= 1 else df.loc[idx, col]
                nex = df.loc[idx+1, col] if idx <= len(df)-2 else df.loc[idx, col]

                plot[int((row[col] + pre)//2), idx*3, :] = self.other_cols_color[i]
                plot[int(row[col]), idx*3+1, :] = self.other_cols_color[i]
                plot[int((row[col] + nex)//2), idx*3+2, :] = self.other_cols_color[i]

        plot = plot[::-1, :, :]  # reversion

        return plot
    
    def generate(self):
        m, _ = self.data.shape
        counter = 0
        if not os.path.exists('train'):
            os.makedirs('train')
        if not os.path.exists('test'):
            os.makedirs('test')
        train_test_cutoff = int(self.cutoff * m)
        while counter < train_test_cutoff:
            curr_plot = self.plot(self.data.iloc[counter:counter+self.window_size, :].copy())
            counter += self.window_size
            curr_dir = []
            for i in [5, 10, 15]:
                if counter + i < m:
                    if self.data.iloc[counter + i, :].loc['close'] > self.data.iloc[counter, :].loc['close']:
                        curr_dir.append(1)
                    else:
                        curr_dir.append(-1)
                else:
                    curr_dir.append(np.nan)
            plt.imshow(curr_plot)
            plt.axis('off')
            plt.savefig(f"train/spy_{self.window_size}_{self.freq}_{curr_dir[0]}_{curr_dir[1]}_{curr_dir[2]}_{self.data.index.values[counter]}_version_1.png", bbox_inches='tight', pad_inches=0)
            plt.close()
        counter += 15
        while counter < m:
            curr_plot = self.plot(self.data.iloc[counter:counter+self.window_size, :].copy())
            counter += self.window_size
            curr_dir = []
            for i in [5, 10, 15]:
                if counter + i < m:
                    if self.data.iloc[counter + i, :]['close'] > self.data.iloc[counter, :]['close']:
                        curr_dir.append(1)
                    else:
                        curr_dir.append(-1)
                else:
                    curr_dir.append(np.nan)
            plt.imshow(curr_plot)
            plt.axis('off')
            plt.savefig(f"test/spy_{self.window_size}_{self.freq}_{curr_dir[0]}_{curr_dir[1]}_{curr_dir[2]}_{self.data.index.values[counter] if counter < m else self.data.index.values[m - 1]}_version_1.png", bbox_inches='tight', pad_inches=0)
            plt.close()


In [118]:
ohlc = charts()
ohlc.generate()

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i