In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.patches import Rectangle
import cv2
import yfinance as yf
from datetime import datetime, timedelta
from ultralytics import YOLO
import random
import shutil
from tqdm import tqdm
import yaml
import torch

## 2. Preparing the dataset

### 2.1 Generating candlestick chart images

In [2]:
def generate_candlestick_chart(ticker, period="2y", interval="1d", output_dir="charts"):
    """
    Generate candlestick chart images from financial data
    
    Args:
        ticker (str): Stock ticker symbol
        period (str): Data period (e.g., '2y' for 2 years)
        interval (str): Data interval (e.g., '1d' for daily)
        output_dir (str): Directory to save generated charts
    
    Returns:
        str: Path to the saved chart image
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Fetch data using yfinance
    data = yf.download(ticker, period=period, interval=interval)
    data.reset_index(inplace=True, drop=True)
    data.columns = ['Close', 'High', 'Low', 'Open', 'Volume']
    
    # Check if data is empty
    if len(data) == 0:
        print(f"No data found for {ticker}")
        return None
    
    # Create figure and axis
    fig, ax = plt.figure(figsize=(12, 8), dpi=100), plt.subplot(1, 1, 1)
    
    # Format date axis
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    
    # Plot candlestick chart
    width = 0.6
    width2 = 0.05
    
    # Loop through data and plot each candlestick
    for i, (idx, row) in enumerate(data.iterrows()):
        # Calculate colors based on price movement
        if row['Close'] >= row['Open']:
            color = 'green'
            body_height = row['Close'] - row['Open']
        else:
            color = 'red'
            body_height = row['Open'] - row['Close']
        
        # Plot the candlestick body
        rect = Rectangle(
            xy=(i-width/2, min(row['Open'], row['Close'])),
            width=width,
            height=body_height,
            facecolor=color,
            edgecolor='black',
            linewidth=0.5
        )
        ax.add_patch(rect)
        
        # Plot the upper and lower wicks
        ax.plot([i, i], [row['Low'], min(row['Open'], row['Close'])], color='black', linewidth=0.5)
        ax.plot([i, i], [max(row['Open'], row['Close']), row['High']], color='black', linewidth=0.5)
    
    # Set title and labels
    ax.set_title(f"{ticker} Candlestick Chart", fontsize=12)
    ax.set_xlabel("Date", fontsize=10)
    ax.set_ylabel("Price", fontsize=10)
    
    # Adjust the x-axis to show the most recent time period
    ax.set_xlim(-1, len(data) + 1)
    
    # Set y-axis limits with some padding
    y_min = data['Low'].min() * 0.95
    y_max = data['High'].max() * 1.05
    ax.set_ylim(y_min, y_max)
    
    # Generate a unique filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = os.path.join(output_dir, f"{ticker}_{timestamp}.jpg")
    
    # Save the chart as an image
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close(fig)
    
    return output_path

### 2.2 Generating candlestick patterns

In [3]:
def generate_pattern_dataset(tickers, patterns, num_samples_per_pattern=100, output_dir="pattern_dataset"):
    """
    Generate a dataset of candlestick patterns
    
    Args:
        tickers (list): List of stock ticker symbols
        patterns (dict): Dictionary mapping pattern names to functions that check for patterns
        num_samples_per_pattern (int): Number of samples to generate per pattern
        output_dir (str): Directory to save the dataset
    
    Returns:
        list: Paths to the generated chart images with pattern annotations
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Create pattern-specific directories
    for pattern_name in patterns.keys():
        os.makedirs(os.path.join(output_dir, pattern_name), exist_ok=True)
    
    generated_charts = []
    
    # Number of patterns detected per type
    pattern_counts = {pattern_name: 0 for pattern_name in patterns.keys()}
    
    # Loop through tickers to get data
    for ticker in tqdm(tickers, desc="Processing tickers"):
        try:
            # Get longer period data to have enough patterns
            data = yf.download(ticker, period="5y", interval="1d")
            print(data.info(), data.columns)
            data.reset_index(inplace=True, drop=True)
            data.columns = ['Close', 'High', 'Low', 'Open', 'Volume']
            print(data.info(), data.columns)
            
            if len(data) < 30:  # Skip if not enough data
                continue
                
            print(f"Downloaded {len(data)} days of data for {ticker}")
            
            # Check for patterns in data
            for pattern_name, pattern_func in patterns.items():
                try:
                    pattern_signals = pattern_func(data)
                    
                    # Convert to regular Python list of indices where patterns are found
                    pattern_indices = []
                    for i in range(len(pattern_signals)):
                        if pattern_signals[i]:
                            pattern_indices.append(i)
                    
                    pattern_count = len(pattern_indices)
                    
                    print(f"Found {pattern_count} {pattern_name} patterns in {ticker}")
                    pattern_counts[pattern_name] += pattern_count
                    
                    # If patterns are found, generate images
                    if pattern_count > 0:
                        # Select random samples if there are more than needed
                        samples_per_ticker = min(
                            num_samples_per_pattern // len(tickers), 
                            len(pattern_indices)
                        )
                        samples_per_ticker = max(1, samples_per_ticker)  # Ensure at least 1 sample
                        if len(pattern_indices) > samples_per_ticker:
                            pattern_indices = np.random.choice(
                                pattern_indices, 
                                size=samples_per_ticker, 
                                replace=False
                            ).tolist()
                        
                        # Generate chart for each pattern instance
                        for idx in pattern_indices:
                            # Ensure we have enough data before and after the pattern
                            if idx < 10 or idx >= len(data) - 6:
                                continue
                            # Extract data window (10 candles before, pattern, 5 candles after)
                            window_data = data.iloc[idx-10:idx+6]
                            # Create figure and axis with a specific figure number to avoid warnings
                            fig_num = len(generated_charts) % 20  # Reuse figure numbers to avoid too many open figures
                            plt.close(fig_num)  # Close if previously opened
                            fig = plt.figure(num=fig_num, figsize=(12, 8), dpi=100)
                            ax = fig.add_subplot(1, 1, 1)
                            
                            # Plot candlestick chart for the window
                            width = 0.6
                            # Loop through window data and plot each candlestick
