In [1]:
spark

In [34]:
import numpy as np

import pyspark.sql.functions as F
from pyspark.sql.functions import udf

from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import VectorAssembler

In [69]:
arr = np.random.rand(5, 3)
arr

array([[0.13126456, 0.84343125, 0.85416719],
       [0.61116831, 0.93498236, 0.62372176],
       [0.43403183, 0.15857956, 0.981562  ],
       [0.02555369, 0.56734115, 0.98358823],
       [0.82257104, 0.69700886, 0.3655278 ]])

In [70]:
df = spark.createDataFrame(arr.tolist())

In [73]:
df = (
    df
    .withColumn(
        "arr", 
        F.array([F.col(f"_{i}") for i in range(1, 4)])
    )
)

In [76]:
vecAssembler = VectorAssembler(
    inputCols=[f"_{i}" for i in range(1, 4)],
    outputCol="arr_vec"
)

In [77]:
df = vecAssembler.transform(df)

In [78]:
df.printSchema()

root
 |-- _1: double (nullable = true)
 |-- _2: double (nullable = true)
 |-- _3: double (nullable = true)
 |-- arr: array (nullable = false)
 |    |-- element: double (containsNull = true)
 |-- arr_vec: vector (nullable = true)



In [80]:
df.select("arr_vec").show(10, False)

+-----------------------------------------------------------+
|arr_vec                                                    |
+-----------------------------------------------------------+
|[0.13126455724692954,0.8434312458994674,0.8541671897700044]|
|[0.6111683065527407,0.9349823622851455,0.6237217586275806] |
|[0.4340318281346046,0.15857955558328163,0.9815619995095274]|
|[0.02555368892901677,0.5673411481647872,0.9835882306608179]|
|[0.8225710435721549,0.6970088578919171,0.3655277980475625] |
+-----------------------------------------------------------+



In [81]:
scaler = StandardScaler(
    inputCol="arr_vec", 
    outputCol="scaled_arr",
    withStd=True,
    withMean=True
)

In [82]:
model = scaler.fit(df)

In [85]:
model.mean

DenseVector([0.4049, 0.6403, 0.7617])

In [86]:
model.std

DenseVector([0.3304, 0.3036, 0.2656])

In [83]:
scaled_df = model.transform(df)

In [84]:
scaled_df.select("scaled_arr").show(10, False)

+-------------------------------------------------------------+
|scaled_arr                                                   |
+-------------------------------------------------------------+
|[-0.8282791901263209,0.6691770862936627,0.3481306050061962]  |
|[0.6242676955314593,0.970728187288748,-0.5196013019908413]   |
|[0.0881205192803099,-1.5865876644406904,0.8278300316721146]  |
|[-1.1482391670539374,-0.24020857968776935,0.8354597139453096]|
|[1.2641301423684892,0.18689097054604978,-1.491819048632779]  |
+-------------------------------------------------------------+

