In [1]:
import numpy as np
import pandas as pd
from os import listdir, path
import matplotlib.pyplot as plt
import matplotlib as mpl
from typing import *
import math
import sys
from copy import deepcopy
from random import sample

def _is_colab() -> bool:
    return 'google.colab' in sys.modules

def mount_drive():
    if _is_colab():
        from google.colab import drive
        drive.mount("/content/gdrive")

mount_drive()

Mounted at /content/gdrive


In [2]:
def get_candlestick_color(close: float, open: float) -> str:
    if close > open:
        return 'white'
    elif close < open:
        return 'black'
    else:
        return 'gray'

def candlestick_plot(data: pd.DataFrame, save_as: str = None):
    x = np.arange(len(data))
    fig, ax = plt.subplots(1, figsize=(3,3))
    for idx, val in data.iterrows():
        o,h,l,c = val['open'], val['high'], val['low'], val['close']
        clr = get_candlestick_color(c, o)
        x_idx = x[idx]
        plt.plot([x_idx, x_idx], [l, h], color=clr) #wick
        plt.plot([x_idx, x_idx], [o, o], color=clr) #open marker
        plt.plot([x_idx, x_idx], [c, c], color=clr) #close marker
        rect = mpl.patches.Rectangle((x_idx-0.5, o), 1, (c - o), facecolor=clr, edgecolor='black', linewidth=1, zorder=3)
        ax.add_patch(rect)
    plt.axis('off')
    if type(save_as) is str:
        plt.savefig(save_as, bbox_inches="tight", pad_inches = 0)
        plt.close("all")
        plt.ioff()
    else:
        plt.show()




In [11]:
class ImageGenerator:
    def __init__(self, 
                 input_path: str,
                 tickers: List[str], 
                 output_path: str, 
                 period: int):
        self.input_path = input_path
        self.tickers = [x.replace(".csv", "") for x in tickers]
        self.output_path = output_path
        self.period = period
        self.__param_handler()

    def __param_handler(self):
        assert path.exists(self.input_path), f"Input path at {self.input_path} doesn't exist."
        assert path.exists(self.output_path), f"Output path at {self.output_path} doesn't exist."
        files = listdir(self.input_path)
        filtered_tickers = []
        for t in self.tickers:
            if f"{t}.csv" in files:
                filtered_tickers.append(t)
            else:
                print(f"{t}.csv doesn't exist. Skipping...")
        self.tickers = filtered_tickers

    def __load_as_df(self, ticker: str) -> pd.DataFrame:
        try:
            return pd.read_csv(f"{self.input_path}/{ticker}.csv")[["open", "high", "low", "close"]]
        except:
            return None

    def generate(self):
        ct = 0
        for ticker in self.tickers:
            df = self.__load_as_df(ticker)
            if not df is None:
                for idx in range(len(df) - self.period - 2):
                    data = df.iloc[idx:idx + self.period].reset_index(drop=True)
                    candlestick_plot(data, save_as=f"{self.output_path}/{ticker}_{idx}.png")
                    ct += 1



In [12]:
INPUT_PATH = "gdrive/MyDrive/PriceData/PriceData"
TICKERS = [x.replace(".csv", "") for x in sample(listdir(INPUT_PATH), 50)]
OUTPUT_PATH = "gdrive/MyDrive/SequentialCandles"
PERIOD = 20


In [None]:
ImageGenerator(INPUT_PATH, TICKERS, OUTPUT_PATH, PERIOD).generate()