#                             print(window_data.iterrows())
                            for i, (window_idx, row) in enumerate(window_data.iterrows()):
                                # Calculate colors based on price movement
                                if row['Close'] >= row['Open']:
                                    color = 'green'
                                    body_height = row['Close'] - row['Open']
                                else:
                                    color = 'red'
                                    body_height = row['Open'] - row['Close']
                                
                                # Is this candle part of the pattern? (center of the window)
                                is_pattern = False
                                
                                # Single-candle patterns
                                if pattern_name in ['doji', 'hammer', 'shooting_star'] and i == 10:
                                    is_pattern = True
                                # Two-candle patterns
                                elif pattern_name == 'engulfing' and i in [9, 10]:
                                    is_pattern = True
                                # Three-candle patterns
                                elif pattern_name in ['morning_star', 'three_white_soldiers'] and i in [8, 9, 10]:
                                    is_pattern = True
                                
                                # Plot the candlestick body
                                rect = Rectangle(
                                    xy=(i-width/2, min(row['Open'], row['Close'])),
                                    width=width,
                                    height=body_height,
                                    facecolor=color,
                                    edgecolor='black' if not is_pattern else 'blue',
                                    linewidth=0.5 if not is_pattern else 2.0
                                )
                                ax.add_patch(rect)
                                
                                # Plot the upper and lower wicks
                                ax.plot([i, i], [row['Low'], min(row['Open'], row['Close'])], 
                                       color='black' if not is_pattern else 'blue', 
                                       linewidth=0.5 if not is_pattern else 2.0)
                                ax.plot([i, i], [max(row['Open'], row['Close']), row['High']], 
                                       color='black' if not is_pattern else 'blue', 
                                       linewidth=0.5 if not is_pattern else 2.0)
                            
                            # Set y-axis limits with some padding
                            y_min = window_data['Low'].min() * 0.95
                            y_max = window_data['High'].max() * 1.05
                            ax.set_ylim(y_min, y_max)
                            
                            # Generate a unique filename
                            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                            rand_id = np.random.randint(0, 10000)
                            chart_path = os.path.join(
                                output_dir, 
                                pattern_name, 
                                f"{ticker}_{pattern_name}_{timestamp}_{rand_id}.jpg"
                            )
                            
                            # Save the chart as an image
                            plt.tight_layout()
                            plt.savefig(chart_path)
                            plt.close(fig)
                            
                            # Calculate bounding box coordinates for the pattern
                            pattern_x_center = 10  # Default for single-candle patterns
                            
                            # Depending on pattern type, set different bounding box coordinates
                            if pattern_name in ['doji', 'hammer', 'shooting_star']:
                                print("Setting Bounding Box")
                                # Single candle patterns
                                x_min = pattern_x_center - width/2
                                x_max = pattern_x_center + width/2
                                
                                # Get the specific candle
                                candle = window_data.iloc[10]
                                
                                y_min = min(candle['Open'], candle['Close'])
                                y_max = max(candle['Open'], candle['Close'])
                                
                                # For hammer, include the lower shadow
                                if pattern_name == 'hammer':
                                    y_min = candle['Low']
                                
                                # For shooting star, include the upper shadow
                                if pattern_name == 'shooting_star':
                                    y_max = candle['High']
                                
                            elif pattern_name == 'engulfing':
                                # Two candle pattern (indices 9-10)
                                x_min = 9 - width/2
                                x_max = 10 + width/2
                                
                                # First candle
                                candle1 = window_data.iloc[9]
                                # Second candle
                                candle2 = window_data.iloc[10]
                                
                                y_min = min(candle1['Open'], candle1['Close'], candle2['Open'], candle2['Close'])
                                y_max = max(candle1['Open'], candle1['Close'], candle2['Open'], candle2['Close'])
                            
                            elif pattern_name == 'morning_star':
                                # Three candle pattern (indices 8-10)
                                x_min = 8 - width/2
                                x_max = 10 + width/2
                                
                                # Get all three candles
                                candles = window_data.iloc[8:11]
                                
                                y_min = min(candles['Low'].min(), candles['Open'].min(), candles['Close'].min())
                                y_max = max(candles['High'].max(), candles['Open'].max(), candles['Close'].max())
                            
                            elif pattern_name == 'three_white_soldiers':
                                # Three candle pattern (indices 8-10)
                                x_min = 8 - width/2
                                x_max = 10 + width/2
                                
                                # Get all three candles
                                candles = window_data.iloc[8:11]
                                
                                y_min = min(candles['Low'].min(), candles['Open'].min(), candles['Close'].min())
                                y_max = max(candles['High'].max(), candles['Open'].max(), candles['Close'].max())
                            
                            # Normalize coordinates to 0-1 range for YOLO format
                            img_width = fig.get_size_inches()[0] * fig.dpi
                            img_height = fig.get_size_inches()[1] * fig.dpi
                            
                            # Convert to YOLO format: [class_id, x_center, y_center, width, height]
                            # All values normalized to 0-1
                            class_id = list(patterns.keys()).index(pattern_name)
                            x_center = (x_min + x_max) / 2 / img_width
                            y_center = (y_min + y_max) / 2 / img_height
                            width_norm = (x_max - x_min) / img_width
                            height_norm = (y_max - y_min) / img_height
                            
                            # Create YOLO annotation file
                            annotation_path = chart_path.replace('.jpg', '.txt')
                            with open(annotation_path, 'w') as f:
                                f.write(f"{class_id} {x_center} {y_center} {width_norm} {height_norm}")
                            
                            generated_charts.append(chart_path)
                            
                            # Print out success message
                            print(f"Generated chart for {pattern_name} pattern in {ticker}")
                except Exception as e:
                    print(f"Error processing {pattern_name} pattern for {ticker}: {e}")
                    continue
                        
        except Exception as e:
            print(f"Error processing {ticker}: {e}")
            continue
    
    # Print summary of pattern counts
    print("\nPattern detection summary:")
    for pattern_name, count in pattern_counts.items():
        print(f"{pattern_name}: {count}")
    
    return generated_charts


### 2.3 Define pattern detection functions

In [4]:
def detect_doji(data):
    """Detect Doji pattern"""
    # A doji has open and close prices that are almost equal
    body_size = abs(data['Open'] - data['Close']).values
    shadow_size = (data['High'] - data['Low']).values
    
    # Doji has very small body compared to its shadows
    # Using element-wise comparison
    doji_pattern = np.zeros(len(data), dtype=bool)
    
    for i in range(len(data)):
        # Avoid division by zero
        if shadow_size[i] > 0:
            if body_size[i] / shadow_size[i] < 0.1:
                doji_pattern[i] = True
    
    return doji_pattern

def detect_hammer(data):
    """Detect Hammer pattern"""
    hammer_pattern = np.zeros(len(data), dtype=bool)
    
    for i in range(len(data)):
        candle = data.iloc[i]
        
        # Calculate body size
        body_size = abs(candle['Open'] - candle['Close'])
        
        # Calculate lower and upper shadows
        lower_shadow = min(candle['Open'], candle['Close']) - candle['Low']
        upper_shadow = candle['High'] - max(candle['Open'], candle['Close'])
        
        # Hammer criteria:
        # 1. Lower shadow at least 2x body
        # 2. Upper shadow less than half of body
        if lower_shadow > (body_size * 2) and upper_shadow < (body_size * 0.5):
            hammer_pattern[i] = True
    
    return hammer_pattern

def detect_shooting_star(data):
    """Detect Shooting Star pattern"""
    shooting_star_pattern = np.zeros(len(data), dtype=bool)
    
    for i in range(len(data)):
        candle = data.iloc[i]
        
        # Calculate body size
        body_size = abs(candle['Open'] - candle['Close'])
        
        # Calculate lower and upper shadows
        lower_shadow = min(candle['Open'], candle['Close']) - candle['Low']
        upper_shadow = candle['High'] - max(candle['Open'], candle['Close'])
        
        # Shooting star criteria:
        # 1. Upper shadow at least 2x body
        # 2. Lower shadow less than half of body
        if upper_shadow > (body_size * 2) and lower_shadow < (body_size * 0.5):
            shooting_star_pattern[i] = True
    
    return shooting_star_pattern

def detect_engulfing(data):
    """Detect Bullish Engulfing pattern"""
    engulfing_pattern = np.zeros(len(data), dtype=bool)
    
    for i in range(1, len(data)):
        prev_candle = data.iloc[i-1]
        curr_candle = data.iloc[i]
        
        # Previous candle is bearish (red), current is bullish (green)
        prev_bearish = prev_candle['Close'] < prev_candle['Open']
        curr_bullish = curr_candle['Close'] > curr_candle['Open']
        
        # Current candle body completely engulfs previous candle body
        body_engulfing = (
            curr_candle['Open'] < prev_candle['Close'] and
            curr_candle['Close'] > prev_candle['Open']
        )
        
        if prev_bearish and curr_bullish and body_engulfing:
            engulfing_pattern[i] = True
    
    return engulfing_pattern

