In [None]:
%pip install sentence-transformers tqdm pyarrow

In [13]:
import os
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

import warnings
warnings.filterwarnings("ignore")

import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

from transformers import BertTokenizerFast, TFBertModel
from sentence_transformers import util
from tqdm import tqdm

import tensorflow as tf
import pyarrow as pa
import pandas as pd
import numpy as np

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('hfl/chinese-roberta-wwm-ext')
model = TFBertModel.from_pretrained('hfl/chinese-roberta-wwm-ext')

In [3]:
parallelism = 4

spark = SparkSession.builder \
                    .appName('Roberta Sentence Embedding') \
                    .config("spark.dynamicAllocation.enabled", False) \
                    .config("spark.driver.memory", "4g") \
                    .config("spark.cores.max", parallelism) \
                    .config("spark.executor.instances", parallelism) \
                    .config("spark.executor.cores", 1) \
                    .config("spark.executor.memory", "8g") \
                    .enableHiveSupport() \
                    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [4]:
goods = spark.table("ruten.goods")
orders = spark.table("ruten.orders")
goods_vs_keywords = spark.table("ruten.goods_vs_keywords")

data = goods.join(orders, goods.item_id == orders.item_id, how="left") \
        .join(goods_vs_keywords, goods_vs_keywords.GNO == goods.item_id, how="left") \
        .where( col("orders.item_id").isNotNull() | col("goods_vs_keywords.gno").isNotNull() ) \
        .select(goods.item_id, goods.item_name) \
        .withColumn(
            "item_name", 
            lower(trim(
                regexp_replace(regexp_replace(col("item_name"), '&#\w+;', ''), '[^\u4e00-\u9fffa-zA-Z0-9]+', ' ')
        )) ) \
        .distinct()

In [5]:
item_names_df = data.toPandas()
item_names_df

                                                                                

Unnamed: 0,item_id,item_name
0,10060814036060,二手音樂cd 98度單曲 征服未來
1,10060831298854,纖維彈簧透氣墊 沙發 汽車 電腦椅 輪椅 透氣椅墊 透氣坐墊 透氣座墊 小寵物墊 通風 散熱...
2,10060908825888,hikaru no go棋靈王便條紙
3,10060924029666,曹錦輝 tsao chin hui 2005 topps heritage 317
4,10061017636459,墨水王 評價9000 lxmark epson canon hp 高品質台灣填充墨水
...,...,...
1246576,22152635903312,jspb g5 標準版拉把
1246577,22152636113923,限fb買家陳先生下標購買rgeva貳號機
1246578,22152636835665,fun patch臂章圖鑑 限時預購
1246579,22152637425256,h 日版cd 安室奈美 break the rules song nation lovin it


In [6]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = tf.cast(tf.expand_dims(attention_mask, -1), tf.float32)
    sum_embeddings = tf.reduce_sum(token_embeddings * input_mask_expanded, axis=1)
    sum_mask = tf.math.reduce_sum(input_mask_expanded, axis=1)
    sentence_embeddings = sum_embeddings / tf.math.maximum(sum_mask, 1e-9)
    return tf.math.l2_normalize(sentence_embeddings, axis=1)

In [7]:
batch_size = 1000
item_names = item_names_df.item_name.tolist()
num_batches = int(np.ceil( len(item_names) / batch_size ))
batches = [item_names[i * batch_size : (i + 1) * batch_size] for i in range(num_batches)]

In [8]:
embeddings = []

with tqdm(total=num_batches) as pbar:
    for batch in batches:
        encoded_input = tokenizer(
            batch,
            max_length=128, 
            padding=True,
            truncation=True, 
            return_tensors="tf"
        )
        model_output = model(**encoded_input)
        embeddings.extend(mean_pooling(model_output, encoded_input['attention_mask']).numpy())
        pbar.update(1)

item_names_df['embedding'] = embeddings

100%|██████████| 1247/1247 [1:21:41<00:00,  3.93s/it]


In [9]:
item_names_df

Unnamed: 0,item_id,item_name,embedding
0,10060814036060,二手音樂cd 98度單曲 征服未來,"[0.017046705, 0.06566689, 0.012368624, 0.01565..."
1,10060831298854,纖維彈簧透氣墊 沙發 汽車 電腦椅 輪椅 透氣椅墊 透氣坐墊 透氣座墊 小寵物墊 通風 散熱...,"[0.029443927, 0.010932564, -0.006758356, 0.021..."
2,10060908825888,hikaru no go棋靈王便條紙,"[0.013676894, 0.014121184, 0.046186905, 0.0211..."
3,10060924029666,曹錦輝 tsao chin hui 2005 topps heritage 317,"[-0.0232483, 0.026301824, 0.03411942, 0.027663..."
4,10061017636459,墨水王 評價9000 lxmark epson canon hp 高品質台灣填充墨水,"[0.01797593, 0.011586101, -0.014641057, 0.0164..."
...,...,...,...
1246576,22152635903312,jspb g5 標準版拉把,"[0.016783109, 0.006942966, -0.017080883, 0.029..."
1246577,22152636113923,限fb買家陳先生下標購買rgeva貳號機,"[0.04315454, 0.01662747, 0.016533075, 0.019895..."
1246578,22152636835665,fun patch臂章圖鑑 限時預購,"[0.0015115625, 0.05576491, 0.038529985, 0.0143..."
1246579,22152637425256,h 日版cd 安室奈美 break the rules song nation lovin it,"[0.009355322, 0.040769435, 0.02524915, 0.02601..."


In [14]:
item_names_df[['item_id', 'embeddings']].to_parquet("/tmp/roberta_embeddings.parquet", engine='pyarrow', compression='snappy')

In [None]:
!hdfs dfs -copyFromLocal /tmp/roberta_embeddings.parquet "/ruten/roberta_embeddings.parquet"