<a href="https://colab.research.google.com/github/tomonari-masada/course2023-nlp/blob/main/assignment20230923.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 課題20230923
* TF-IDFベクトルを使って、20 newsgroupsの任意のテキストについて、それと最も似ているテキストを10個返す関数を書こう。
* 10個のうち、元のテキストと同じクラスに属するテキストがいくつあるかを調べよう。
* 全てのテキストについて同じことを行ない、最も似ている上位10個のうち同じクラスのテキスト数の平均値を求めよう。

In [1]:
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.datasets import fetch_20newsgroups

newsgroups = fetch_20newsgroups()

In [2]:
def number_normalizer(tokens):
  return ("#NUMBER" if token[0].isdigit() else token for token in tokens)

class NumberNormalizingVectorizer(TfidfVectorizer):
  def build_tokenizer(self):
    tokenizer = super().build_tokenizer()
    return lambda doc: list(number_normalizer(tokenizer(doc)))

In [3]:
vectorizer = NumberNormalizingVectorizer(stop_words='english', min_df=5)
X = vectorizer.fit_transform(newsgroups.data).toarray()
X.shape

(11314, 23427)

## ミニバッチにして高速化

In [4]:
def find_similar_texts(text_idx, batch_size=1000, num_texts=10):
  similarities = X[text_idx:text_idx+batch_size] @ X.T
  sorted_indices = np.argsort(- similarities, axis=-1)
  num_queries = X[text_idx:text_idx+batch_size].shape[0]
  query_indices = np.arange(text_idx, text_idx + batch_size)[:num_queries]
  anomaly_indices = (sorted_indices[:,0] != query_indices)
  for idx in anomaly_indices.nonzero()[0]:
    print(f"Text {idx+text_idx} is as similar to text {sorted_indices[idx,0]} as to itself.")
    sorted_indices[idx,1], sorted_indices[idx,0] = sorted_indices[idx,0], sorted_indices[idx,1]
    print("----", sorted_indices[idx,0:num_texts+1].tolist())
    print("----", [f"{s:.3f}" for s in similarities[idx,sorted_indices[idx,0:num_texts+1]]])
  return sorted_indices[:,1:num_texts+1]

In [5]:
from tqdm import tqdm

num_texts = 10
batch_size = 1000

num_correct_answers = list()
for text_idx in tqdm(range(0, len(newsgroups.target)+1, batch_size)):
  similar_texts = find_similar_texts(text_idx, batch_size=batch_size, num_texts=num_texts)
  prediction = newsgroups.target[similar_texts]
  ground_truth = newsgroups.target[text_idx:text_idx+batch_size].reshape(-1,1)
  num_correct_answers.append((prediction == ground_truth).sum())

  8%|▊         | 1/12 [00:12<02:20, 12.74s/it]

Text 14 is as similar to text 5392 as to itself.
---- [14, 5392, 8800, 10716, 9863, 11012, 9247, 8516, 10574, 141, 516]
---- ['1.000', '1.000', '0.320', '0.320', '0.317', '0.312', '0.311', '0.311', '0.310', '0.307', '0.307']


 17%|█▋        | 2/12 [00:25<02:08, 12.88s/it]

Text 1063 is as similar to text 4635 as to itself.
---- [1063, 4635, 8556, 1768, 8469, 4955, 10109, 5054, 5976, 9673, 1029]
---- ['1.000', '1.000', '0.314', '0.307', '0.292', '0.287', '0.286', '0.279', '0.275', '0.268', '0.260']


 25%|██▌       | 3/12 [00:38<01:56, 12.97s/it]

Text 2726 is as similar to text 2217 as to itself.
---- [2726, 2217, 5955, 3027, 6727, 9956, 5192, 1929, 7499, 5179, 8506]
---- ['1.000', '1.000', '0.395', '0.376', '0.370', '0.259', '0.258', '0.246', '0.243', '0.236', '0.232']


 58%|█████▊    | 7/12 [01:27<01:02, 12.59s/it]