def detect_morning_star(data):
    """Detect Morning Star pattern (three-candle pattern)"""
    morning_star = np.zeros(len(data), dtype=bool)
    
    for i in range(2, len(data)):
        first_candle = data.iloc[i-2]
        middle_candle = data.iloc[i-1]
        last_candle = data.iloc[i]
        
        # First candle is bearish (red)
        first_bearish = first_candle['Close'] < first_candle['Open']
        
        # Middle candle is small (doji or small body)
        middle_body = abs(middle_candle['Close'] - middle_candle['Open'])
        first_body = abs(first_candle['Open'] - first_candle['Close'])
        middle_small = middle_body < (first_body * 0.3)
        
        # Last candle is bullish (green)
        last_bullish = last_candle['Close'] > last_candle['Open']
        
        # Gap down from first to middle
        gap_down = middle_candle['High'] < first_candle['Close']
        
        # Gap up from middle to last
        gap_up = last_candle['Low'] > middle_candle['High']
        
        # Last candle closes above middle point of first candle
        first_midpoint = (first_candle['Open'] + first_candle['Close']) / 2
        last_closes_high = last_candle['Close'] > first_midpoint
        
        if (first_bearish and middle_small and last_bullish and 
            (gap_down or gap_up) and last_closes_high):
            morning_star[i] = True
    
    return morning_star

def detect_three_white_soldiers(data):
    """Detect Three White Soldiers pattern"""
    three_soldiers = np.zeros(len(data), dtype=bool)
    
    for i in range(2, len(data)):
        first = data.iloc[i-2]
        second = data.iloc[i-1]
        third = data.iloc[i]
        
        # All three candles are bullish (green)
        all_bullish = (
            first['Close'] > first['Open'] and
            second['Close'] > second['Open'] and
            third['Close'] > third['Open']
        )
        
        # Each candle opens within the previous candle's body
        progressive_opens = (
            second['Open'] > first['Open'] and
            second['Open'] < first['Close'] and
            third['Open'] > second['Open'] and
            third['Open'] < second['Close']
        )
        
        # Each candle closes higher than the previous
        progressive_closes = (
            second['Close'] > first['Close'] and
            third['Close'] > second['Close']
        )
        
        if all_bullish and progressive_opens and progressive_closes:
            three_soldiers[i] = True
    
    return three_soldiers

### 2.4 Create the dataset

In [5]:
def create_dataset(output_dir="candlestick_dataset"):
    """
    Create a complete dataset for training YOLOv8
    
    Args:
        output_dir (str): Directory to save the dataset
    """
    # Define stock tickers to use (major indices, tech stocks, etc.)
    tickers = [
        "SPY", "QQQ", "DIA", "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", 
        "NVDA", "AMD", "INTC", "JPM", "V", "MA", "DIS", "NFLX", "CSCO", "VZ",
        "T", "PFE", "MRK", "JNJ", "PG", "KO", "PEP", "WMT", "HD", "BA", "CAT"
    ]
    
    # Define patterns to detect
    patterns = {
        "doji": detect_doji,
        "hammer": detect_hammer,
        "shooting_star": detect_shooting_star,
        "engulfing": detect_engulfing,
        "morning_star": detect_morning_star,
        "three_white_soldiers": detect_three_white_soldiers
    }
    
    # Generate pattern dataset
    print("Generating pattern dataset...")
    chart_paths = generate_pattern_dataset(
        tickers=tickers,
        patterns=patterns,
        num_samples_per_pattern=300,  # 300 samples per pattern
        output_dir=os.path.join(output_dir, "raw_data")
    )
    
    # Check if any charts were generated
    if len(chart_paths) == 0:
        print("No pattern charts were generated. Please check the pattern detection logic.")
        return None
    
    # Split into train, val, test sets
    print(f"Generated {len(chart_paths)} chart images")
    
    # Shuffle paths
    random.shuffle(chart_paths)
    
    # Split: 70% train, 20% val, 10% test
    train_split = int(0.7 * len(chart_paths))
    val_split = int(0.9 * len(chart_paths))
    
    train_paths = chart_paths[:train_split]
    val_paths = chart_paths[train_split:val_split]
    test_paths = chart_paths[val_split:]
    
    # Create directories
    os.makedirs(os.path.join(output_dir, "train", "images"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "train", "labels"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "val", "images"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "val", "labels"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "test", "images"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "test", "labels"), exist_ok=True)
    
    # Copy files to appropriate directories
    def copy_files(paths, dest_type):
        for img_path in tqdm(paths, desc=f"Copying {dest_type} files"):
            # Copy image
            img_filename = os.path.basename(img_path)
            shutil.copy(
                img_path,
                os.path.join(output_dir, dest_type, "images", img_filename)
            )
            
            # Copy annotation
            label_path = img_path.replace(".jpg", ".txt")
            label_filename = os.path.basename(label_path)
            shutil.copy(
                label_path,
                os.path.join(output_dir, dest_type, "labels", label_filename)
            )
    
    print("Organizing dataset...")
    copy_files(train_paths, "train")
    copy_files(val_paths, "val")
    copy_files(test_paths, "test")
    
    # Create dataset YAML file
    dataset_yaml = {
        "path": os.path.abspath(output_dir),
        "train": "train/images",
        "val": "val/images",
        "test": "test/images",
        "names": list(patterns.keys())
    }
    
    with open(os.path.join(output_dir, "dataset.yaml"), "w") as f:
        yaml.dump(dataset_yaml, f)
    
    print(f"Dataset created at {output_dir}")
    print(f"Training set: {len(train_paths)} images")
    print(f"Validation set: {len(val_paths)} images")
    print(f"Test set: {len(test_paths)} images")
    
    return os.path.join(output_dir, "dataset.yaml")

## 3. Training the YOLOv8 model

In [6]:
def train_yolov8_model(dataset_yaml, model_size="n", epochs=100, batch_size=16, image_size=640):
    """
    Train a YOLOv8 model on the candlestick pattern dataset
    
    Args:
        dataset_yaml (str): Path to dataset YAML file
        model_size (str): YOLOv8 model size (n, s, m, l, x)
        epochs (int): Number of training epochs
        batch_size (int): Batch size
        image_size (int): Image size
        
    Returns:
        str: Path to the trained model weights
    """
    print(f"Training YOLOv8{model_size} model on {dataset_yaml}...")
    
    # Load a pre-trained YOLOv8 model
    model = YOLO(f"yolov8{model_size}.pt")
    
    # Train the model
    results = model.train(
        data=dataset_yaml,
        epochs=epochs,
        batch=batch_size,
        imgsz=image_size,
        patience=20,  # Early stopping patience
        save=True,
        device="0" if torch.cuda.is_available() else "cpu",
        verbose=True
    )
    
    # Get path to best model weights
#     best_weights_path = model.best
    
#     print(f"Training completed. Best weights saved to: {best_weights_path}")
    
    return model

## 4. Fine-tuning the model

