In [2]:
%pip install transformers torch

Note: you may need to restart the kernel to use updated packages.


In [3]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import torch.nn.functional as F

class SparseEncoder:
    def __init__(self, model_id="naver/splade_v2_max"):
        self._tokenizer = AutoTokenizer.from_pretrained(model_id)
        self._model = AutoModelForMaskedLM.from_pretrained(model_id)
        self._model.eval()
        # extract the ID position to text token mappings
        self._idx2token = {
            idx: token for token, idx in self._tokenizer.get_vocab().items()
            }

    def encode(self, text):
        tokens = self._tokenizer([text], return_tensors='pt')
        with torch.no_grad():
            output = self._model(**tokens)

        vec = torch.max(
            torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
            dim=1)[0].squeeze()

        # extract non-zero positions
        cols = vec.nonzero().squeeze().cpu().tolist()
        # extract the non-zero values
        weights = vec[cols].cpu().tolist()
        # use to create a dictionary of token ID to weight
        sparse_dict = dict(zip(cols, weights))
        # map token IDs to human-readable tokens
        sparse_dict_tokens = {
            self._idx2token[idx]: round(weight, 4) for idx, weight in zip(cols, weights)
        }
        # sort so we can see most relevant tokens first
        sparse_dict_tokens = {
            k: v for k, v in sorted(
                sparse_dict_tokens.items(),
                key=lambda item: item[1],
                reverse=True
            )
        }
        return sparse_dict_tokens

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
sparce_encoder = SparseEncoder()
text =  "Programmed cell death (PCD) is the regulated death of cells within an organism"
sparce_vector = sparce_encoder.encode(text)

In [5]:
sparce_vector

{'pc': 2.5905,
 'death': 2.4068,
 'cell': 2.0662,
 'programmed': 2.0624,
 '##d': 1.9653,
 'organism': 1.6084,
 'regulated': 1.4702,
 'die': 1.4701,
 'meaning': 1.4017,
 'computer': 1.2484,
 'regulation': 1.242,
 'set': 1.1262,
 'regulate': 0.809,
 'cells': 0.7754,
 'code': 0.6977,
 'organisms': 0.6679,
 'kill': 0.6577,
 'controlled': 0.5733,
 'within': 0.5619,
 'master': 0.4954,
 'radio': 0.4952,
 'bacteria': 0.4649,
 'goal': 0.4209,
 'is': 0.3682,
 'result': 0.3376,
 'end': 0.3282,
 'determined': 0.255,
 'monitor': 0.2354,
 'transfer': 0.2328,
 'process': 0.1818,
 'penalty': 0.1534,
 'fear': 0.1488,
 'gene': 0.112,
 'cause': 0.1096,
 'pd': 0.0834,
 'happen': 0.0163}

In [8]:
# 2つのベクトルのスコア計算（類似単語が多いとスコアが高くなる）
def get_score(vec1, vec2):
  score = 0.0
  for k,v in vec1.items():
    if k in vec2:
      score += v * vec2.get(k)
  return score

# 検索対象の文書的なもの
texts = [
    "Global market trends indicate a shift towards renewable energy sources.",
    # グローバル市場の動向は、再生可能エネルギー源へのシフトを示しています。
    "Innovative technologies are transforming the healthcare industry.",
    # 革新的な技術がヘルスケア産業を変革しています。
    "Educational reforms are essential for future workforce development.",
    # 教育改革は将来の労働力開発に不可欠です。
    "Environmental conservation efforts are increasing to combat climate change.",
    # 気候変動と戦うために、環境保全の努力が増加しています。
    "Artificial intelligence is revolutionizing the way businesses operate."
    # 人工知能はビジネスの運営方法を革命的に変えています。
]

sparce_encoder = SparseEncoder()
sparse_vectors = [sparce_encoder.encode(t) for t in texts]

# 検索語的なもの
target = "Climate action" # 気候行動
target_vector = sparce_encoder.encode(target)

sorted_scores = sorted(zip(texts, sparse_vectors), key=lambda x: get_score(x[1], target_vector), reverse=True)
for text, vec in sorted_scores:
  score = get_score(vec, target_vector)
  print(f"{score} {text}")

8.21236882 Environmental conservation efforts are increasing to combat climate change.
1.72356134 Global market trends indicate a shift towards renewable energy sources.
0.7166878 Artificial intelligence is revolutionizing the way businesses operate.
0.25009236 Educational reforms are essential for future workforce development.
0.16637855999999998 Innovative technologies are transforming the healthcare industry.