Text 6518 is as similar to text 3665 as to itself.
---- [6518, 3665, 8800, 10716, 8516, 9863, 10574, 9247, 141, 11012, 5045]
---- ['1.000', '1.000', '0.592', '0.591', '0.589', '0.586', '0.585', '0.579', '0.570', '0.568', '0.567']


 75%|███████▌  | 9/12 [01:53<00:37, 12.58s/it]

Text 8543 is as similar to text 10304 as to itself.
---- [8543, 10304, 8048, 3049, 5597, 9888, 6372, 9046, 5008, 10015, 6284]
---- ['1.000', '1.000', '0.448', '0.298', '0.269', '0.259', '0.248', '0.243', '0.234', '0.215', '0.212']
Text 8701 is as similar to text 9845 as to itself.
---- [8701, 9845, 5156, 10716, 8800, 7511, 1979, 9863, 4352, 11012, 9247]
---- ['1.000', '1.000', '0.249', '0.220', '0.220', '0.220', '0.216', '0.215', '0.214', '0.214', '0.212']


 83%|████████▎ | 10/12 [02:04<00:24, 12.19s/it]

Text 9511 is as similar to text 5106 as to itself.
---- [9511, 5106, 9107, 3482, 10257, 10273, 10059, 8307, 8545, 1162, 10404]
---- ['1.000', '1.000', '0.262', '0.249', '0.249', '0.228', '0.204', '0.179', '0.175', '0.171', '0.156']
Text 9989 is as similar to text 800 as to itself.
---- [9989, 800, 9440, 9139, 7522, 1634, 924, 1454, 10452, 4164, 11069]
---- ['1.000', '1.000', '0.326', '0.233', '0.180', '0.172', '0.168', '0.167', '0.150', '0.149', '0.136']


 92%|█████████▏| 11/12 [02:17<00:12, 12.36s/it]

Text 10777 is as similar to text 2002 as to itself.
---- [10777, 2002, 3893, 11104, 11083, 7043, 5284, 3973, 9350, 6343, 2168]
---- ['1.000', '1.000', '0.359', '0.350', '0.277', '0.246', '0.246', '0.241', '0.232', '0.218', '0.197']


100%|██████████| 12/12 [02:21<00:00, 11.79s/it]


In [6]:
np.array(num_correct_answers).sum()

70081

In [7]:
np.array(num_correct_answers).sum() / len(newsgroups.target)

6.19418419657062

## PyTorchを使って高速化

In [8]:
import torch

def subroutine(text_idx, batch_size):
  X_cuda = torch.tensor(X, dtype=torch.float32).to("cuda")
  return (
      X_cuda[text_idx:text_idx+batch_size] @ X_cuda.T
      ).cpu().numpy().astype(np.double)

In [9]:
def find_similar_texts(text_idx, batch_size=1000, num_texts=10):
  similarities = subroutine(text_idx, batch_size)
  sorted_indices = np.argsort(- similarities, axis=-1)
  num_queries = X[text_idx:text_idx+batch_size].shape[0]
  query_indices = np.arange(text_idx, text_idx + batch_size)[:num_queries]
  anomaly_indices = (sorted_indices[:,0] != query_indices)
  for idx in anomaly_indices.nonzero()[0]:
    print(f"Text {idx+text_idx} is as similar to text {sorted_indices[idx,0]} as to itself.")
    sorted_indices[idx,1], sorted_indices[idx,0] = sorted_indices[idx,0], sorted_indices[idx,1]
    print("----", sorted_indices[idx,0:num_texts+1].tolist())
    print("----", [f"{s:.3f}" for s in similarities[idx,sorted_indices[idx,0:num_texts+1]]])
  return sorted_indices[:,1:num_texts+1]

In [10]:
from tqdm import tqdm

num_texts = 10
batch_size = 1000

