In [1]:
%pip install transformers torch


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [41]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import torch.nn.functional as F

class SparseEncoder:
    def __init__(self, model_id="naver/splade_v2_max"):
        self._tokenizer = AutoTokenizer.from_pretrained(model_id)
        self._model = AutoModelForMaskedLM.from_pretrained(model_id)
        # model.evalによって、モデルを評価モードに変更する
        self._model.eval()
        # extract the ID position to text token mappings
        self._idx2token = {idx: token for token, idx in self._tokenizer.get_vocab().items()}

    def encode(self, text):
        # return_tensors='pt'とは、PyTorchテンソルを返すことを指定する
        tokens = self._tokenizer([text], return_tensors='pt')
        with torch.no_grad():
            output = self._model(**tokens)

        # output.logitsとは、モデルの出力であり、各トークンのスコアを表す
        logits = output.logits
        relu_logits = torch.relu(logits)
        log_relu_logits = torch.log(1 + relu_logits)
        attention_mask = tokens.attention_mask
        attention_mask_unsqueezed = attention_mask.unsqueeze(-1)
        logits_x_attention_mask = log_relu_logits * attention_mask_unsqueezed

        # torch.maxをdim=1で適用することで、各トークンの最大スコアを取得する
        max_scores_by_tokens = torch.max(logits_x_attention_mask, dim=1).values
        print('logits', logits.shape, logits)
        print('relu_logits', relu_logits.shape, relu_logits)
        print('log_relu_logits', log_relu_logits.shape, log_relu_logits)
        print('attention_mask', attention_mask.shape, attention_mask)
        print('attention_mask_unsqueezed', attention_mask_unsqueezed.shape)
        print('logits_x_attention_mask', logits_x_attention_mask.shape)
        print('max_scores_by_tokens', max_scores_by_tokens.shape)

        vec = max_scores_by_tokens.squeeze() # squeeze()は次元数が1の次元を削除する
        print('max_scores_by_tokens_squeezed', vec.shape)

        # nonzeros()で非ゼロの値のインデックスを返す
        nonzero_max_values = vec.nonzero()
        print('max_scores_by_tokens', nonzero_max_values.shape, nonzero_max_values[0])
        cols = nonzero_max_values.squeeze().cpu().tolist()
        print('cols', cols)
        # 非ゼロ要素のインデックスで重みを取得する
        weights = vec[cols].cpu().tolist()
        print('weights', weights)
        # use to create a dictionary of token ID to weight
        sparse_dict = dict(zip(cols, weights))
        # map token IDs to human-readable tokens
        sparse_dict_tokens = {
            self._idx2token[idx]: round(weight, 4) for idx, weight in zip(cols, weights)
        }
        # sort so we can see most relevant tokens first
        # .items()でキーと値のペアを取得する
        dict_items = sparse_dict_tokens.items()
        print('dict_items', dict_items)
        sorted_sparse_dict_tokens = sorted(
                dict_items,
                key=lambda item: item[1],
                reverse=True
            )
        print('sorted_sparse_dict_tokens', sorted_sparse_dict_tokens)

        sparse_dict_tokens = {
            k: v for k, v in sorted_sparse_dict_tokens
        }
        return sparse_dict_tokens

In [42]:
sparce_encoder = SparseEncoder()
text =  "Programmed cell death (PCD) is the regulated death of cells within an organism"
sparce_vector = sparce_encoder.encode(text)

logits torch.Size([1, 18, 30522]) tensor([[[ -7.4206,  -7.4528,  -7.6990,  ...,  -6.3517,  -6.4922,  -7.4430],
         [-14.1387, -14.2438, -14.4285,  ..., -12.6452, -12.2198, -17.7459],
         [-13.4026, -13.4891, -13.8986,  ..., -11.5533, -11.8254, -13.5814],
         ...,
         [ -6.9746,  -7.0242,  -7.3337,  ...,  -6.0591,  -5.8105,  -8.1359],
         [ -8.2372,  -8.3405,  -8.6611,  ...,  -7.1867,  -7.2440,  -9.9187],
         [ -8.9823,  -9.0616,  -9.2012,  ...,  -7.4842,  -7.4102,  -8.1561]]])
relu_logits torch.Size([1, 18, 30522]) tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
log_relu_logits torch.Size([1, 18, 30522]) tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
       

In [8]:
sparce_vector

{'pc': 2.5905,
 'death': 2.4068,
 'cell': 2.0662,
 'programmed': 2.0624,
 '##d': 1.9653,
 'organism': 1.6084,
 'regulated': 1.4702,
 'die': 1.4701,
 'meaning': 1.4017,
 'computer': 1.2484,
 'regulation': 1.242,
 'set': 1.1262,
 'regulate': 0.809,
 'cells': 0.7754,
 'code': 0.6977,
 'organisms': 0.6679,
 'kill': 0.6577,
 'controlled': 0.5733,
 'within': 0.5619,
 'master': 0.4954,
 'radio': 0.4952,
 'bacteria': 0.4649,
 'goal': 0.4209,
 'is': 0.3682,
 'result': 0.3376,
 'end': 0.3282,
 'determined': 0.255,
 'monitor': 0.2354,
 'transfer': 0.2328,
 'process': 0.1818,
 'penalty': 0.1534,
 'fear': 0.1488,
 'gene': 0.112,
 'cause': 0.1096,
 'pd': 0.0834,
 'happen': 0.0163}

In [45]:
# 2つのベクトルのスコア計算（類似単語が多いとスコアが高くなる）
def get_score(vec1, vec2):
  score = 0.0
  for k,v in vec1.items():
    if k in vec2:
      score += v * vec2.get(k)
  return score