In [7]:
def fine_tune_model(best_weights_path, dataset_yaml, epochs=50, batch_size=8, image_size=640):
    """
    Fine-tune the YOLOv8 model
    
    Args:
        best_weights_path (str): Path to the best weights from initial training
        dataset_yaml (str): Path to dataset YAML file
        epochs (int): Number of fine-tuning epochs
        batch_size (int): Batch size
        image_size (int): Image size
        
    Returns:
        str: Path to the fine-tuned model weights
    """
    print(f"Fine-tuning model {best_weights_path}...")
    
    # Load the trained model
    model = YOLO(best_weights_path)
    
    # Fine-tune with a lower learning rate
    results = model.train(
        data=dataset_yaml,
        epochs=epochs,
        batch=batch_size,
        imgsz=image_size,
        patience=20,  # Early stopping patience
        save=True,
        device="0" if torch.cuda.is_available() else "cpu",
        verbose=True,
        lr0=0.001,  # Lower learning rate for fine-tuning
        lrf=0.01,   # Final learning rate as a fraction of initial lr
    )
    
    # Get path to best fine-tuned weights
#     best_weights_path = model.best
    
#     print(f"Fine-tuning completed. Best weights saved to: {best_weights_path}")
    
    return model

## 5. Inference for pattern detection

In [8]:
def detect_patterns(chart_image_path, model_path, conf_threshold=0.25):
    """
    Detect candlestick patterns in a chart image
    
    Args:
        chart_image_path (str): Path to the chart image
        model_path (str): Path to the trained YOLOv8 model weights
        conf_threshold (float): Confidence threshold for detections
        
    Returns:
        list: List of detected patterns with confidence scores
    """
    # Load the model
    model = YOLO(model_path)
    
    # Run inference
    results = model(chart_image_path, conf=conf_threshold)
    
    # Process results
    detections = []
    
    for result in results:
        for i, (box, conf, cls) in enumerate(zip(result.boxes.xyxy, result.boxes.conf, result.boxes.cls)):
            x1, y1, x2, y2 = box.tolist()
            confidence = conf.item()
            class_id = int(cls.item())
            class_name = result.names[class_id]
            
            detections.append({
                "pattern": class_name,
                "confidence": confidence,
                "box": [x1, y1, x2, y2]
            })
    
    return detections

def visualize_detections(chart_image_path, detections, output_path=None):
    """
    Visualize detected patterns on a chart image
    
    Args:
        chart_image_path (str): Path to the chart image
        detections (list): List of detected patterns
        output_path (str, optional): Path to save the output image
        
    Returns:
        numpy.ndarray: Image with visualized detections
    """
    # Load the image
    image = cv2.imread(chart_image_path)
    
    # Convert BGR to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Colors for different patterns (in RGB)
    colors = {
        "doji": (255, 0, 0),          # Red
        "hammer": (0, 255, 0),        # Green
        "shooting_star": (0, 0, 255), # Blue
        "engulfing": (255, 255, 0),   # Yellow
        "morning_star": (255, 0, 255),# Magenta
        "three_white_soldiers": (0, 255, 255) # Cyan
    }
    
    # Draw bounding boxes and labels
    for detection in detections:
        pattern = detection["pattern"]
        conf = detection["confidence"]
        box = detection["box"]
        
        # Get color for this pattern
        color = colors.get(pattern, (200, 200, 200))
        
        # Draw bounding box
        cv2.rectangle(
            image, 
            (int(box[0]), int(box[1])), 
            (int(box[2]), int(box[3])), 
            color, 
            2
        )
        
        # Prepare label text
        label = f"{pattern}: {conf:.2f}"
        
        # Get text size
        text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
        
        # Calculate text background rectangle
        text_bg_x1 = int(box[0])
        text_bg_y1 = int(box[1]) - text_size[1] - 10
        text_bg_x2 = int(box[0]) + text_size[0] + 10
        text_bg_y2 = int(box[1])
        
        # Draw text background
        cv2.rectangle(
            image,
            (text_bg_x1, text_bg_y1),
            (text_bg_x2, text_bg_y2),
            color,
            -1
        )
        
        # Draw text
        cv2.putText(
            image,
            label,
            (int(box[0] + 5), int(box[1] - 5)),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (255, 255, 255),
            2
        )
    
    # Save the output image if requested
    if output_path:
        # Convert RGB back to BGR for saving
        output_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(output_path, output_img)
    
    return image

def detect_patterns_from_ticker(ticker, model_path, period="1y", interval="1d", conf_threshold=0.3):
    """
    Generate a chart for a ticker and detect patterns
    
    Args:
        ticker (str): Stock ticker symbol
        model_path (str): Path to trained model
        period (str): Data period
        interval (str): Data interval
        conf_threshold (float): Detection confidence threshold
        
    Returns:
        tuple: (chart_path, detections, visualization)
    """
    # Generate chart
    chart_path = generate_candlestick_chart(
        ticker=ticker,
        period=period,
        interval=interval,
        output_dir="inference_charts"
    )
    
    if chart_path is None:
        return None, [], None
    
    # Detect patterns
    detections = detect_patterns(
        chart_image_path=chart_path,
        model_path=model_path,
        conf_threshold=conf_threshold
    )
    
    # Visualize detections
    output_path = chart_path.replace(".jpg", "_detected.jpg")
    visualization = visualize_detections(
        chart_image_path=chart_path,
        detections=detections,
        output_path=output_path
    )
    
    return chart_path, detections, visualization

## 6. Main execution flow

In [9]:
# 1. Create dataset
dataset_yaml = create_dataset(output_dir="candlestick_dataset")

# Check if dataset was created successfully
if dataset_yaml is None:
    print("Dataset creation failed. Please check the pattern detection functions.")

Generating pattern dataset...


Processing tickers:   0%|          | 0/30 [00:00<?, ?it/s]

YF.download() has changed argument auto_adjust default to True


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, SPY)   1257 non-null   float64
 1   (High, SPY)    1257 non-null   float64
 2   (Low, SPY)     1257 non-null   float64
 3   (Open, SPY)    1257 non-null   float64
 4   (Volume, SPY)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'SPY'),
            (  'High', 'SPY'),
            (   'Low', 'SPY'),
            (  'Open', 'SPY'),
            ('Volume', 'SPY')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:   3%|▎         | 1/30 [00:15<07:35, 15.71s/it]

Generated chart for three_white_soldiers pattern in SPY


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, QQQ)   1257 non-null   float64
 1   (High, QQQ)    1257 non-null   float64
 2   (Low, QQQ)     1257 non-null   float64
 3   (Open, QQQ)    1257 non-null   float64
 4   (Volume, QQQ)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'QQQ'),
            (  'High', 'QQQ'),
            (   'Low', 'QQQ'),
            (  'Open', 'QQQ'),
            ('Volume', 'QQQ')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:   7%|▋         | 2/30 [00:29<06:42, 14.38s/it]

Generated chart for three_white_soldiers pattern in QQQ


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, DIA)   1257 non-null   float64
 1   (High, DIA)    1257 non-null   float64
 2   (Low, DIA)     1257 non-null   float64
 3   (Open, DIA)    1257 non-null   float64
 4   (Volume, DIA)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'DIA'),
            (  'High', 'DIA'),
            (   'Low', 'DIA'),
            (  'Open', 'DIA'),
            ('Volume', 'DIA')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  10%|█         | 3/30 [00:42<06:16, 13.93s/it]

Generated chart for three_white_soldiers pattern in DIA


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, AAPL)   1257 non-null   float64
 1   (High, AAPL)    1257 non-null   float64
 2   (Low, AAPL)     1257 non-null   float64
 3   (Open, AAPL)    1257 non-null   float64
 4   (Volume, AAPL)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'AAPL'),
            (  'High', 'AAPL'),
            (   'Low', 'AAPL'),
            (  'Open', 'AAPL'),
            ('Volume', 'AAPL')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  13%|█▎        | 4/30 [00:55<05:49, 13.43s/it]

