In [6]:
import warnings
warnings.filterwarnings("ignore")

import findspark
findspark.init()

from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.stat import Summarizer
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

import pyarrow.parquet as pq
import pandas as pd
import numpy as np
import re

In [2]:
parrallelism = 8

In [3]:
spark = SparkSession.builder \
                    .appName('Mean Categories') \
                    .config("spark.dynamicAllocation.enabled", False) \
                    .config("spark.driver.memory", "4g") \
                    .config("spark.cores.max", parrallelism) \
                    .config("spark.executor.instances", parrallelism) \
                    .config("spark.executor.cores", 1) \
                    .config("spark.executor.memory", "6g") \
                    .enableHiveSupport() \
                    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [4]:
category = spark.read.table("ruten.items").select("category_name", "item_id").distinct().cache()

category_mpnet_mean = ( spark.read.table("ruten.item_mpnet_embeddings")
                      .join(category, on="item_id", how="inner" )
                      .select("category_name", array_to_vector("embedding").alias("embedding") )
                      .groupBy("category_name").agg(Summarizer.mean(col("embedding")).alias("means")) )

category_roberta_mean = ( spark.read.table("ruten.item_roberta_embeddings")
                      .join(category, on="item_id", how="inner" )
                      .select("category_name", array_to_vector("embedding").alias("embedding") )
                      .groupBy("category_name").agg(Summarizer.mean(col("embedding")).alias("means")) )

In [7]:
category_mpnet_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.category_mpnet_mean")

category_roberta_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.category_roberta_mean")

                                                                                

In [8]:
spark.read.table("ruten.category_mpnet_mean").show(10, truncate=80)

+-------------+--------------------------------------------------------------------------------+
|category_name|                                                                           means|
+-------------+--------------------------------------------------------------------------------+
|         電鑽|[-0.021314089978136785, -0.10547570231519209, -0.015131004432114804, 0.084993...|
|        EPSON|[-0.01817725158769983, -0.03541076516216426, -0.013079896840095827, 0.0838633...|
|     修車工具|[-0.016647793525897635, -0.05932178240893265, -0.013177707850340642, 0.108768...|
|     昆蟲用品|[0.02782750157050651, -0.01444393068217083, -0.016812594513393456, 0.06576453...|
|       風扇式|[-0.07861331445389741, -0.10615482247196611, -0.012547073829550259, 0.1035614...|
|         油品|[-0.052018190908357714, -0.06605940253191672, -0.012483718835735796, 0.103602...|
|     防毒軟體|[0.001110304127520315, 0.012769164728647949, -0.010819803078864085, 0.0101728...|
|     其他用品|[0.026637781565172026, 0.0012971386956

In [9]:
spark.read.table("ruten.category_roberta_mean").show(10, truncate=80)

+-------------+--------------------------------------------------------------------------------+
|category_name|                                                                           means|
+-------------+--------------------------------------------------------------------------------+
|         電鑽|[0.03079621376129999, 0.011305086300605801, -0.010168589144345522, 0.00601073...|
|        EPSON|[0.0122165122531219, 0.015443971295002369, 0.004824102293412356, 0.0030927372...|
|     修車工具|[0.02087074320814526, 0.01602600673577771, -0.006578294068127924, 0.014337540...|
|     昆蟲用品|[0.011721802285823984, 0.0423861917540528, 0.004003241466825103, -0.006544369...|
|       風扇式|[0.01368033661415166, 0.006734896217600938, -0.004806658724363041, 0.01217407...|
|         油品|[0.021233216216690682, 0.021986425723601613, 0.002847337595034737, -7.6024434...|
|     防毒軟體|[0.014197460872464483, 0.03902770848786932, 0.010386681545504083, 1.414819508...|
|     其他用品|[0.024933663381152802, 0.0264184833508