In [1]:
import warnings
warnings.filterwarnings("ignore")

import findspark
findspark.init()

from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.stat import Summarizer
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

import pyarrow.parquet as pq
import pandas as pd
import numpy as np
import re

In [2]:
parrallelism = 8

In [3]:
spark = SparkSession.builder \
                    .appName('Mean Categories') \
                    .config("spark.dynamicAllocation.enabled", False) \
                    .config("spark.driver.memory", "4g") \
                    .config("spark.cores.max", parrallelism) \
                    .config("spark.executor.instances", parrallelism) \
                    .config("spark.executor.cores", 1) \
                    .config("spark.executor.memory", "6g") \
                    .enableHiveSupport() \
                    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/06/07 11:18:48 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [5]:
category = spark.read.table("ruten.items").select("category_name", "item_id").distinct().alias("i").cache()

category_bert_mean = ( spark.read.table("ruten.item_bert_embeddings")
                      .join(category, on="item_id", how="inner" )
                      .select("i.category_name", array_to_vector("embedding").alias("embedding") )
                      .groupBy("i.category_name").agg(Summarizer.mean(col("embedding")).alias("means")) )

category_mpnet_mean = ( spark.read.table("ruten.item_mpnet_embeddings")
                      .join(category, on="item_id", how="inner" )
                      .select("i.category_name", array_to_vector("embedding").alias("embedding") )
                      .groupBy("i.category_name").agg(Summarizer.mean(col("embedding")).alias("means")) )

category_roberta_mean = ( spark.read.table("ruten.item_roberta_embeddings")
                      .join(category, on="item_id", how="inner" )
                      .select("i.category_name", array_to_vector("embedding").alias("embedding") )
                      .groupBy("i.category_name").agg(Summarizer.mean(col("embedding")).alias("means")) )

In [6]:
category_bert_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.category_bert_mean")

category_mpnet_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.category_mpnet_mean")

category_roberta_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.category_roberta_mean")

                                                                                

In [7]:
spark.read.table("ruten.category_bert_mean").show(10, truncate=80)

+----------------+--------------------------------------------------------------------------------+
|   category_name|                                                                           means|
+----------------+--------------------------------------------------------------------------------+
|其他登山露營裝備|[-0.01130470043232018, -0.010176716403420494, 0.011288885955397975, -0.002144...|
|        穿戴系列|[-0.01203853264599241, -0.011944506361470741, 0.013148004361007931, -0.002781...|
|           1:144|[-0.016146805694259088, -0.014371978527081863, 0.013816663178287135, -0.00254...|
|        套裝盒組|[-0.013986995813175042, -0.0053478428896143815, 0.01463188558291397, -0.00575...|
|      素面POLO衫|[-0.011080838239891725, -0.007788765488792685, 0.011456314304768176, -0.00214...|
|  電子零件、材料|[-0.01231464171752948, -0.00702363626656497, 0.012938657890587757, -0.0035582...|
|    其他汽車零件|[-0.013964041612507375, -0.007590736771159745, 0.013124226760831017, -0.00300...|
|        肉類食品|[-0.00963452894543

In [8]:
spark.read.table("ruten.category_mpnet_mean").show(10, truncate=80)

+--------------+--------------------------------------------------------------------------------+
| category_name|                                                                           means|
+--------------+--------------------------------------------------------------------------------+
|  其他事務用品|[0.015318099806923842, -0.06862425321653287, -0.01495216189873185, 0.06262199...|
|        南北貨|[0.02332762520182198, -0.006393144300872907, -0.013879214193070342, 0.0454430...|
|      多段變頻|[0.07199920323972056, -0.06529106213045878, -0.015270731338711244, 0.02465232...|
|布料、布飾用品|[0.022684920082191057, -0.08251138266299793, -0.01717377680467156, 0.07631546...|
|        BL小說|[0.06470077777137676, 0.09910147837150007, -0.01709956380849084, 0.0616534070...|
|      其他漫畫|[0.06163656314060818, 0.10261786172115997, -0.015232874225089697, 0.065954661...|
|國中、國小用書|[0.022476823187816082, 0.08254795025494631, -0.016955095660812845, 0.04092600...|
|引擎、車組零件|[-0.04060564998260812, -0.09625831360981695,

In [9]:
spark.read.table("ruten.category_roberta_mean").show(10, truncate=80)

+--------------+--------------------------------------------------------------------------------+
| category_name|                                                                           means|
+--------------+--------------------------------------------------------------------------------+
|  其他事務用品|[0.020256369170148104, 0.019319534728361926, -0.002031191696291541, 0.0089650...|
|        南北貨|[0.035391496221875846, 0.03844949997405114, -0.010247700171351182, 0.01589613...|
|      多段變頻|[0.025965876576532446, 0.030605224953609726, -0.004405737421984625, 0.0134910...|
|布料、布飾用品|[0.034184457729992505, 0.030938423830260815, 0.0010962391452267182, 0.0091732...|
|        BL小說|[0.019612097124749112, 0.031712931973347715, 0.014908200593612614, 0.01273087...|
|      其他漫畫|[0.01580074914411529, 0.033089502161621605, 0.022509570164805454, 0.014541025...|
|國中、國小用書|[8.149879320586605E-4, 0.023566311881067337, -0.011255109917981941, 0.0152919...|
|引擎、車組零件|[0.013603736883026598, 0.014209323022127652,