num_correct_answers = list()
for text_idx in tqdm(range(0, len(newsgroups.target)+1, batch_size)):
  similar_texts = find_similar_texts(text_idx, batch_size=batch_size, num_texts=num_texts)
  prediction = newsgroups.target[similar_texts]
  ground_truth = newsgroups.target[text_idx:text_idx+batch_size].reshape(-1,1)
  num_correct_answers.append((prediction == ground_truth).sum())

  8%|▊         | 1/12 [00:02<00:25,  2.32s/it]

Text 14 is as similar to text 5392 as to itself.
---- [14, 5392, 8800, 10716, 9863, 11012, 9247, 8516, 10574, 141, 516]
---- ['1.000', '1.000', '0.320', '0.320', '0.317', '0.312', '0.311', '0.311', '0.310', '0.307', '0.307']


 17%|█▋        | 2/12 [00:04<00:20,  2.08s/it]

Text 1063 is as similar to text 4635 as to itself.
---- [1063, 4635, 8556, 1768, 8469, 4955, 10109, 5054, 5976, 9673, 1029]
---- ['1.000', '1.000', '0.314', '0.307', '0.292', '0.287', '0.286', '0.279', '0.275', '0.268', '0.260']


 25%|██▌       | 3/12 [00:06<00:17,  1.97s/it]

Text 2726 is as similar to text 2217 as to itself.
---- [2726, 2217, 5955, 3027, 6727, 9956, 5192, 1929, 7499, 5179, 8506]
---- ['1.000', '1.000', '0.395', '0.376', '0.370', '0.259', '0.258', '0.246', '0.243', '0.236', '0.232']


 58%|█████▊    | 7/12 [00:14<00:10,  2.05s/it]

Text 6518 is as similar to text 3665 as to itself.
---- [6518, 3665, 8800, 10716, 8516, 9863, 10574, 9247, 141, 11012, 5045]
---- ['1.000', '1.000', '0.592', '0.591', '0.589', '0.586', '0.585', '0.579', '0.570', '0.568', '0.567']


 75%|███████▌  | 9/12 [00:17<00:05,  1.94s/it]

Text 8543 is as similar to text 10304 as to itself.
---- [8543, 10304, 8048, 3049, 5597, 9888, 6372, 9046, 5008, 10015, 6284]
---- ['1.000', '1.000', '0.448', '0.298', '0.269', '0.259', '0.248', '0.243', '0.234', '0.215', '0.212']
Text 8701 is as similar to text 9845 as to itself.
---- [8701, 9845, 5156, 10716, 8800, 7511, 1979, 9863, 4352, 11012, 9247]
---- ['1.000', '1.000', '0.249', '0.220', '0.220', '0.220', '0.216', '0.215', '0.214', '0.214', '0.212']


 83%|████████▎ | 10/12 [00:19<00:03,  1.92s/it]

Text 9511 is as similar to text 5106 as to itself.
---- [9511, 5106, 9107, 3482, 10257, 10273, 10059, 8307, 8545, 1162, 10404]
---- ['1.000', '1.000', '0.262', '0.249', '0.249', '0.228', '0.204', '0.179', '0.175', '0.171', '0.156']
Text 9989 is as similar to text 800 as to itself.
---- [9989, 800, 9440, 9139, 7522, 1634, 924, 1454, 10452, 4164, 11069]
---- ['1.000', '1.000', '0.326', '0.233', '0.180', '0.172', '0.168', '0.167', '0.150', '0.149', '0.136']


 92%|█████████▏| 11/12 [00:21<00:01,  1.91s/it]

Text 10777 is as similar to text 2002 as to itself.
---- [10777, 2002, 3893, 11104, 11083, 7043, 5284, 3973, 9350, 6343, 2168]
---- ['1.000', '1.000', '0.359', '0.350', '0.277', '0.246', '0.246', '0.241', '0.232', '0.218', '0.197']


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


In [11]:
np.array(num_correct_answers).sum()

70081

In [12]:
np.array(num_correct_answers).sum() / len(newsgroups.target)

6.19418419657062