In [49]:
from typing import List

import numpy as np
import torch
from IPython.display import Image
from transformers.tokenization_bert_japanese import BertJapaneseTokenizer
from transformers import BertModel
from sklearn.metrics.pairwise import cosine_similarity

In [15]:
# 日本語トークナイザ
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')
# 事前学習済みBert
model = BertModel.from_pretrained('cl-tohoku/bert-base-japanese')

In [4]:
# configs
CORPUS_PATH = 'corpus.txt'
ILLUST_YA_CSV_PATH = '/Users/rikeda/Downloads/hackday/illust.csv'

In [34]:
# 試しにぐっちが集めたいらすとやの見出しを使ってみる
import csv
with open(ILLUST_YA_CSV_PATH, 'r') as f:
    reader = csv.reader(f)
    illusts = [(torch.tensor(tokenizer.batch_encode_plus([row[1]], pad_to_max_length=True, add_special_tokens=True)['input_ids']), row[2]) for row in reader]
print(illusts[:3])

[(tensor([[    2, 16682,     5,  2890,     5,  1480,    23,  1896,   758,    24,
             3]]), 'https://4.bp.blogspot.com/-jN5C5Y4xN9s/V2zF20D4INI/AAAAAAAA8AQ/sVYu-ezC_3UKtX4n75AaHUqEskfQQnS7QCLcB/s150/money_character_strong_yen.png'), (tensor([[    2, 25654,  1624,     5,  4307,     3]]), 'https://2.bp.blogspot.com/-57P4j4ba84E/VyNdckfKDVI/AAAAAAAA6No/00EjX61RCQwiFfvQjglB7M55X2xsO65MQCLcB/s400/medicine_kaze.png'), (tensor([[    2, 11122,  2932,     5,  4307,     3]]), 'https://3.bp.blogspot.com/-D1ojODeU7lg/U0pS6kDirII/AAAAAAAAe_Y/LU00iN_kGs0/s400/rose_blue.png')]


In [36]:
# tokenizeの確認
tokenizer.convert_ids_to_tokens(illusts[0][0][0])

['[CLS]', 'お金', 'の', '単位', 'の', 'キャラクター', '(', '強い', '円', ')', '[SEP]']

In [45]:
def get_bert_sentence_vec(model, input_ids: List[List[str]]):
    outputs = model(input_ids)
    last_hidden_states = outputs[0]  # 1つ目の要素が最終層の hidden state
    return last_hidden_states[:, 0, :].detach().numpy()
get_bert_sentence_vec(model, illusts[0][0]).shape

(1, 768)

In [59]:
sentence = tokenizer.batch_encode_plus(["東京都内で25日、新型コロナウイルスの感染者が新たに４８人確認されたことがわかった。１日当たりの感染者は２４日（５５人）を下回ったが、直近１週間の平均は約３９人と高水準が続いている。"])
input_ids = torch.tensor(sentence['input_ids'])
sentence_vec = get_bert_sentence_vec(model, input_ids)
illusts_sentence_vec = np.array([get_bert_sentence_vec(model, illust[0])[0, :] for illust in illusts])

In [69]:
similarity = cosine_similarity(sentence_vec, illusts_sentence_vec)[0]
top3_indices = np.argsort(similarity)[::-1][:3]

print(similarity[top3_indices.tolist()[0]])
print(tokenizer.convert_ids_to_tokens(illusts[top3_indices.tolist()[0]][0][0]))
Image(url=illusts[top3_indices.tolist()[0]][1])

0.7013853
['[CLS]', '一', '##昨', '##々', '##日', '[SEP]']


In [72]:
print(similarity[top3_indices.tolist()[1]])
print(tokenizer.convert_ids_to_tokens(illusts[top3_indices.tolist()[1]][0][0]))
Image(url =  illusts[top3_indices.tolist()[1]][1])

0.67197865
['[CLS]', '家具', 'が', '倒れ', 'て', '窓', 'を', '割', '##る', 'イラスト', '(', '事故', ')', '[SEP]']


In [73]:
print(similarity[top3_indices.tolist()[2]])
print(tokenizer.convert_ids_to_tokens(illusts[top3_indices.tolist()[2]][0][0]))
Image(url =  illusts[top3_indices.tolist()[2]][1])

0.6691485
['[CLS]', '地域', 'を', 'また', '##い', 'だ', '感染', '拡大', 'の', 'イラスト', '(', '感染', '地域', 'から', ')', '[SEP]']
