In [1]:
import os
from pathlib import Path

# Get $HOME
home = str(Path.home())

# $SPARK_HOME
spark_home = '/home/xgboost/spark-3.0.0-bin-hadoop2.7'

import findspark
findspark.init(spark_home)

import pyspark
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.linalg import SparseVector, DenseVector, Vectors, VectorUDT

In [2]:
spark = SparkSession.builder\
                .master('local[4]')\
                .getOrCreate()
sc = spark.sparkContext

In [3]:
x = [[1,2,0,0,0], [1,2,3,0,0], [1,2,4,5,0], [1,2,2,5,6]]
df = sc.parallelize(x).toDF(['a', 'b', 'c', 'd', 'e'])
df.show()

+---+---+---+---+---+
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
|  1|  2|  0|  0|  0|
|  1|  2|  3|  0|  0|
|  1|  2|  4|  5|  0|
|  1|  2|  2|  5|  6|
+---+---+---+---+---+



In [4]:
input_cols = ['a', 'b', 'c', 'd', 'e']
vec_assembler = VectorAssembler(inputCols=input_cols, outputCol='features').setHandleInvalid('keep')
#res1 = vec_assembler.transform(df)
#res1.show()

In [5]:
pipeline = Pipeline().setStages([vec_assembler])
pipeline_model = pipeline.fit(df)
res2 = pipeline_model.transform(df)
res2.show()

+---+---+---+---+---+--------------------+
|  a|  b|  c|  d|  e|            features|
+---+---+---+---+---+--------------------+
|  1|  2|  0|  0|  0| (5,[0,1],[1.0,2.0])|
|  1|  2|  3|  0|  0|[1.0,2.0,3.0,0.0,...|
|  1|  2|  4|  5|  0|[1.0,2.0,4.0,5.0,...|
|  1|  2|  2|  5|  6|[1.0,2.0,2.0,5.0,...|
+---+---+---+---+---+--------------------+



In [6]:
toDense = lambda v: DenseVector(v.toArray())
toDenseUdf = F.udf(toDense, VectorUDT())

res2.withColumn('features', toDenseUdf('features')).show()

+---+---+---+---+---+--------------------+
|  a|  b|  c|  d|  e|            features|
+---+---+---+---+---+--------------------+
|  1|  2|  0|  0|  0|[1.0,2.0,0.0,0.0,...|
|  1|  2|  3|  0|  0|[1.0,2.0,3.0,0.0,...|
|  1|  2|  4|  5|  0|[1.0,2.0,4.0,5.0,...|
|  1|  2|  2|  5|  6|[1.0,2.0,2.0,5.0,...|
+---+---+---+---+---+--------------------+



In [7]:
sc.stop()