In [None]:
%pip install torch --index-url https://download.pytorch.org/whl/cpu
%pip install sentence-transformers

In [None]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

from transformers import BertTokenizerFast, TFBertModel, BertConfig
from sentence_transformers import util

import tensorflow as tf
import pyarrow.parquet as pq
import pandas as pd
import numpy as np

In [3]:
parallelism = 4

spark = SparkSession.builder \
                    .appName('BERT Sentence Embedding') \
                    .config("spark.dynamicAllocation.enabled", False) \
                    .config("spark.driver.memory", "6g") \
                    .config("spark.driver.maxResultSize", "4g") \
                    .config("spark.cores.max", parallelism) \
                    .config("spark.executor.instances", parallelism) \
                    .config("spark.executor.cores", 1) \
                    .config("spark.executor.memory", "6g") \
                    .enableHiveSupport() \
                    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [None]:
BASE = "hfl/chinese-macbert-base"
MODEL = "bert_model_base.h5"
tokenizer = BertTokenizerFast.from_pretrained(BASE)
model = TFBertModel.from_pretrained(MODEL, config=BertConfig.from_pretrained(BASE))

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = tf.cast(tf.expand_dims(attention_mask, -1), tf.float32)
    sum_embeddings = tf.reduce_sum(token_embeddings * input_mask_expanded, axis=1)
    sum_mask = tf.math.reduce_sum(input_mask_expanded, axis=1)
    sentence_embeddings = sum_embeddings / tf.math.maximum(sum_mask, 1e-9)
    return tf.math.l2_normalize(sentence_embeddings, axis=1)

def encode(query):
    encoded_input = tokenizer(
        query,
        max_length=128, 
        padding=True,
        truncation=True, 
        return_tensors="tf"
    )
    model_output = model(**encoded_input)
    return mean_pooling(model_output, encoded_input['attention_mask']).numpy()

In [4]:
items_df   = spark.read.table("ruten.item_bert_embeddings").toPandas()
ids        = items_df.item_id.values
sellers    = items_df.seller_nickname.values
categories = items_df.category_name.values
item_names = items_df.item_name.values
embeddings = np.vstack(items_df.embedding.values)

def item_semantic_search(query, n = 5):
    query_embeddings = encode(query)
    similarity = util.cos_sim(query_embeddings, embeddings)[0]
    top_results_indices = np.argsort(similarity)[-n:]
    return pd.DataFrame(
        zip(
            categories[top_results_indices],
            ids[top_results_indices],
            item_names[top_results_indices],
            np.array(similarity[top_results_indices])
        ),
        columns=["category", "item_id", "item_name", "similarity"]
    )

items_df

                                                                                

Unnamed: 0,category_name,item_id,item_name,seller_nickname,embedding
0,濾心/配件,11070507947955,水專家 小t不鏽鋼st 白鐵 填充外殼 不銹鋼外殼 單開 2分牙口 600 支,wowowo41,"[-0.07450886, -0.10841466, -0.017112901, 0.025..."
1,男歌手,11070509171665,黃毛丫頭 seal human being 全新未拆 q867 清倉 下標賣,venus50520,"[0.047684427, 0.16763099, -0.0102732675, -0.00..."
2,其他廚房用品,11070911517968,真空包裝袋160mm 270mm 日期印字機 特價供應,tcp248,"[0.02079058, -0.10171947, -0.011393265, 0.0676..."
3,肉類食品,11071009196544,o 黑牛霸霸 o非常好吃的牛排 牛小排 600g 煎 烤 炒 吃完會笑喔,blickcow,"[-0.008497075, 0.14535776, -0.010477434, -0.02..."
4,變速系統零件,11071013890034,老胡單車精品 日本 shimano deore xt 27段定位指撥分離式變把 sl m77...,sjbikeshop,"[0.002893014, -0.13167614, -0.010606956, 0.128..."
...,...,...,...,...,...
539995,其他汽車精品,22152612754311,愛淨小舖 福士打蠟機藍色底盤,xoxo0717,"[-0.027187163, -0.05704191, -0.018376539, 0.02..."
539996,其他機車零組件,22152618990735,jf asdf66999訂購明細,jf-moto,"[-0.047234442, 0.029590204, -0.012442782, 0.11..."
539997,其他漫畫,22152623216709,2本合售 漫畫書 無章釘 小叮噹彩色長篇 宇宙開拓史 上 下 藤子 f 不二雄 青文 脫頁無...,alixson8,"[0.11017776, 0.11026062, -0.018029204, 0.06759..."
539998,其他家庭雜貨,22152624223809,泡澡桶大人 可折疊浴桶 家用全身成人浴缸汗蒸兒童沐浴盆洗澡盆神器,biboyo,"[-0.017054738, -0.1812868, -0.016909659, -0.02..."


