In [27]:
import os

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from transformers import BertJapaneseTokenizer, BertModel
from dotenv import load_dotenv
import torch
import pandas as pd
import base64
from itertools import zip_longest, filterfalse

In [28]:
ENABLE_INDEX=True
IS_TEST=True

In [29]:
load_dotenv(dotenv_path='.local.env')

True

In [30]:
class SentenceBertJapanese:
    """
    https://huggingface.co/sonoisa/sentence-bert-base-ja-mean-tokens-v2
    """

    def __init__(self, model_name_or_path, device=None):
        self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
        self.model = BertModel.from_pretrained(model_name_or_path)
        self.model.eval()

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model.to(device)

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


    def encode(self, sentences, batch_size=8):
        all_embeddings = []
        iterator = range(0, len(sentences), batch_size)
        for batch_idx in iterator:
            batch = sentences[batch_idx:batch_idx + batch_size]

            encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest", 
                                           truncation=True, return_tensors="pt").to(self.device)
            model_output = self.model(**encoded_input)
            sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')

            all_embeddings.extend(sentence_embeddings)

        # return torch.stack(all_embeddings).numpy()
        return torch.stack(all_embeddings)


MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"  # <- v2です。
model = SentenceBertJapanese(MODEL_NAME)

sentences = ["暴走したAI", "暴走した人工知能"]
sentence_embeddings = model.encode(sentences, batch_size=8)

print("Sentence embeddings:", sentence_embeddings)

Sentence embeddings: tensor([[ 0.1376,  0.5348,  0.0560,  ...,  0.9439,  0.1351,  0.0652],
        [ 0.0604,  0.6181, -0.5663,  ...,  0.0606, -0.8503,  0.1538]],
       grad_fn=<StackBackward0>)


In [31]:
# 何かしら素敵なデータをCSVで用意しておく
df = pd.read_csv('resources/data.csv', engine='python')

In [32]:
titles = df['title'].to_list()
print(len(titles))

20550


In [33]:
ES_URL = os.environ.get("ES_URL")
ES_USER = os.environ.get("ES_USER")
ES_PASS = os.environ.get("ES_PASS")
INDEX_NAME = os.environ.get("ES_INDEX_NAME")
print(ES_URL)
print(ES_USER)
print(ES_PASS)

es = Elasticsearch(ES_URL, basic_auth=(ES_USER, ES_PASS))
es.cat.indices()

https://askd-qa-v2.es.ap-northeast-1.aws.found.io:9243
elastic
cfpewA9RsLfvWcRADa2HEEg7


TextApiResponse('green  open  .internal.alerts-observability.logs.alerts-default-000001    9beEw5cmQR6KwjuPRw26Lw 1 0     0 0   248b   248b\ngreen  open  .internal.alerts-observability.uptime.alerts-default-000001  EwjOj6Y1QEyvCZoBycyVYA 1 0     0 0   248b   248b\ngreen  open  .fleet-file-data-agent-000001                                PkCp0NzMTcCqxJtEVB29dA 1 0     0 0   248b   248b\nyellow open  query_logs-v0.3                                              vL4u4C1XQQexTB1u95_cUg 2 2     0 0   496b   496b\nyellow open  index                                                        8dG5KMxiRLybegWpfvCiWw 2 2     3 0 58.8kb 58.8kb\ngreen  open  .fleet-files-agent-000001                                    Rx9P5HZKRjqHW_EvIbINdQ 1 0     0 0   248b   248b\ngreen  open  .internal.alerts-observability.slo.alerts-default-000001     P_Gr-3uWQamSzAKlTqH8fQ 1 0     0 0   248b   248b\nyellow open  topics-v0.28                                                 0aUKYqzbRmCVflvJMiJkDQ 2 2 50117 0   84mb

In [34]:
# 動作確認用
if IS_TEST:
    titles = titles[:5]
    # titles = ["新型コロナワクチン予防接種について", "胃ガンの手術", '三歳の子供が熱']

In [38]:
def chunk(ln: list, n: int):
    return zip_longest(*[iter(ln)]*n)

In [39]:
index_offset = 0
for chunk in chunk(titles, n=8):
    chunk = filter(None, chunk)
    title_list = list(chunk)
    sentence_embeddings = model.encode(title_list, batch_size=8)
    actions = [{
        "_id": index_offset + _id,
        "_source": {
            "title": title,
            "vector": embedding,
        },
    } for _id, (title, embedding) in enumerate(zip(title_list, sentence_embeddings.tolist()))]
    success, errors = bulk(es, actions, index=INDEX_NAME, refresh=True)
    if len(errors) != 0:
        raise Exception(errors[0])
    index_offset += len(title_list)

In [42]:
search_title = '三歳児の高熱'
search_embeddings = model.encode([search_title], batch_size=8)

query = {
    "query": {
        "term": {
            "title": search_title
        }
    },
    "knn": {
        "field": "vector",
        "query_vector": search_embeddings.tolist()[0],
        "k": 10,
        "num_candidates": 50
    },
    "rank": {
        "rrf": {
            "window_size": 50,
            "rank_constant": 20
        }
    }
}

In [44]:

result = es.search(index=INDEX_NAME, body=query, size=10)
# 検索結果からドキュメントの内容のみ表示
for document in result["hits"]["hits"]:
    print(document["_source"]['title'])

  result = es.search(index=INDEX_NAME, body=query, size=10)
