In [2]:
%pip install pyarrow torch sentence-transformers

Successfully installed click-8.1.3 cmake-3.26.3 joblib-1.2.0 lit-16.0.5.post0 mpmath-1.3.0 networkx-3.1 nltk-3.8.1 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 pyarrow-12.0.0 scikit-learn-1.2.2 sentence-transformers-2.2.2 sentencepiece-0.1.99 sympy-1.12 threadpoolctl-3.1.0 torch-2.0.1 torchvision-0.15.2 triton-2.0.0
Note: you may need to restart the kernel to use updated packages.


In [3]:
import warnings
warnings.filterwarnings("ignore")

import findspark
findspark.init()

from sentence_transformers import SentenceTransformer, util
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

import pyarrow as pa
import pyarrow.parquet as pq

import pandas as pd
import numpy as np

In [4]:
parrallelism = 4
spark = SparkSession.builder \
                    .appName('Mpnet Sentence Embedding') \
                    .config("spark.dynamicAllocation.enabled", False) \
                    .config("spark.driver.memory", "4g") \
                    .config("spark.cores.max", parrallelism) \
                    .config("spark.executor.instances", parrallelism) \
                    .config("spark.executor.cores", 1) \
                    .config("spark.executor.memory", "8g") \
                    .enableHiveSupport() \
                    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [5]:
model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')

Downloading (…)9e268/.gitattributes: 100%|██████████| 690/690 [00:00<00:00, 3.03MB/s]
Downloading (…)_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 1.08MB/s]
Downloading (…)f2cd19e268/README.md: 100%|██████████| 3.77k/3.77k [00:00<00:00, 19.3MB/s]
Downloading (…)cd19e268/config.json: 100%|██████████| 723/723 [00:00<00:00, 3.76MB/s]
Downloading (…)ce_transformers.json: 100%|██████████| 122/122 [00:00<00:00, 722kB/s]
Downloading pytorch_model.bin: 100%|██████████| 1.11G/1.11G [00:12<00:00, 87.6MB/s]
Downloading (…)nce_bert_config.json: 100%|██████████| 53.0/53.0 [00:00<00:00, 293kB/s]
Downloading (…)tencepiece.bpe.model: 100%|██████████| 5.07M/5.07M [00:00<00:00, 66.6MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 239/239 [00:00<00:00, 1.52MB/s]
Downloading (…)9e268/tokenizer.json: 100%|██████████| 9.08M/9.08M [00:01<00:00, 7.38MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 402/402 [00:00<00:00, 2.38MB/s]
Downloading (…)d19e268/modules.json: 100%|██

In [6]:
goods = spark.table("ruten.goods")
orders = spark.table("ruten.orders")
goods_vs_keywords = spark.table("ruten.goods_vs_keywords")

data = goods.join(orders, goods.item_id == orders.item_id, how="left") \
                .join(goods_vs_keywords, goods_vs_keywords.GNO == goods.item_id, how="left") \
                .where( col("orders.item_id").isNotNull() | col("goods_vs_keywords.gno").isNotNull() ) \
                .select(goods.item_id, goods.item_name) \
                .withColumn("item_name", lower(trim(regexp_replace(regexp_replace(col("item_name"), '&#\w+;', ''), '[^\u4e00-\u9fffa-zA-Z0-9]+', ' '))) ) \
                .distinct()

data.printSchema()

root
 |-- item_id: long (nullable = true)
 |-- item_name: string (nullable = true)



In [7]:
item_names_df = data.toPandas()
item_names_df

                                                                                

Unnamed: 0,item_id,item_name
0,10060814036060,二手音樂cd 98度單曲 征服未來
1,10060831298854,纖維彈簧透氣墊 沙發 汽車 電腦椅 輪椅 透氣椅墊 透氣坐墊 透氣座墊 小寵物墊 通風 散熱...
2,10060908825888,hikaru no go棋靈王便條紙
3,10060924029666,曹錦輝 tsao chin hui 2005 topps heritage 317
4,10061017636459,墨水王 評價9000 lxmark epson canon hp 高品質台灣填充墨水
...,...,...
1246576,22152635903312,jspb g5 標準版拉把
1246577,22152636113923,限fb買家陳先生下標購買rgeva貳號機
1246578,22152636835665,fun patch臂章圖鑑 限時預購
1246579,22152637425256,h 日版cd 安室奈美 break the rules song nation lovin it


In [9]:
embeddings = model.encode(item_names_df.item_name.values, show_progress_bar=True)

Batches: 100%|██████████| 313/313 [00:20<00:00, 15.43it/s]


In [10]:
item_names_df['embeddings'] = embeddings.tolist()
item_names_df.head()

Unnamed: 0,item_id,item_name,embeddings
0,10060814036060,二手音樂cd 98度單曲 征服未來,"[0.025065604597330093, 0.0536174438893795, -0...."
1,10060831298854,纖維彈簧透氣墊 沙發 汽車 電腦椅 輪椅 透氣椅墊 透氣坐墊 透氣座墊 小寵物墊 通風 散熱...,"[0.030799880623817444, -0.07103081047534943, -..."
2,10060908825888,hikaru no go棋靈王便條紙,"[0.064357228577137, 0.016072463244199753, -0.0..."
3,10060924029666,曹錦輝 tsao chin hui 2005 topps heritage 317,"[-0.08539468050003052, 0.13184863328933716, -0..."
4,10061017636459,墨水王 評價9000 lxmark epson canon hp 高品質台灣填充墨水,"[-0.028074568137526512, -0.014946662820875645,..."


In [11]:
item_names_df[['item_id', 'embeddings']].to_parquet("/tmp/mpnet_embeddings.parquet", engine='pyarrow', compression='snappy')

In [None]:
!hdfs dfs -copyFromLocal /tmp/mpnet_embeddings.parquet "/ruten/mpnet_embeddings.parquet"