Generated chart for three_white_soldiers pattern in AAPL


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, MSFT)   1257 non-null   float64
 1   (High, MSFT)    1257 non-null   float64
 2   (Low, MSFT)     1257 non-null   float64
 3   (Open, MSFT)    1257 non-null   float64
 4   (Volume, MSFT)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'MSFT'),
            (  'High', 'MSFT'),
            (   'Low', 'MSFT'),
            (  'Open', 'MSFT'),
            ('Volume', 'MSFT')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  17%|█▋        | 5/30 [01:09<05:42, 13.70s/it]

Generated chart for three_white_soldiers pattern in MSFT


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   (Close, GOOGL)   1257 non-null   float64
 1   (High, GOOGL)    1257 non-null   float64
 2   (Low, GOOGL)     1257 non-null   float64
 3   (Open, GOOGL)    1257 non-null   float64
 4   (Volume, GOOGL)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'GOOGL'),
            (  'High', 'GOOGL'),
            (   'Low', 'GOOGL'),
            (  'Open', 'GOOGL'),
            ('Volume', 'GOOGL')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3 

Processing tickers:  20%|██        | 6/30 [01:22<05:23, 13.49s/it]

Generated chart for three_white_soldiers pattern in GOOGL


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, AMZN)   1257 non-null   float64
 1   (High, AMZN)    1257 non-null   float64
 2   (Low, AMZN)     1257 non-null   float64
 3   (Open, AMZN)    1257 non-null   float64
 4   (Volume, AMZN)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'AMZN'),
            (  'High', 'AMZN'),
            (   'Low', 'AMZN'),
            (  'Open', 'AMZN'),
            ('Volume', 'AMZN')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  23%|██▎       | 7/30 [01:35<05:08, 13.41s/it]

Generated chart for three_white_soldiers pattern in AMZN


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, META)   1257 non-null   float64
 1   (High, META)    1257 non-null   float64
 2   (Low, META)     1257 non-null   float64
 3   (Open, META)    1257 non-null   float64
 4   (Volume, META)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'META'),
            (  'High', 'META'),
            (   'Low', 'META'),
            (  'Open', 'META'),
            ('Volume', 'META')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  27%|██▋       | 8/30 [01:49<04:54, 13.37s/it]

Generated chart for three_white_soldiers pattern in META


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, TSLA)   1257 non-null   float64
 1   (High, TSLA)    1257 non-null   float64
 2   (Low, TSLA)     1257 non-null   float64
 3   (Open, TSLA)    1257 non-null   float64
 4   (Volume, TSLA)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'TSLA'),
            (  'High', 'TSLA'),
            (   'Low', 'TSLA'),
            (  'Open', 'TSLA'),
            ('Volume', 'TSLA')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  30%|███       | 9/30 [02:01<04:37, 13.24s/it]

Generated chart for three_white_soldiers pattern in TSLA


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, NVDA)   1257 non-null   float64
 1   (High, NVDA)    1257 non-null   float64
 2   (Low, NVDA)     1257 non-null   float64
 3   (Open, NVDA)    1257 non-null   float64
 4   (Volume, NVDA)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'NVDA'),
            (  'High', 'NVDA'),
            (   'Low', 'NVDA'),
            (  'Open', 'NVDA'),
            ('Volume', 'NVDA')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  33%|███▎      | 10/30 [02:15<04:28, 13.41s/it]

Generated chart for three_white_soldiers pattern in NVDA
Generated chart for three_white_soldiers pattern in NVDA


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, AMD)   1257 non-null   float64
 1   (High, AMD)    1257 non-null   float64
 2   (Low, AMD)     1257 non-null   float64
 3   (Open, AMD)    1257 non-null   float64
 4   (Volume, AMD)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'AMD'),
            (  'High', 'AMD'),
            (   'Low', 'AMD'),
            (  'Open', 'AMD'),
            ('Volume', 'AMD')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  37%|███▋      | 11/30 [02:27<04:07, 13.02s/it]

Generated chart for three_white_soldiers pattern in AMD


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, INTC)   1257 non-null   float64
 1   (High, INTC)    1257 non-null   float64
 2   (Low, INTC)     1257 non-null   float64
 3   (Open, INTC)    1257 non-null   float64
 4   (Volume, INTC)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'INTC'),
            (  'High', 'INTC'),
            (   'Low', 'INTC'),
            (  'Open', 'INTC'),
            ('Volume', 'INTC')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  40%|████      | 12/30 [02:39<03:48, 12.68s/it]

Generated chart for three_white_soldiers pattern in INTC


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, JPM)   1257 non-null   float64
 1   (High, JPM)    1257 non-null   float64
 2   (Low, JPM)     1257 non-null   float64
 3   (Open, JPM)    1257 non-null   float64
 4   (Volume, JPM)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'JPM'),
            (  'High', 'JPM'),
            (   'Low', 'JPM'),
            (  'Open', 'JPM'),
            ('Volume', 'JPM')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  43%|████▎     | 13/30 [02:53<03:41, 13.04s/it]

Generated chart for three_white_soldiers pattern in JPM
Generated chart for three_white_soldiers pattern in JPM


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   (Close, V)   1257 non-null   float64
 1   (High, V)    1257 non-null   float64
 2   (Low, V)     1257 non-null   float64
 3   (Open, V)    1257 non-null   float64
 4   (Volume, V)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'V'),
            (  'High', 'V'),
            (   'Low', 'V'),
            (  'Open', 'V'),
            ('Volume', 'V')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null   float64
 4   Volume  1

Processing tickers:  47%|████▋     | 14/30 [03:04<03:17, 12.35s/it]

Generated chart for three_white_soldiers pattern in V


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   (Close, MA)   1257 non-null   float64
 1   (High, MA)    1257 non-null   float64
 2   (Low, MA)     1257 non-null   float64
 3   (Open, MA)    1257 non-null   float64
 4   (Volume, MA)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'MA'),
            (  'High', 'MA'),
            (   'Low', 'MA'),
            (  'Open', 'MA'),
            ('Volume', 'MA')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null   float64
 4

Processing tickers:  50%|█████     | 15/30 [03:15<02:59, 11.94s/it]

Generated chart for three_white_soldiers pattern in MA


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, DIS)   1257 non-null   float64
 1   (High, DIS)    1257 non-null   float64
 2   (Low, DIS)     1257 non-null   float64
 3   (Open, DIS)    1257 non-null   float64
 4   (Volume, DIS)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'DIS'),
            (  'High', 'DIS'),
            (   'Low', 'DIS'),
            (  'Open', 'DIS'),
            ('Volume', 'DIS')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  53%|█████▎    | 16/30 [03:26<02:44, 11.73s/it]

Generated chart for three_white_soldiers pattern in DIS
Generated chart for three_white_soldiers pattern in DIS


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, NFLX)   1257 non-null   float64
 1   (High, NFLX)    1257 non-null   float64
 2   (Low, NFLX)     1257 non-null   float64
 3   (Open, NFLX)    1257 non-null   float64
 4   (Volume, NFLX)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'NFLX'),
            (  'High', 'NFLX'),
            (   'Low', 'NFLX'),
            (  'Open', 'NFLX'),
            ('Volume', 'NFLX')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  57%|█████▋    | 17/30 [03:41<02:42, 12.52s/it]

