## Librarby Import

In [12]:
import pymongo

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import numpy as np
from tqdm import tqdm

from sklearn.gaussian_process import kernels,GaussianProcessRegressor
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import TimeSeriesSplit

import math,os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import TimeSeriesSplit

import plotly.graph_objects as go

import warnings
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)  
pd.set_option('display.max_colwidth', None) 
pd.set_option('display.width', None)


## Data Fetching

In [13]:
class fetch_and_split_data:
    def __init__(self, 
                symbol, 
                db_name = "local", 
                collection_name = "technical_stock_data", 
                mongo_uri="mongodb://localhost:27017/"):
        """
        Initializes the MongoDB connection and prepares the collection for the stock data.

        Args:
            db_name (str): Name of the MongoDB database.
            collection_name (str): Name of the collection inside the database.
            symbol (str): The stock symbol to filter the data (e.g., 'NVDA').
            mongo_uri (str): MongoDB connection URI (default is localhost).
        """
        self.mongo_uri = mongo_uri
        self.client = pymongo.MongoClient(self.mongo_uri)
        self.db = self.client[db_name]
        self.collection = self.db[collection_name]
        self.symbol = symbol
        self.df = None
        self.df_test = None

    def fetch_data(self):
        """Fetches stock data from the MongoDB collection and converts it to a Pandas DataFrame."""
        # Fetch the data from MongoDB
        fetched_data_lst = list(self.collection.find({"symbol": self.symbol,
                                                    "interval": "daily"}))

        # Extract the desired stock symbol's technical data
        if not len(fetched_data_lst) == 0:
            # Extract the 'technical_data' field from the filtered data
            self.df = pd.DataFrame(fetched_data_lst)
        else:
            raise ValueError(f"No data found for symbol: {self.symbol}")

    def split_data(self):
        """Splits the data into training and testing datasets."""
        if self.df is not None:
            # Take the latest rows for training
            self.df_train = self.df[-252*3:-100] 

            # Take the rest (last 100 rows) for testing
            self.df_test = self.df[-100:]
        else:
            raise ValueError("Data not loaded. Call fetch_data() first.")

    def get_train_data(self):
        """Returns the training data."""
        if self.df_train is not None:
            return self.df_train
        else:
            raise ValueError("Data not loaded or split. Call fetch_data() and split_data() first.")

    def get_test_data(self):
        """Returns the test data."""
        if self.df_test is not None:
            return self.df_test
        else:
            raise ValueError("Test data not available. Call split_data() first.")


## Check Data Integrity

In [14]:
#Compute the missing value ratio 
def missing_values(df):
    missing_data = ((df.isnull().sum())/len(df))
    missing_data = missing_data[missing_data.values > 0].sort_values()
    if missing_data.empty:
        print('No missing values')    
    else:
        #Visualize the missing value ratio 
        fig = plt.figure(figsize=(5,5), dpi = 100)
        sns.barplot(x = missing_data.index, y = missing_data.values) 
        plt.xticks(rotation=90)
        plt.title('Features Missing Ratio')
        plt.show()
    

## Data Processing

In [15]:
class prepare_data:
    def __init__(self, exclude_columns=None):
        """
        Initializes the DataPreprocessor with the columns to exclude from log transformation.

        Args:
            exclude_columns (list): List of numeric columns to exclude from log transformation.
        """
        self.exclude_columns = exclude_columns if exclude_columns else ['MACD', 'MACD_SIGNAL', 'MACD_HIST']
        self.ohe = OneHotEncoder(drop='first')  # OneHotEncoder for categorical columns
        
    def preprocess(self, df, train=True):
        """
        Preprocess the dataframe by performing the following steps:
        - Drop 'date' column and rows with missing values
        - Convert alert-related columns to categorical types
        - Log-transform numeric columns except for excluded columns
        - One-hot encode categorical columns
        - Add a timestamp column

        Args:
            df (pd.DataFrame): The input dataframe to preprocess.

        Returns:
            pd.DataFrame: The preprocessed dataframe.
        """
        self.filename = 'train_data.parquet' if train else 'test_data.parquet'

        # Step 1: Drop 'date' column and handle missing values
        df = df.drop(columns=['date'], errors='ignore')  # Avoids error if 'date' is missing
        df = df.dropna()

        # Step 2: Convert relevant columns to category
        df["CandleStickType"] = df["CandleStickType"].astype('category')
        df["Incremental_High"] = df["Incremental_High"].astype('category')
        df["MACD_GOLDEN_CROSS"] = df["MACD_GOLDEN_CROSS"].astype('category')

        alert_columns = df.columns[df.columns.str.contains('Alert')]
        for column in alert_columns:
            df[column] = df[column].astype('category')

        # Step 3: One-hot encode categorical columns
        categorical_df = df.select_dtypes(include=['category'])
        encoded_df = pd.DataFrame(self.ohe.fit_transform(categorical_df).toarray(), 
                                columns=self.ohe.get_feature_names_out(categorical_df.columns))

        # Step 4: Concatenate numeric and encoded categorical data
        numeric_df = df.select_dtypes(include=['float64', 'int64'])
        numeric_df = numeric_df.reset_index(drop=True)
        
        df = pd.concat([numeric_df, encoded_df], axis=1)

        # Step 5: Add timestamp column
        df['timestamp'] = df.index

        # Create target variable
        df = self.create_target(df)
        
        # Save the prepared data for model training
        self.save_data(df, '/Users/yiukitcheung/Documents/Projects/Stocks/train_data_repository')
        
        return df
    
    def create_target(self, df):
        # Compute the return % of the next day
        df['log_daily_return'] = np.log(df.close.pct_change() + 1) * 100
        
        return df
    
    def save_data(self, df, file_path):
        """Saves the preprocessed dataframe to a CSV file."""
        file_path = os.path.join(file_path, self.filename)
        df.to_parquet(file_path, index=False)

