In [2]:
from transformers import BertModel, BertTokenizer
from torch.optim import Adam
import torch
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import datetime
import os

  from .autonotebook import tqdm as notebook_tqdm


## roberta 模型

In [2]:
# 開始測量(時間)
startime = datetime.datetime.now()

df_ = pd.read_csv('./translated_agoda_comments_all_city.csv', header=0)

# 刪除 '綜合評論' 欄位中含有 NaN 的列
df_ = df_.dropna(subset=['綜合評論'])

# 篩選評論數大於等於 30 的飯店
df_filtered = df_.groupby('飯店名稱').filter(lambda x: len(x) >= 30)

# df_filtered = df_filtered[0:100000]

df_filtered = df_filtered.reset_index(drop=True)

# 初始化 RoBERTa-wwm-ext 模型和分詞器
tokenizer = BertTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')
model = BertModel.from_pretrained('hfl/chinese-roberta-wwm-ext')

# 定義每個批次的大小
batch_size = 5000

# 計算需要進行的批次數
n_batches = (len(df_filtered) // batch_size) + 1

# 開始進行批次處理
for batch_idx in range(n_batches):
    print(f"Processing batch {batch_idx + 1} of {n_batches}...")
    
    start_idx = batch_idx * batch_size

    end_idx = (batch_idx + 1) * batch_size

    # 評論文本
    sentences = list(df_filtered.loc[start_idx:end_idx-1, '綜合評論'])

    # 如果 sentences 是空的，則跳過該批次
    if not sentences:
        continue

    # 初始化一個空的 list 用於儲存飯店資訊
    hotel_info_list = []

    # 獲取每個文本的嵌入向量，並儲存對應的飯店資訊
    embeddings = []
    for idx, sentence in enumerate(sentences):
        if isinstance(sentence, str):
            tokens = tokenizer.tokenize(sentence)  # 分詞
            inputs = tokenizer(sentence, return_tensors="pt")
        else:
            print(f"Sentence at index {idx} is not a string. Skipping...")
            continue

        # 檢查序列長度是否超過 512
        if len(inputs["input_ids"][0]) > 512:
            print("The sentence is too long. Skipping...")
            continue

        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.last_hidden_state.mean(dim=1).numpy().flatten())
        hotel_info_list.append(df_filtered.iloc[start_idx + idx].to_dict())

    # 讀取原有的飯店資訊和 embeddings
    if os.path.exists('agoda_roberta_hotel_info_綜合評論.csv') and os.path.exists('agoda_roberta_embeddings_綜合評論.npy'):
        df_hotel_info_old = pd.read_csv('agoda_roberta_hotel_info_綜合評論.csv')
        embeddings_old = np.load('agoda_roberta_embeddings_綜合評論.npy')

    else:
        df_hotel_info_old = pd.DataFrame()
        embeddings_old = np.array([]).reshape(0,768)

    # 儲存飯店資訊與 embeddings
    df_hotel_info_new = pd.DataFrame(hotel_info_list)
    df_hotel_info = pd.concat([df_hotel_info_old, df_hotel_info_new], ignore_index=True)
    df_hotel_info.to_csv('agoda_roberta_hotel_info_綜合評論.csv', index=False)

    embeddings_new = np.array(embeddings)
    embeddings = np.vstack([embeddings_old, embeddings_new])
    np.save('agoda_roberta_embeddings_綜合評論.npy', embeddings)
    
    print(f"第{batch_idx + 1}批次儲存完畢")
    
# 結束測量
endtime = datetime.datetime.now()

# 輸出結果
print("執行時間：", endtime - startime)

  df_ = pd.read_csv('./translated_agoda_comments_all_city.csv', header=0)
Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Processing batch 1 of 21...
Sentence at index 1026 is not a string. Skipping...
Sentence at index 1641 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 4046 is not a string. Skipping...
Sentence at index 4441 is not a string. Skipping...
第1批次儲存完畢
Processing batch 2 of 21...
Sentence at index 309 is not a string. Skipping...
Sentence at index 358 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 2133 is not a string. Skipping...
The sentence is too long. Skipping...
Sentence at index 2343 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 4233 is not a string. Skipping...
The sentence is too long. Skipping...
第2批次儲存完畢
Proc

  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第5批次儲存完畢
Processing batch 6 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 2812 is not a string. Skipping...
Sentence at index 2813 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
第6批次儲存完畢
Processing batch 7 of 21...
The sentence is too long

  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第8批次儲存完畢
Processing batch 9 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 1234 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 3546 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 4479 is not a string. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第9批次儲存完畢
Processing batch 10 of 21...
Sentence at index 373 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 2335 is not a string. Skipping...
The sentence is too long. Skipping...
Sentence at index 4184 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
第10批次儲存完畢
Processing batch 11 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 3035 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第11批次儲存完畢
Processing batch 12 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第12批次儲存完畢
Processing batch 13 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 846 is not a string. Skipping...
Sentence at index 891 is not a string. Skipping...
The sentence is too long. Skipping...
Sentence at index 1130 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 4283 is not a string. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第13批次儲存完畢
Processing batch 14 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第14批次儲存完畢
Processing batch 15 of 21...
Sentence at index 130 is not a string. Skipping...
The sentence is too long. Skipping...
Sentence at index 1737 is not a string. Skipping...
Sentence at index 1892 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 4444 is not a string. Skipping...
The sentence is too long. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第15批次儲存完畢
Processing batch 16 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 3877 is not a string. Skipping...
The sentence is too long. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第16批次儲存完畢
Processing batch 17 of 21...
Sentence at index 0 is not a string. Skipping...
Sentence at index 28 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第17批次儲存完畢
Processing batch 18 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第18批次儲存完畢
Processing batch 19 of 21...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 3266 is not a string. Skipping...
Sentence at index 3355 is not a string. Skipping...
Sentence at index 4985 is not a string. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第19批次儲存完畢
Processing batch 20 of 21...
Sentence at index 174 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
Sentence at index 2266 is not a string. Skipping...
Sentence at index 3221 is not a string. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...
The sentence is too long. Skipping...


  df_hotel_info_old = pd.read_csv('agoda_hotel_info_roberta_綜合評論.csv')


第20批次儲存完畢
Processing batch 21 of 21...
執行時間： 1:53:54.740255
