In [13]:
import warnings
warnings.filterwarnings("ignore")

import findspark
findspark.init()

from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.stat import Summarizer
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

import pyarrow.parquet as pq
import pandas as pd
import numpy as np
import re

In [2]:
parrallelism = 8

In [3]:
spark = SparkSession.builder \
                    .appName('Mean Categories') \
                    .config("spark.dynamicAllocation.enabled", False) \
                    .config("spark.driver.memory", "4g") \
                    .config("spark.cores.max", parrallelism) \
                    .config("spark.executor.instances", parrallelism) \
                    .config("spark.executor.cores", 1) \
                    .config("spark.executor.memory", "6g") \
                    .enableHiveSupport() \
                    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [8]:
goods = spark.read.table("ruten.goods")
items = spark.read.table("ruten.items")

sellers = goods.join(items, goods.item_id == items.item_id) \
            .select(goods.seller_id, trim(goods.seller_nickname).alias("seller_nickname"), items.item_id) \
            .distinct().alias("s").cache()

seller_bert_mean = ( spark.read.table("ruten.item_bert_embeddings")
                      .join(sellers, on="item_id", how="inner" )
                      .select("s.seller_nickname", array_to_vector("embedding").alias("embedding") )
                      .groupBy("s.seller_nickname").agg(Summarizer.mean(col("embedding")).alias("means")) )

seller_mpnet_mean = ( spark.read.table("ruten.item_mpnet_embeddings")
                      .join(sellers, on="item_id", how="inner" )
                      .select("s.seller_nickname", array_to_vector("embedding").alias("embedding") )
                      .groupBy("s.seller_nickname").agg(Summarizer.mean(col("embedding")).alias("means")) )

seller_roberta_mean = ( spark.read.table("ruten.item_roberta_embeddings")
                      .join(sellers, on="item_id", how="inner" )
                      .select("s.seller_nickname", array_to_vector("embedding").alias("embedding") )
                      .groupBy("s.seller_nickname").agg(Summarizer.mean(col("embedding")).alias("means")) )

In [14]:
seller_bert_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.seller_bert_mean")

seller_mpnet_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.seller_mpnet_mean")

seller_roberta_mean \
    .withColumn("means", vector_to_array(col("means")) ) \
    .write.mode("overwrite") \
    .saveAsTable("ruten.seller_roberta_mean")

                                                                                

In [17]:
spark.read.table("ruten.seller_mpnet_mean").show(10, truncate=80)

+---------------+--------------------------------------------------------------------------------+
|seller_nickname|                                                                           means|
+---------------+--------------------------------------------------------------------------------+
|         01diro|[-0.056929703801870346, -0.04218580946326256, -0.01768401823937893, 0.0588154...|
|     0424634075|[-0.06633256189525127, -0.08288780320435762, -0.012077237013727427, 0.0842352...|
|       0433kink|[-0.06585595346987247, -0.16937275975942614, -0.016787812610467273, 0.1343832...|
|        0435477|[0.06272678822278976, -0.10116132348775864, -0.017282670363783836, 0.04581346...|
|       06100921|[0.01666453063632522, -0.038587991465364804, -0.012625398869016048, 0.0953165...|
|         0723lc|[-0.040193647146224976, -0.06348846852779388, -0.012742171064019203, 0.065037...|
|     0902553630|[0.07662977526585261, -0.06022613992293676, -0.014840484596788883, 0.07983168...|
|    09026

                                                                                

In [16]:
spark.read.table("ruten.seller_roberta_mean").show(10, truncate=80)

+---------------+--------------------------------------------------------------------------------+
|seller_nickname|                                                                           means|
+---------------+--------------------------------------------------------------------------------+
|         01diro|[-0.009230716153979301, 0.0041698310524225235, -0.015844877809286118, 0.00848...|
|     0424634075|[0.006412898132111877, 0.009556040167808533, 0.005233170231804252, 0.00587945...|
|       0433kink|[0.045148008130490774, -0.008256737433839588, -0.003538392561798294, -0.00754...|
|        0435477|[0.06136468052864075, 0.03238063305616379, -0.026400094851851463, 0.007108700...|
|       06100921|[0.034363214889990876, 0.027591509262806384, 0.002779285608116409, 0.01748632...|
|         0723lc|[0.03015187941491604, 0.020539559423923492, -0.010590951889753342, 0.00507627...|
|     0902553630|[0.023145358078181744, 0.0031898897917320332, -0.014856083629031975, 0.005086...|
|    09026

                                                                                