In [16]:
# Test the fetch_and_split_data
StockDataPreprocessor = fetch_and_split_data('NVDA')
StockDataPreprocessor.fetch_data()
StockDataPreprocessor.split_data()
df = StockDataPreprocessor.get_train_data()
test_df = StockDataPreprocessor.get_test_data()

# Test the data prepare_data
DataPreprocessor = prepare_data()

df = DataPreprocessor.preprocess(df)
test_df = DataPreprocessor.preprocess(test_df,train=False)

In [17]:
df

Unnamed: 0,169EMA,BodyDiff,low,high,close_t-1,open,close_t-2,MACD_HIST,8EMA,close_t-3,169EMA_Upper,13EMA,atr,MACD,volume,close,169EMA_Lower,MACD_SIGNAL,144EMA,Incremental_High_1.0,dual_channel_Alert_0,dual_channel_Alert_1,MACD_GOLDEN_CROSS_1.0,382_Alert_1,Engulf_Alert_0,Engulf_Alert_1,CandleStickType_red,MACD_Alert_0,MACD_Alert_1,timestamp,log_daily_return
0,17.574974,0.341415,21.852482,22.496377,21.903399,22.102053,21.209587,-0.133755,21.940106,21.076820,18.453723,21.943350,0.637189,0.270444,248555000,22.443468,16.696225,0.404198,18.170701,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,
1,17.627541,0.010981,21.823533,22.111040,22.443468,22.032175,21.903399,-0.117002,21.963006,21.209587,18.508918,21.957608,0.638637,0.257946,217655000,22.043156,16.746164,0.374948,18.224114,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1,-1.799747
2,17.674545,0.049914,21.288454,21.761641,22.043156,21.672793,22.443468,-0.130608,21.887422,21.903399,18.558272,21.909790,0.640915,0.211688,245215000,21.622879,16.790818,0.342296,18.270994,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,2,-1.925018
3,17.709709,0.514119,20.615608,21.382294,21.622879,21.177645,22.043156,-0.196633,21.615446,22.443468,18.595195,21.731752,0.643446,0.096504,343069000,20.663527,16.824224,0.293138,18.303994,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,3,-4.538183
4,17.742323,0.422275,20.431925,20.980983,20.663527,20.904116,21.622879,-0.241976,21.363533,22.043156,18.629439,21.553193,0.645481,-0.009333,218394000,20.481840,16.855206,0.232644,18.334034,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,4,-0.883150
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
651,64.622433,3.916343,83.373013,88.316184,82.618149,83.803941,79.663643,-0.419388,83.506737,82.409180,67.853555,84.106055,3.756030,-0.872801,551011000,87.720284,61.391312,-0.453413,67.516289,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,651,5.992378
652,64.894432,0.161975,85.251699,87.977240,87.720284,87.580311,82.618149,-0.043159,84.447970,79.663643,68.139153,84.625516,3.761530,-0.507361,388971000,87.742287,61.649710,-0.464202,67.795268,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,652,0.025080
653,65.147291,0.837857,86.285532,88.804107,87.742287,87.225369,87.720284,0.112690,84.878980,82.618149,68.404656,84.877230,3.759968,-0.323340,363709000,86.387512,61.889927,-0.436030,68.051713,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,653,-1.556083
654,65.357642,2.035662,81.241373,85.985580,86.387512,85.062738,87.742287,-0.006014,84.467446,87.720284,68.625524,84.612922,3.757597,-0.443547,559863000,83.027077,62.089760,-0.437533,68.258270,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,654,-3.967635
