In [351]:
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sudachipy import tokenizer
from sudachipy import dictionary
from pydatrie import DoubleArrayTrie
from elasticsearch import Elasticsearch
from transformers import AutoTokenizer, AutoModel


In [352]:
model_name = "cl-nagoya/ruri-large"

In [353]:
client = Elasticsearch("http://localhost:9200/")

In [354]:

tokenizer_obj = dictionary.Dictionary(config_path="./sudachi.json", dict_type="core").create()  
mode = tokenizer.Tokenizer.SplitMode.C


  tokenizer_obj = dictionary.Dictionary(config_path="./sudachi.json", dict_type="core").create()


In [355]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [356]:
def get_token_attentions(text) -> dict[str, float]:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    attentions = outputs.attentions[-1][0, :, 0].mean(dim=0)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    token_attentions = {}
    current_word = ""
    current_weight = 0

    for token, weight in zip(tokens[1:-1], attentions[1:-1]):
        if token.startswith("##"):
            current_word += token[2:]
            current_weight += float(weight)
            continue

        if current_word:
            token_attentions[current_word] = current_weight
        current_word = token
        current_weight = float(weight.item())

    if current_word:
        token_attentions[current_word] = float(current_weight)

    return token_attentions

In [357]:
def retokenize_with_sudachi(tokens, text):
    """
    Sudachiの形態素解析結果を元に、トークン化を結合する
    """

    tokenizer_obj = dictionary.Dictionary(config_path="./sudachi.json", dict_type="core").create()  
    sudachi_tokens = [m.surface() for m in tokenizer_obj.tokenize(text, mode)]

    result = {}
    for token in sudachi_tokens:
        result[token] = 0
        for t in tokens:
            if t in token:
                result[token] += float(tokens[t])

    return result
    

In [364]:
result = get_token_attentions("qdrantが開発した新しいランキングアルゴリズムであるBM42を試します。")
for k, v in sorted(result.items(), key = lambda item : item[1], reverse=True):
    print(f"{k}: {v:.4f}")

ランキング: 0.2108
qdrant: 0.1261
試し: 0.1148
アルゴリズム: 0.1063
42: 0.0750
ます: 0.0523
を: 0.0492
。: 0.0436
BM: 0.0301
ある: 0.0272
新しい: 0.0171
開発: 0.0149
で: 0.0058
た: 0.0056
が: 0.0054
し: 0.0051


In [358]:
example_text = "半夏厚朴湯と柴胡加竜骨牡蛎湯の併用"
result = get_token_attentions(example_text)
for k, v in sorted(result.items(), key = lambda item : item[1], reverse=True):
    print(f"{k}: {v:.4f}")

print("----------")
result = retokenize_with_sudachi(result, example_text)
for k, v in sorted(result.items(), key = lambda item : item[1], reverse=True):
    print(f"{k}: {v:.4f}")

併用: 0.1649
半夏: 0.1020
湯: 0.0594
厚朴: 0.0523
柴胡: 0.0436
牡蛎: 0.0323
竜骨: 0.0310
加: 0.0213
の: 0.0156
と: 0.0149
----------
半夏厚朴湯: 0.2138
柴胡加竜骨牡蛎湯: 0.1876
併用: 0.1649
の: 0.0156
と: 0.0149


  tokenizer_obj = dictionary.Dictionary(config_path="./sudachi.json", dict_type="core").create()


In [359]:
synonyms = DoubleArrayTrie(
    {
        "ばね指": "弾発指",
        "弾発指": "ばね指"
    }
)

def token_expantion(tokens) -> dict[str, float]:
    """
    トークンの類義語を追加する
    """

    result = {}
    for k, v in tokens.items():
        result[k] = v
        syn = synonyms.get(k)
        if syn is not None:
            result[syn] = v
    return result

In [360]:
example_text = "ばね指の症状について"
result = get_token_attentions(example_text)
for k, v in sorted(result.items(), key = lambda item : item[1], reverse=True):
    print(f"{k}: {v:.4f}")

print("----------")
result = retokenize_with_sudachi(result, example_text)
for k, v in sorted(result.items(), key = lambda item : item[1], reverse=True):
    print(f"{k}: {v:.4f}")

print("----------")
result = token_expantion(result)
for k, v in sorted(result.items(), key = lambda item : item[1], reverse=True):
    print(f"{k}: {v:.4f}")


ばね指: 0.5203
症状: 0.1462
の: 0.0684
に: 0.0675
つい: 0.0506
て: 0.0496
----------
ばね指: 0.5203
症状: 0.1462
の: 0.0684
に: 0.0675
つい: 0.0506
て: 0.0496
----------
ばね指: 0.5203
弾発指: 0.5203
症状: 0.1462
の: 0.0684
に: 0.0675
つい: 0.0506
て: 0.0496


  tokenizer_obj = dictionary.Dictionary(config_path="./sudachi.json", dict_type="core").create()


In [361]:
example_text = "qdrantが開発した新しいランキングアルゴリズムであるBM42を試します。"
result = get_token_attentions(example_text)
for k, v in sorted(result.items(), key = lambda item : item[1], reverse=True):
    print(f"{k}: {v:.4f}")

ランキング: 0.2108
qdrant: 0.1261
試し: 0.1148
アルゴリズム: 0.1063
42: 0.0750
ます: 0.0523
を: 0.0492
。: 0.0436
BM: 0.0301
ある: 0.0272
新しい: 0.0171
開発: 0.0149
で: 0.0058
た: 0.0056
が: 0.0054
し: 0.0051


In [362]:
texts = [
    "qdrantが開発した新しいランキングアルゴリズムであるBM42を試します。",
    "検索ランキングで使われるBM25とは？",
    "局所麻酔(キシロカイン)アレルギー及び迷走神経反射について"
]

In [363]:
for i, t in enumerate(texts):
    tokens = get_token_attentions(t)
    joined_text = ' '.join(tokens.keys())
    doc = {
        "title": t,
        "joined_tokens": joined_text,
        "tokens": tokens
    }
    resp = client.index(index="test-index", id=i+1, document=doc)


curl -X GET \
  http://localhost:9200/test-index/_explain/2 \
  -H 'Content-Type: application/json' \
  -H 'cache-control: no-cache' \
  -d '{
    "query": {
        "bool": {
            "should": [
                {
                    "script_score": {
                        "query": {
                            "bool": {
                                "filter": {
                                    "match": {
                                        "joined_tokens": "BM"
                                    }
                                },
                                "should": [
                                    {
                                        "term": {
                                            "tokens": {
                                                "value": "BM"
                                            }
                                        }
                                    }
                                ]
                            }
                        },
                        "script": {
                            "source": "return _score / _termStats.docFreq().getSum() "
                        }
                    }
                },
                {
                    "script_score": {
                        "query": {
                            "bool": {
                                "filter": {
                                    "match": {
                                        "joined_tokens": "検索"
                                    }
                                },
                                "should": [
                                    {
                                        "term": {
                                            "tokens": {
                                                "value": "検索"
                                            }
                                        }
                                    }
                                ]
                            }
                        },
                        "script": {
                            "source": "return _score / _termStats.docFreq().getSum() "
                        }
                    }
                }
            ]
        }
    }
}'