In [41]:
item_semantic_search("samsung galaxy s22")

Unnamed: 0,category,item_id,item_name,similarity
0,A、M、N系列,22103792510349,samsung galaxy a20 sm a205gn 故障機 零件機 豐0123,0.822868
1,A、M、N系列,22105865863589,二手商品 samsung galaxy a7 sm a700yd 三星 智慧型手機,0.82796
2,A、M、N系列,22118299819293,samsung galaxy a70,0.829029
3,電視維修零組件,22040561588494,samsung ua55d8000ym s240labmb3v0 7 邏輯板,0.831385
4,A、M、N系列,22008293327761,三星 samsung galaxy a8 2016 施華洛世奇 手機殼,0.83252


In [27]:
category_df    = spark.read.table("ruten.category_bert_mean").toPandas()
category_names = category_df.category_name.values
category_means = np.vstack(category_df.means.values).astype(np.float32)

def category_semantic_search(query, n = 5):
    query_embeddings = model.encode(query)
    similarity = util.cos_sim(query_embeddings, category_means)[0]
    top_results_indices = np.argsort(similarity)[-n:]
    return pd.DataFrame(
        zip(
            category_names[top_results_indices],
            np.array(similarity[top_results_indices])
        ),
        columns=["category_name", "similarity"]
    )

category_df

                                                                                

Unnamed: 0,category_name,means
0,電鑽,"[-0.021314089978136785, -0.10547570231519209, ..."
1,EPSON,"[-0.01817725158769983, -0.03541076516216426, -..."
2,修車工具,"[-0.016647793525897635, -0.05932178240893265, ..."
3,昆蟲用品,"[0.02782750157050651, -0.01444393068217083, -0..."
4,風扇式,"[-0.07861331445389741, -0.10615482247196611, -..."
...,...,...
295,日本,"[0.03946300028958035, 0.08576908184163877, -0...."
296,菸灰缸,"[0.02161824415006262, -0.020578239525831415, -..."
297,其他食品,"[0.032953422213161616, -0.005706825999513139, ..."
298,其他音響設備,"[-0.04809040116088981, -0.12225969023885935, -..."


In [40]:
category_semantic_search("samsung galaxy s22")

Unnamed: 0,category_name,similarity
0,螢幕保護貼,0.586334
1,手機保護殼,0.59149
2,SONY XPERIA,0.591997
3,SAMSUNG,0.670447
4,A、M、N系列,0.711966


In [46]:
seller_df = spark.read.table("ruten.seller_bert_mean").toPandas()
seller_names = seller_df.seller_nickname.values
seller_means = np.vstack(seller_df.means.values).astype(np.float32)

def seller_semantic_search(query, n = 5):
    query_embeddings = model.encode(query)
    similarity = util.cos_sim(query_embeddings, seller_means)[0]
    top_results_indices = np.argsort(similarity)[-n:]
    return pd.DataFrame(
        zip(
            seller_names[top_results_indices],
            np.array(similarity[top_results_indices])
        ),
        columns=["category_name", "similarity"]
    )

seller_df

                                                                                

Unnamed: 0,seller_nickname,means
0,01diro,"[-0.056929703801870346, -0.04218580946326256, ..."
1,0424634075,"[-0.06633256189525127, -0.08288780320435762, -..."
2,0433kink,"[-0.06585595346987247, -0.16937275975942614, -..."
3,0435477,"[0.06272678822278976, -0.10116132348775864, -0..."
4,06100921,"[0.01666453063632522, -0.038587991465364804, -..."
...,...,...
43230,z0912322341,"[0.033563136123120785, -0.06951586715877056, -..."
43231,z45678862,"[-0.002749911043792963, 0.005743665620684624, ..."
43232,zk4833,"[0.06047833226621151, -0.009724953398108482, -..."
43233,znshyutw,"[0.03702324256300926, -0.047632236033678055, -..."


In [47]:
seller_semantic_search("samsung galaxy s22")

Unnamed: 0,category_name,similarity
0,elvis001,0.809088
1,tupolev0312,0.817331
2,tsung4398,0.822877
3,sanpony999,0.829029
4,kenhung88,0.854194