# 検索対象の文書的なもの
texts = [
    "Global market trends indicate a shift towards renewable energy sources.",
    # グローバル市場の動向は、再生可能エネルギー源へのシフトを示しています。
    "Innovative technologies are transforming the healthcare industry.",
    # 革新的な技術がヘルスケア産業を変革しています。
    "Educational reforms are essential for future workforce development.",
    # 教育改革は将来の労働力開発に不可欠です。
    "Environmental conservation efforts are increasing to combat climate change.",
    # 気候変動と戦うために、環境保全の努力が増加しています。
    "Artificial intelligence is revolutionizing the way businesses operate."
    # 人工知能はビジネスの運営方法を革命的に変えています。
]

sparce_encoder = SparseEncoder()
sparse_vectors = [sparce_encoder.encode(t) for t in texts]

logits torch.Size([1, 13, 30522]) tensor([[[ -8.0085,  -7.9910,  -8.0943,  ...,  -7.2088,  -6.7127,  -7.7589],
         [-12.4164, -12.5163, -12.4281,  ..., -10.3584,  -9.8625, -12.4319],
         [-13.6496, -13.4941, -13.7159,  ..., -11.5093, -11.0709, -12.6855],
         ...,
         [-10.0733, -10.0550, -10.2393,  ...,  -8.6230,  -8.7384,  -8.4065],
         [-10.4967, -10.5393, -10.5337,  ...,  -8.2109,  -8.5428,  -9.7887],
         [-10.1714, -10.2071, -10.2485,  ...,  -8.2812,  -8.2802,  -9.3571]]])
relu_logits torch.Size([1, 13, 30522]) tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
log_relu_logits torch.Size([1, 13, 30522]) tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
       

In [52]:
# enumerate()はインデックスと要素を同時に取得する
for index, vec in enumerate(sparse_vectors):
    print(f"text: {texts[index]}")
    print(f"→ vec: {vec}")

text: Global market trends indicate a shift towards renewable energy sources.
→ vec: {'global': 2.3839, 'market': 2.2082, 'renewable': 2.0742, 'trend': 1.894, 'shift': 1.8732, 'energy': 1.7373, 'trends': 1.4855, 'towards': 1.4805, 'indicate': 1.4703, 'toward': 1.4509, 'worldwide': 1.3056, 'geo': 1.2086, 'alternative': 1.1725, 'sustainable': 1.1586, 'source': 1.0792, 'markets': 1.0402, 'power': 0.9669, 'shifts': 0.9642, 'move': 0.9543, 'movement': 0.8496, 'favor': 0.8446, 'change': 0.7916, 'sector': 0.7582, 'signal': 0.7434, 'policy': 0.739, 'platform': 0.6346, 'network': 0.6067, 'distribution': 0.582, 'focus': 0.5818, 'industry': 0.5598, 'emerging': 0.5099, 'current': 0.5051, 'direction': 0.4683, 'increase': 0.4509, 'decline': 0.4401, 'electric': 0.3907, 'pull': 0.3712, 'sources': 0.3509, 'consumer': 0.3298, 'mean': 0.3266, 'evolution': 0.3024, 'symbol': 0.2938, 'reflect': 0.2883, 'it': 0.2872, 'affect': 0.2373, 'digital': 0.2342, 'mainstream': 0.2227, 'goal': 0.201, 'turn': 0.1918, 'n

In [54]:
# 検索語的なもの
target = "Climate action" # 気候行動
target_vector = sparce_encoder.encode(target)

logits torch.Size([1, 4, 30522]) tensor([[[ -9.3298,  -9.3722,  -9.4416,  ...,  -8.0414,  -8.3697,  -7.7104],
         [-15.7526, -15.7268, -15.6716,  ..., -13.0981, -13.6872, -15.3081],
         [-16.4689, -16.3558, -16.5029,  ..., -12.9810, -14.9521, -14.2455],
         [-13.3109, -13.3634, -13.3491,  ..., -10.7415, -10.9095, -12.1695]]])
relu_logits torch.Size([1, 4, 30522]) tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
log_relu_logits torch.Size([1, 4, 30522]) tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
attention_mask torch.Size([1, 4]) tensor([[1, 1, 1, 1]])
attention_mask_unsqueezed torch.Size([1, 4, 1])
logits_x_attention_mask torch.Size([1, 4, 30522])
max_scores_by_tokens torch.Size([1, 30522])
max_scores_by_tokens_squeezed torch.Size([3052

In [56]:
print(f"target: {target}")
print(f"→ vec: {target_vector}")
sorted_scores = sorted(zip(texts, sparse_vectors), key=lambda x: get_score(x[1], target_vector), reverse=True)
for text, vec in sorted_scores:
  score = get_score(vec, target_vector)
  print(f"{score} {text}")

target: Climate action
→ vec: {'action': 2.5925, 'climate': 2.3456, 'weather': 1.7827, 'movement': 1.732, 'cold': 1.3254, 'act': 1.2159, 'decision': 0.6433, 'violence': 0.6117, 'goal': 0.4778, 'meaning': 0.3478, 'sport': 0.3459, 'it': 0.3334, 'activity': 0.327, 'force': 0.3261, 'democratic': 0.0644, 'move': 0.0629, 'idea': 0.0504, 'global': 0.0001}
8.21236882 Environmental conservation efforts are increasing to combat climate change.
1.72356134 Global market trends indicate a shift towards renewable energy sources.
0.7166878 Artificial intelligence is revolutionizing the way businesses operate.
0.25009236 Educational reforms are essential for future workforce development.
0.16637855999999998 Innovative technologies are transforming the healthcare industry.
