In [1]:
from pyspark.ml.feature import (
    OneHotEncoder,
    StringIndexer,
    VectorAssembler,
    Word2Vec,
)
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    array,
    col,
    collect_list,
    concat_ws,
    explode,
    lit,
    lower,
    size,
    split,
    when,
)

In [2]:
spark = SparkSession.builder.appName("bk-imp-features").getOrCreate()

23/05/07 05:35:53 WARN Utils: Your hostname, workspace resolves to a loopback address: 127.0.1.1; using 11.11.1.73 instead (on interface eth0)
23/05/07 05:35:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/05/07 05:35:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
metadata_df = (
    spark.read.parquet("../data/sampled_data")
    .select("product_id", "category", "description")
    .drop_duplicates()
)

In [4]:
metadata_df.show(3)



+----------+--------------------+--------------------+
|product_id|            category|         description|
+----------+--------------------+--------------------+
|B000RN04VU|[Automotive, Moto...|[The rugged unive...|
|B013T0NR4I|[Automotive, Exte...|[Proudly made in ...|
|B019HALHL4|[Automotive, Moto...|[Fitment: Honda C...|
+----------+--------------------+--------------------+
only showing top 3 rows



                                                                                

## Category

In [5]:
df = metadata_df

In [6]:
# Replace empty or null values in the category array with ["Unknown"]
df = df.withColumn(
    "category",
    when(size(col("category")) == 0, lit(["Unknown"])).otherwise(
        col("category")
    ),
)

# Create a new DataFrame by exploding the category array column
exploded_df = df.select(
    col("product_id"),
    col("category"),
    explode(col("category")).alias("single_category"),
    col("description"),
)

# Use StringIndexer to convert the category column into a column of numerical indices
indexer = StringIndexer(inputCol="single_category", outputCol="category_index")
indexed_df = indexer.fit(exploded_df).transform(exploded_df)

# Use OneHotEncoder to convert the numerical indices into a vector of binary features
encoder = OneHotEncoder(inputCol="category_index", outputCol="category_vec")
encoded_df = encoder.fit(indexed_df).transform(indexed_df)

# Group by product_id and aggregate the category_vec column into a list
grouped_df = encoded_df.groupBy("product_id", "category", "description").agg(
    collect_list("category_vec").alias("category_vec_list")
)

In [7]:
metadata_df = grouped_df

## Description

In [8]:
df = metadata_df

In [9]:
# Combine description array into a string and replace empty or null values with "Unknown"
df = df.withColumn(
    "description",
    when(size(col("description")) == 0, lit(["Unknown"])).otherwise(
        col("description")
    ),
)
df = df.withColumn("description", lower(concat_ws(" ", col("description"))))

# Split the description string into words and convert to lowercase
exploded_df = df.withColumn("words", split(lower(col("description")), "\W+"))

# Train a Word2Vec model on the words column
word2vec = Word2Vec(
    vectorSize=100, minCount=5, inputCol="words", outputCol="word2vec_100"
)
model = word2vec.fit(exploded_df)

# Use the Word2Vec model to transform the words column into a vector column
exploded_df = model.transform(exploded_df).select(
    "product_id",
    "category",
    "description",
    "category_vec_list",
    "word2vec_100",
)

23/05/07 05:36:05 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                

In [10]:
metadata_df = exploded_df

In [11]:
metadata_df.show(3)

+----------+--------------------+--------------------+--------------------+--------------------+
|product_id|            category|         description|   category_vec_list|        word2vec_100|
+----------+--------------------+--------------------+--------------------+--------------------+
|B00006HNRY|[Automotive, Tool...|shipping depth: 2...|[(662,[0],[1.0]),...|[0.01941847761830...|
|B00009WC6A|[Automotive, Inte...|the scraper and p...|[(662,[0],[1.0]),...|[0.00362994347123...|
|B0000AS5QB|[Automotive, Tool...|coleman 200001648...|[(662,[0],[1.0]),...|[0.02458205162302...|
+----------+--------------------+--------------------+--------------------+--------------------+
only showing top 3 rows



In [24]:
metadata_df.coalesce(1).write.parquet("../data/produc_features")