<a href="https://colab.research.google.com/github/tomonari-masada/course2023-stats1/blob/main/03_text_retrieval_with_multinomial_distributions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 多項分布を使ったテキスト検索

## 説明
* 検索対象の各テキストについて最尤推定で単語確率を求める。
* クエリの尤度を、各テキストについて求めた単語確率を使って計算する。
* このように計算されたクエリの尤度によって、検索対象のテキストをソートする。
* 上記の方法では検索があまりうまくいかないことを確認する。

## 準備

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multinomial
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer

## データセット

In [None]:
train_corpus, _ = fetch_20newsgroups(subset="train", return_X_y=True)
test_corpus, _ = fetch_20newsgroups(subset="test", return_X_y=True)

In [None]:
train_corpus[0]

In [None]:
len(train_corpus), len(test_corpus)

In [None]:
vectorizer = CountVectorizer(min_df=50, max_df=0.1, stop_words="english")
X_train = vectorizer.fit_transform(train_corpus).toarray()
X_test = vectorizer.transform(test_corpus).toarray()

In [None]:
X_train.shape, X_test.shape

In [None]:
vocabulary = vectorizer.get_feature_names_out()
print(vocabulary)

## 最尤推定

In [None]:
X_train_probs = X_train / X_train.sum(axis=1).reshape(-1, 1)

In [None]:
X_train_probs.sum(axis=1)

## 検索

In [None]:
def log_likelihood(x_test, x_train_prob):
  x_test_nonzero = x_test * (x_train_prob > 0)
  rv = multinomial(x_test_nonzero.sum(), x_train_prob)
  return rv.logpmf(x_test_nonzero)

In [None]:
print(test_corpus[0])

In [None]:
log_likelihood(X_test[0], X_train_probs[0])

In [None]:
log_likelihood(X_test[0], X_train_probs[11313])

In [None]:
score = list()
for i in range(X_train.shape[0]):
  score.append(log_likelihood(X_test[0], X_train_probs[i]))
score = np.array(score)
score = np.where(score == 0.0, - np.inf, score)
sorted_train_indices = (- score).argsort()

In [None]:
sorted_train_indices[0]

In [None]:
print(train_corpus[sorted_train_indices[0]])

In [None]:
vocabulary[(X_test[0] * (X_train_probs[sorted_train_indices[0]] > 0)) > 0]