Generated chart for three_white_soldiers pattern in NFLX


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, CSCO)   1257 non-null   float64
 1   (High, CSCO)    1257 non-null   float64
 2   (Low, CSCO)     1257 non-null   float64
 3   (Open, CSCO)    1257 non-null   float64
 4   (Volume, CSCO)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'CSCO'),
            (  'High', 'CSCO'),
            (   'Low', 'CSCO'),
            (  'Open', 'CSCO'),
            ('Volume', 'CSCO')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    12

Processing tickers:  60%|██████    | 18/30 [03:52<02:25, 12.10s/it]

Generated chart for three_white_soldiers pattern in CSCO
Generated chart for three_white_soldiers pattern in CSCO


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   (Close, VZ)   1257 non-null   float64
 1   (High, VZ)    1257 non-null   float64
 2   (Low, VZ)     1257 non-null   float64
 3   (Open, VZ)    1257 non-null   float64
 4   (Volume, VZ)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'VZ'),
            (  'High', 'VZ'),
            (   'Low', 'VZ'),
            (  'Open', 'VZ'),
            ('Volume', 'VZ')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null   float64
 4

Processing tickers:  63%|██████▎   | 19/30 [04:03<02:10, 11.82s/it]

Generated chart for three_white_soldiers pattern in VZ
Generated chart for three_white_soldiers pattern in VZ


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   (Close, T)   1257 non-null   float64
 1   (High, T)    1257 non-null   float64
 2   (Low, T)     1257 non-null   float64
 3   (Open, T)    1257 non-null   float64
 4   (Volume, T)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'T'),
            (  'High', 'T'),
            (   'Low', 'T'),
            (  'Open', 'T'),
            ('Volume', 'T')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null   float64
 4   Volume  1

Processing tickers:  67%|██████▋   | 20/30 [04:14<01:55, 11.57s/it]

Generated chart for three_white_soldiers pattern in T
Generated chart for three_white_soldiers pattern in T


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, PFE)   1257 non-null   float64
 1   (High, PFE)    1257 non-null   float64
 2   (Low, PFE)     1257 non-null   float64
 3   (Open, PFE)    1257 non-null   float64
 4   (Volume, PFE)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'PFE'),
            (  'High', 'PFE'),
            (   'Low', 'PFE'),
            (  'Open', 'PFE'),
            ('Volume', 'PFE')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  70%|███████   | 21/30 [04:28<01:52, 12.50s/it]

Generated chart for three_white_soldiers pattern in PFE


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, MRK)   1257 non-null   float64
 1   (High, MRK)    1257 non-null   float64
 2   (Low, MRK)     1257 non-null   float64
 3   (Open, MRK)    1257 non-null   float64
 4   (Volume, MRK)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'MRK'),
            (  'High', 'MRK'),
            (   'Low', 'MRK'),
            (  'Open', 'MRK'),
            ('Volume', 'MRK')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  73%|███████▎  | 22/30 [04:40<01:36, 12.08s/it]

Generated chart for three_white_soldiers pattern in MRK
Generated chart for three_white_soldiers pattern in MRK


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, JNJ)   1257 non-null   float64
 1   (High, JNJ)    1257 non-null   float64
 2   (Low, JNJ)     1257 non-null   float64
 3   (Open, JNJ)    1257 non-null   float64
 4   (Volume, JNJ)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'JNJ'),
            (  'High', 'JNJ'),
            (   'Low', 'JNJ'),
            (  'Open', 'JNJ'),
            ('Volume', 'JNJ')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  77%|███████▋  | 23/30 [04:51<01:22, 11.82s/it]

Generated chart for three_white_soldiers pattern in JNJ
Generated chart for three_white_soldiers pattern in JNJ


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   (Close, PG)   1257 non-null   float64
 1   (High, PG)    1257 non-null   float64
 2   (Low, PG)     1257 non-null   float64
 3   (Open, PG)    1257 non-null   float64
 4   (Volume, PG)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'PG'),
            (  'High', 'PG'),
            (   'Low', 'PG'),
            (  'Open', 'PG'),
            ('Volume', 'PG')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null   float64
 4

Processing tickers:  80%|████████  | 24/30 [05:03<01:12, 12.09s/it]

Generated chart for three_white_soldiers pattern in PG


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   (Close, KO)   1257 non-null   float64
 1   (High, KO)    1257 non-null   float64
 2   (Low, KO)     1257 non-null   float64
 3   (Open, KO)    1257 non-null   float64
 4   (Volume, KO)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'KO'),
            (  'High', 'KO'),
            (   'Low', 'KO'),
            (  'Open', 'KO'),
            ('Volume', 'KO')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null   float64
 4

Processing tickers:  83%|████████▎ | 25/30 [05:15<00:59, 11.82s/it]

Generated chart for three_white_soldiers pattern in KO
Generated chart for three_white_soldiers pattern in KO


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, PEP)   1257 non-null   float64
 1   (High, PEP)    1257 non-null   float64
 2   (Low, PEP)     1257 non-null   float64
 3   (Open, PEP)    1257 non-null   float64
 4   (Volume, PEP)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'PEP'),
            (  'High', 'PEP'),
            (   'Low', 'PEP'),
            (  'Open', 'PEP'),
            ('Volume', 'PEP')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  87%|████████▋ | 26/30 [05:30<00:51, 12.88s/it]

Generated chart for three_white_soldiers pattern in PEP
Generated chart for three_white_soldiers pattern in PEP


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, WMT)   1257 non-null   float64
 1   (High, WMT)    1257 non-null   float64
 2   (Low, WMT)     1257 non-null   float64
 3   (Open, WMT)    1257 non-null   float64
 4   (Volume, WMT)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'WMT'),
            (  'High', 'WMT'),
            (   'Low', 'WMT'),
            (  'Open', 'WMT'),
            ('Volume', 'WMT')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers:  90%|█████████ | 27/30 [05:42<00:37, 12.50s/it]

Generated chart for three_white_soldiers pattern in WMT


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   (Close, HD)   1257 non-null   float64
 1   (High, HD)    1257 non-null   float64
 2   (Low, HD)     1257 non-null   float64
 3   (Open, HD)    1257 non-null   float64
 4   (Volume, HD)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'HD'),
            (  'High', 'HD'),
            (   'Low', 'HD'),
            (  'Open', 'HD'),
            ('Volume', 'HD')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null   float64
 4

Processing tickers:  93%|█████████▎| 28/30 [05:54<00:24, 12.45s/it]

Generated chart for three_white_soldiers pattern in HD
Generated chart for three_white_soldiers pattern in HD


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   (Close, BA)   1257 non-null   float64
 1   (High, BA)    1257 non-null   float64
 2   (Low, BA)     1257 non-null   float64
 3   (Open, BA)    1257 non-null   float64
 4   (Volume, BA)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'BA'),
            (  'High', 'BA'),
            (   'Low', 'BA'),
            (  'Open', 'BA'),
            ('Volume', 'BA')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null   float64
 4

Processing tickers:  97%|█████████▋| 29/30 [06:06<00:12, 12.19s/it]

Generated chart for three_white_soldiers pattern in BA


[*********************100%***********************]  1 of 1 completed


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1257 entries, 2020-04-20 to 2025-04-17
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   (Close, CAT)   1257 non-null   float64
 1   (High, CAT)    1257 non-null   float64
 2   (Low, CAT)     1257 non-null   float64
 3   (Open, CAT)    1257 non-null   float64
 4   (Volume, CAT)  1257 non-null   int64  
dtypes: float64(4), int64(1)
memory usage: 58.9 KB
None MultiIndex([( 'Close', 'CAT'),
            (  'High', 'CAT'),
            (   'Low', 'CAT'),
            (  'Open', 'CAT'),
            ('Volume', 'CAT')],
           names=['Price', 'Ticker'])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1257 entries, 0 to 1256
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Close   1257 non-null   float64
 1   High    1257 non-null   float64
 2   Low     1257 non-null   float64
 3   Open    1257 non-null 

