In [27]:
from pyspark.ml.pipeline import PipelineModel
from pyspark.sql import SparkSession

# Initialize a SparkSession
spark = SparkSession.builder \
    .appName("Mock PySpark DataFrame") \
    .getOrCreate()

# Load the pipeline model
loaded_pipeline_model = PipelineModel.load("product_info_pipeline")

# Create a test DataFrame
test_data = [
    (200.0, "Furniture", 15),
    (2.0, "Stationery", 500)
]
test_columns = ["price", "category", "stock"]
test_df = spark.createDataFrame(test_data, test_columns)

# Transform the test data using the loaded pipeline model
transformed_test_df = loaded_pipeline_model.transform(test_df)

# Show the resulting DataFrame
transformed_test_df.select("features").show(truncate=False)

+----------------+
|features        |
+----------------+
|[200.0,15.0,1.0]|
|[2.0,500.0,2.0] |
+----------------+



In [28]:
test_data

[(200.0, 'Furniture', 15), (2.0, 'Stationery', 500)]

In [6]:
loaded_pipeline_model.stages[0]

StringIndexerModel: uid=StringIndexer_a542918f3da0, handleInvalid=error

In [7]:
loaded_pipeline_model.stages[1]

VectorAssembler_b33983c177b7

In [11]:
loaded_pipeline_model.stages[0].getInputCol()

'category'

In [8]:
labelsArray = loaded_pipeline_model.stages[0].labelsArray
labelsArray

[('Electronics', 'Furniture', 'Stationery')]

In [9]:
loaded_pipeline_model.stages[1].getInputCols()

['price', 'stock', 'category_index']

In [33]:
import numpy as np

# Extract the StringIndexer metadata
string_indexer_metadata = loaded_pipeline_model.stages[0].labelsArray

# Create a mapping from category to index
category_to_index = {label: index for index, label in enumerate(string_indexer_metadata[0])}

# Convert test_data to numpy array
test_data_np = np.array(test_data)

# Map the category column to indices
category_indices = np.array([category_to_index[category] for category in test_data_np[:, 1]])

# Combine all features into a single numpy array
features = np.column_stack((test_data_np[:, 0].astype(float),  # price
                            test_data_np[:, 2].astype(float),  # stock
                            category_indices.astype(float)     # category index
                            )) 

print("Features:\n", features)


Features:
 [[200.  15.   1.]
 [  2. 500.   2.]]
