In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import col, lag, avg, when
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.sql.types import FloatType

spark = SparkSession.builder \
    .appName("StockMarketPrediction") \
    .getOrCreate()

data_path = "itc.csv" 
historical_data = spark.read.csv(data_path, header=True, inferSchema=True)

historical_data.show()

window_spec = Window.partitionBy("Symbol").orderBy("Date")

historical_data = historical_data.withColumn("Prev_Close", lag("Close", 1).over(window_spec))
historical_data = historical_data.withColumn("Prev_Volume", lag("Volume", 1).over(window_spec))

historical_data = historical_data.withColumn("MA_Close_5", avg("Close").over(window_spec.rowsBetween(-4, 0)))

historical_data = historical_data.na.fill(0)

historical_data = historical_data.withColumn(
    "label", 
    when(col("Close") > col("Prev_Close"), 1)  # Buy if today's close is higher than previous close
    .when(col("Close") < col("Prev_Close"), 0)  # Sell if today's close is lower
    .otherwise(2)  # Hold instead of -1
)

historical_data.select("Date", "Symbol", "Close", "Prev_Close", "MA_Close_5", "label").show()

feature_columns = ["Prev_Close", "Prev_Volume", "MA_Close_5", "Volume", "VWAP"]

assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")

(train_data, test_data) = historical_data.randomSplit([0.8, 0.2], seed=42)

rf = RandomForestClassifier(labelCol="label", featuresCol="features", numTrees=100)

pipeline = Pipeline(stages=[assembler, rf])

model = pipeline.fit(train_data)

predictions = model.transform(test_data)

predictions.select("features", "label", "prediction").show()

evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print(f"Test Accuracy = {accuracy}")

model.save("model/")


+----------+------+------+----------+------+------+------+------+------+------+-------+--------------------+------+------------------+-----------+
|      Date|Symbol|Series|Prev Close|  Open|  High|   Low|  Last| Close|  VWAP| Volume|            Turnover|Trades|Deliverable Volume|%Deliverble|
+----------+------+------+----------+------+------+------+------+------+------+-------+--------------------+------+------------------+-----------+
|2000-01-03|   ITC|    EQ|     656.0| 694.0| 708.5| 675.0| 708.5| 708.5|701.81| 562715|     3.9491742195E13|  NULL|              NULL|       NULL|
|2000-01-04|   ITC|    EQ|     708.5| 714.0| 729.0| 694.3|710.65|712.35|714.16| 712637|     5.0893789485E13|  NULL|              NULL|       NULL|
|2000-01-05|   ITC|    EQ|    712.35|716.25| 758.9| 660.0| 731.0| 726.2|732.43|1382149|     1.0123247781E14|  NULL|              NULL|       NULL|
|2000-01-06|   ITC|    EQ|     726.2| 741.0| 784.3| 741.0| 784.3| 784.3|776.63| 721618|     5.6042663465E13|  NULL|   