Processing tickers: 100%|██████████| 30/30 [06:18<00:00, 12.62s/it]

Generated chart for three_white_soldiers pattern in CAT
Generated chart for three_white_soldiers pattern in CAT

Pattern detection summary:
doji: 3919
hammer: 960
shooting_star: 818
engulfing: 1320
morning_star: 171
three_white_soldiers: 337
Generated 1623 chart images





Organizing dataset...


Copying train files: 100%|██████████| 1136/1136 [00:40<00:00, 27.93it/s]
Copying val files: 100%|██████████| 324/324 [00:10<00:00, 30.68it/s]
Copying test files: 100%|██████████| 163/163 [00:04<00:00, 35.04it/s]

Dataset created at candlestick_dataset
Training set: 1136 images
Validation set: 324 images
Test set: 163 images





In [10]:
# 2. Train YOLOv8 model
model_returned = train_yolov8_model(
    dataset_yaml=dataset_yaml,
    model_size="n",  # Start with a smaller model for faster training
    epochs=25,
    batch_size=16,
    image_size=640
)

Training YOLOv8n model on candlestick_dataset/dataset.yaml...
Ultralytics 8.3.94 🚀 Python-3.10.16 torch-2.1.0+cu121 CPU (Intel Xeon E5-2680 v4 2.40GHz)
[34m[1mengine/trainer: [0mtask=detect, mode=train, model=yolov8n.pt, data=candlestick_dataset/dataset.yaml, epochs=25, time=None, patience=20, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=cpu, workers=8, project=None, name=train, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_tx

[34m[1mtrain: [0mScanning /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/candlestick_dataset/train/labels... 1136 images, 0 backgrounds, 7 corrupt: 100%|██████████| 1136/1136 [00:01<00:00, 934.11it/s]

[34m[1mtrain: [0mNew cache created: /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/candlestick_dataset/train/labels.cache



[34m[1mval: [0mScanning /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/candlestick_dataset/val/labels... 324 images, 0 backgrounds, 0 corrupt: 100%|██████████| 324/324 [00:00<00:00, 1027.53it/s]

[34m[1mval: [0mNew cache created: /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/candlestick_dataset/val/labels.cache





Plotting labels to runs/detect/train/labels.jpg... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m AdamW(lr=0.001, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
[34m[1mTensorBoard: [0mmodel graph visualization added ✅
Image sizes 640 train, 640 val
Using 0 dataloader workers
Logging results to [1mruns/detect/train[0m
Starting training for 25 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       1/25         0G          0      112.6          0          0        640: 100%|██████████| 71/71 [02:18<00:00,  1.95s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:13<00:00,  1.22s/it]


                   all        324        324          0          0          0          0

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       2/25         0G          0      91.61          0          0        640: 100%|██████████| 71/71 [02:14<00:00,  1.90s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:12<00:00,  1.15s/it]


                   all        324        324          0          0          0          0

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       3/25         0G          0      73.64          0          0        640: 100%|██████████| 71/71 [02:14<00:00,  1.89s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:12<00:00,  1.10s/it]


                   all        324        324          0          0          0          0

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       4/25         0G          0      55.07          0          0        640: 100%|██████████| 71/71 [02:12<00:00,  1.86s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:11<00:00,  1.02s/it]


                   all        324        324          0          0          0          0

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       5/25         0G          0      39.79          0          0        640: 100%|██████████| 71/71 [02:14<00:00,  1.90s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.00it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       6/25         0G          0      28.18          0          0        640: 100%|██████████| 71/71 [02:14<00:00,  1.89s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:11<00:00,  1.02s/it]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       7/25         0G          0      19.62          0          0        640: 100%|██████████| 71/71 [02:13<00:00,  1.89s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.04it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       8/25         0G          0      13.53          0          0        640: 100%|██████████| 71/71 [02:12<00:00,  1.86s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.04it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       9/25         0G          0      9.372          0          0        640: 100%|██████████| 71/71 [02:12<00:00,  1.86s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.04it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      10/25         0G  1.214e-08      6.486  7.555e-09          0        640: 100%|██████████| 71/71 [02:12<00:00,  1.87s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.08it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      11/25         0G          0      4.568          0          0        640: 100%|██████████| 71/71 [02:12<00:00,  1.86s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.03it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      12/25         0G          0       3.21          0          0        640: 100%|██████████| 71/71 [02:12<00:00,  1.87s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:11<00:00,  1.04s/it]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      13/25         0G  2.012e-12      2.273  9.609e-13          0        640: 100%|██████████| 71/71 [02:16<00:00,  1.92s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:11<00:00,  1.04s/it]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      14/25         0G          0      1.627          0          0        640: 100%|██████████| 71/71 [02:16<00:00,  1.93s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:11<00:00,  1.04s/it]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      15/25         0G          0      1.181          0          0        640: 100%|██████████| 71/71 [02:15<00:00,  1.91s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:11<00:00,  1.00s/it]

                   all        324        324          0          0          0          0





Closing dataloader mosaic

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      16/25         0G          0     0.8802          0          0        640: 100%|██████████| 71/71 [02:13<00:00,  1.89s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.04it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      17/25         0G          0      0.667          0          0        640: 100%|██████████| 71/71 [02:11<00:00,  1.85s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.02it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      18/25         0G          0     0.5193          0          0        640: 100%|██████████| 71/71 [02:12<00:00,  1.86s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      19/25         0G          0     0.4139          0          0        640: 100%|██████████| 71/71 [02:11<00:00,  1.86s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.01it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      20/25         0G          0     0.3421          0          0        640: 100%|██████████| 71/71 [02:11<00:00,  1.86s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.04it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      21/25         0G          0     0.2855          0          0        640: 100%|██████████| 71/71 [02:11<00:00,  1.86s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.00it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      22/25         0G          0     0.2516          0          0        640: 100%|██████████| 71/71 [02:11<00:00,  1.85s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.02it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      23/25         0G          0     0.2237          0          0        640: 100%|██████████| 71/71 [02:12<00:00,  1.87s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.04it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      24/25         0G          0     0.2076          0          0        640: 100%|██████████| 71/71 [02:10<00:00,  1.84s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.05it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      25/25         0G          0     0.1973          0          0        640: 100%|██████████| 71/71 [02:09<00:00,  1.83s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:10<00:00,  1.06it/s]

                   all        324        324          0          0          0          0






25 epochs completed in 1.005 hours.
Optimizer stripped from runs/detect/train/weights/last.pt, 6.2MB
Optimizer stripped from runs/detect/train/weights/best.pt, 6.2MB

Validating runs/detect/train/weights/best.pt...
Ultralytics 8.3.94 🚀 Python-3.10.16 torch-2.1.0+cu121 CPU (Intel Xeon E5-2680 v4 2.40GHz)
Model summary (fused): 72 layers, 3,006,818 parameters, 0 gradients, 8.1 GFLOPs


                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 11/11 [00:09<00:00,  1.14it/s]


                   all        324        324          0          0          0          0
                  doji         54         54          0          0          0          0
                hammer         55         55          0          0          0          0
         shooting_star         53         53          0          0          0          0
             engulfing         63         63          0          0          0          0
          morning_star         37         37          0          0          0          0
  three_white_soldiers         62         62          0          0          0          0
Speed: 0.5ms preprocess, 21.9ms inference, 0.0ms loss, 0.1ms postprocess per image
Results saved to [1mruns/detect/train[0m


In [11]:
# 3. Fine-tune the model
fine_tuned_model = fine_tune_model(
    best_weights_path="runs/detect/train/weights/best.pt",
    dataset_yaml=dataset_yaml,
    epochs=10,
    batch_size=8,
    image_size=640
)

Fine-tuning model runs/detect/train/weights/best.pt...
Ultralytics 8.3.94 🚀 Python-3.10.16 torch-2.1.0+cu121 CPU (Intel Xeon E5-2680 v4 2.40GHz)
[34m[1mengine/trainer: [0mtask=detect, mode=train, model=runs/detect/train/weights/best.pt, data=candlestick_dataset/dataset.yaml, epochs=10, time=None, patience=20, batch=8, imgsz=640, save=True, save_period=-1, cache=False, device=cpu, workers=8, project=None, name=train2, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frame

[34m[1mtrain: [0mScanning /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/candlestick_dataset/train/labels.cache... 1136 images, 0 backgrounds, 7 corrupt: 100%|██████████| 1136/1136 [00:00<?, ?it/s]




[34m[1mval: [0mScanning /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/candlestick_dataset/val/labels.cache... 324 images, 0 backgrounds, 0 corrupt: 100%|██████████| 324/324 [00:00<?, ?it/s]

Plotting labels to runs/detect/train2/labels.jpg... 





[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.001' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m AdamW(lr=0.001, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
[34m[1mTensorBoard: [0mmodel graph visualization added ✅
Image sizes 640 train, 640 val
Using 0 dataloader workers
Logging results to [1mruns/detect/train2[0m
Starting training for 10 epochs...
Closing dataloader mosaic

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       1/10         0G          0    0.07963          0          0        640: 100%|██████████| 142/142 [02:11<00:00,  1.08it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.04it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       2/10         0G          0    0.04529          0          0        640: 100%|██████████| 142/142 [02:09<00:00,  1.09it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.05it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       3/10         0G          0    0.01737          0          0        640: 100%|██████████| 142/142 [02:09<00:00,  1.10it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.05it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       4/10         0G          0   0.008866          0          0        640: 100%|██████████| 142/142 [02:08<00:00,  1.10it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  1.98it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       5/10         0G          0    0.00642          0          0        640: 100%|██████████| 142/142 [02:09<00:00,  1.10it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.07it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       6/10         0G          0   0.004818          0          0        640: 100%|██████████| 142/142 [02:08<00:00,  1.10it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.06it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       7/10         0G          0   0.002482          0          0        640: 100%|██████████| 142/142 [02:09<00:00,  1.10it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.03it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       8/10         0G          0   0.001503          0          0        640: 100%|██████████| 142/142 [02:09<00:00,  1.10it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.08it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       9/10         0G          0   0.001346          0          0        640: 100%|██████████| 142/142 [02:09<00:00,  1.10it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.04it/s]

                   all        324        324          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      10/10         0G          0   0.001276          0          0        640: 100%|██████████| 142/142 [02:09<00:00,  1.10it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:10<00:00,  2.02it/s]

                   all        324        324          0          0          0          0






10 epochs completed in 0.390 hours.
Optimizer stripped from runs/detect/train2/weights/last.pt, 6.2MB
Optimizer stripped from runs/detect/train2/weights/best.pt, 6.2MB

Validating runs/detect/train2/weights/best.pt...
Ultralytics 8.3.94 🚀 Python-3.10.16 torch-2.1.0+cu121 CPU (Intel Xeon E5-2680 v4 2.40GHz)
Model summary (fused): 72 layers, 3,006,818 parameters, 0 gradients, 8.1 GFLOPs


                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 21/21 [00:09<00:00,  2.18it/s]


                   all        324        324          0          0          0          0
                  doji         54         54          0          0          0          0
                hammer         55         55          0          0          0          0
         shooting_star         53         53          0          0          0          0
             engulfing         63         63          0          0          0          0
          morning_star         37         37          0          0          0          0
  three_white_soldiers         62         62          0          0          0          0
Speed: 0.5ms preprocess, 22.0ms inference, 0.0ms loss, 0.1ms postprocess per image
Results saved to [1mruns/detect/train2[0m


In [14]:
fine_tuned_weights = "runs/detect/train2/weights/best.pt"

In [15]:
 # 4. Test inference on new data
test_tickers = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"]

print("\nTesting pattern detection on new charts:")
for ticker in test_tickers:
    print(f"\nAnalyzing {ticker}...")
    chart_path, detections, viz = detect_patterns_from_ticker(
        ticker=ticker,
        model_path=fine_tuned_weights,
        period="2y",
        interval="1d",
        conf_threshold=0.3
    )
    
    if chart_path is None:
        print(f"Failed to generate chart for {ticker}")
        continue
    
    print(f"Chart saved to: {chart_path}")
    
    if len(detections) > 0:
        print(f"Detected {len(detections)} patterns:")
        for i, det in enumerate(detections):
            print(f"  {i+1}. {det['pattern']} (confidence: {det['confidence']:.2f})")
    else:
        print("No patterns detected")

print("\nModel training and testing completed!")
print(f"Trained model path: {fine_tuned_weights}")


Testing pattern detection on new charts:

Analyzing AAPL...


[*********************100%***********************]  1 of 1 completed



image 1/1 /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/inference_charts/AAPL_20250420_174211.jpg: 448x640 (no detections), 59.1ms
Speed: 2.2ms preprocess, 59.1ms inference, 0.5ms postprocess per image at shape (1, 3, 448, 640)
Chart saved to: inference_charts/AAPL_20250420_174211.jpg
No patterns detected

Analyzing MSFT...


[*********************100%***********************]  1 of 1 completed



image 1/1 /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/inference_charts/MSFT_20250420_174213.jpg: 448x640 (no detections), 38.8ms
Speed: 2.0ms preprocess, 38.8ms inference, 0.5ms postprocess per image at shape (1, 3, 448, 640)
Chart saved to: inference_charts/MSFT_20250420_174213.jpg
No patterns detected

Analyzing GOOGL...


[*********************100%***********************]  1 of 1 completed



image 1/1 /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/inference_charts/GOOGL_20250420_174215.jpg: 448x640 (no detections), 38.9ms
Speed: 2.1ms preprocess, 38.9ms inference, 0.5ms postprocess per image at shape (1, 3, 448, 640)
Chart saved to: inference_charts/GOOGL_20250420_174215.jpg
No patterns detected

Analyzing AMZN...


[*********************100%***********************]  1 of 1 completed



image 1/1 /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/inference_charts/AMZN_20250420_174217.jpg: 448x640 (no detections), 38.8ms
Speed: 2.1ms preprocess, 38.8ms inference, 0.5ms postprocess per image at shape (1, 3, 448, 640)
Chart saved to: inference_charts/AMZN_20250420_174217.jpg
No patterns detected

Analyzing TSLA...


[*********************100%***********************]  1 of 1 completed



image 1/1 /home/jhaveri.r/cs5330-project-candlestick-pattern-detection/src/inference_charts/TSLA_20250420_174219.jpg: 448x640 (no detections), 38.9ms
Speed: 2.3ms preprocess, 38.9ms inference, 0.5ms postprocess per image at shape (1, 3, 448, 640)
Chart saved to: inference_charts/TSLA_20250420_174219.jpg
No patterns detected

Model training and testing completed!
Trained model path: runs/detect/train2/weights/best.pt
