In [1]:
import warnings
warnings.filterwarnings("ignore")

import findspark
findspark.init()

from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.stat import Summarizer
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

import pyarrow.parquet as pq
import pandas as pd
import numpy as np
import re

In [2]:
parrallelism = 8

In [3]:
spark = SparkSession.builder \
                    .appName('Mean Categories') \
                    .config("spark.dynamicAllocation.enabled", False) \
                    .config("spark.driver.memory", "4g") \
                    .config("spark.cores.max", parrallelism) \
                    .config("spark.executor.instances", parrallelism) \
                    .config("spark.executor.cores", 1) \
                    .config("spark.executor.memory", "6g") \
                    .enableHiveSupport() \
                    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/06/07 11:18:48 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [5]:
category = spark.read.table("ruten.items").select("category_name", "item_id").distinct().alias("i").cache()

category_bert_mean = ( spark.read.table("ruten.item_bert_embeddings")
                      .join(category, on="item_id", how="inner" )
                      .select("i.category_name", array_to_vector("embedding").alias("embedding") )
                      .groupBy("i.category_name").agg(Summarizer.mean(col("embedding")).alias("means")) )

In [6]:
category_bert_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.category_bert_mean")

                                                                                

In [8]:
spark.read.table("ruten.category_bert_mean").show(10, truncate=80)

+--------------+--------------------------------------------------------------------------------+
| category_name|                                                                           means|
+--------------+--------------------------------------------------------------------------------+
|  其他事務用品|[0.015318099806923842, -0.06862425321653287, -0.01495216189873185, 0.06262199...|
|        南北貨|[0.02332762520182198, -0.006393144300872907, -0.013879214193070342, 0.0454430...|
|      多段變頻|[0.07199920323972056, -0.06529106213045878, -0.015270731338711244, 0.02465232...|
|布料、布飾用品|[0.022684920082191057, -0.08251138266299793, -0.01717377680467156, 0.07631546...|
|        BL小說|[0.06470077777137676, 0.09910147837150007, -0.01709956380849084, 0.0616534070...|
|      其他漫畫|[0.06163656314060818, 0.10261786172115997, -0.015232874225089697, 0.065954661...|
|國中、國小用書|[0.022476823187816082, 0.08254795025494631, -0.016955095660812845, 0.04092600...|
|引擎、車組零件|[-0.04060564998260812, -0.09625831360981695,