In [2]:
import os
import yaml
import nbimporter
from datetime import datetime, date
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType,  DateType, TimestampType
from pyspark.sql.functions import to_date, to_timestamp
from lab_table_manager import TableManager
import yfinance as yf
import time
import random
from lab_pg_database_manager import PGDatabaseManager





def fetch_yfinance_record(symbol_date_pairs):
    try:
        symbol, start_date = symbol_date_pairs
        # Fetch stock data using yfinance
        quote = yf.Ticker(symbol)
        current_date = date.today()
        hist = quote.history(start=start_date, end=current_date)

        # Reset index to include Date as a column and format it
        hist.reset_index(inplace=True)
        hist["Date"] = hist["Date"].dt.date
        
        # limit and stablize the fields of hist
        hist = hist[['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Dividends', 'Stock Splits']]
        import_time = datetime.now().isoformat()
        # Add symbol and import_time to each row
        record_list = [
            tuple(row) + (symbol, import_time) for row in hist.itertuples(index=False)
        ]
        random_sleep_time = random.uniform(0.1, 0.9)
        time.sleep(random_sleep_time)

        # print(record_list)
        return record_list

    except Exception as e:
        print(f"Error fetching data for {symbol}: {e}")
        return []  # Return an empty list on error

In [None]:
import yfinance as yf
from datetime import date
import time
import random
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, DateType, DoubleType, StringType

# # Define the schema for the DataFrame
record_schema = StructType([
    StructField("Date", DateType(), True),
    StructField("Open", DoubleType(), True),
    StructField("High", DoubleType(), True),
    StructField("Low", DoubleType(), True),
    StructField("Close", DoubleType(), True),
    StructField("Volume", DoubleType(), True),
    StructField("Dividends", DoubleType(), True),
    StructField("Stock Splits", DoubleType(), True),
    StructField("Symbol", StringType(), True),
    StructField("ImportTime", StringType(), True)
])



def xparallel_fetch_yfinance_record(spark, symbol_date_pairs, record_schema):
    try:
        # Initialize Spark session

        connection_config_file="cfg_connections.yaml"
        spark_app_name="YFinanceDataFetcher"
        # spark=create_spark_session(spark_app_name)
        
        # spark = SparkSession.builder.appName("YFinanceDataFetcher").getOrCreate()
        
        # Distribute (symbol, start_date) pairs across Spark workers
        record_rdd = spark.sparkContext.parallelize(symbol_date_pairs)
        
        # Fetch data in parallel using mapPartitions to avoid broadcasting Spark session
        mapped_record_rdd = record_rdd.mapPartitions(lambda partition: [record for pair in partition for record in fetch_yfinance_record(pair)])
        
        # Convert RDD to DataFrame on the driver node
        result_df = spark.createDataFrame(mapped_record_rdd)  

        return result_df
    except Exception as e:
        print(f"Error paralleling fetch: {e}")
        return spark.createDataFrame([])
  
