<a href="https://colab.research.google.com/github/tomonari-masada/course2024-stats2/blob/main/VI_for_LDA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational inference for latent Dirichlet allocation

In [None]:
from tqdm.auto import tqdm
from sklearn.feature_extraction.text import CountVectorizer
import torch
import transformers
from datasets import load_dataset

transformers.set_seed(1234)

device = "cuda:0"

* https://huggingface.co/datasets/dell-research-harvard/newswire

### データセットの取得

In [None]:
ds = load_dataset("dell-research-harvard/newswire")

* 文書数を調べる。

In [None]:
len(ds["train"]["article"])

### データセットの準備
* データセットが大きいので1/50だけ使う。

In [None]:
corpus = ds["train"]["article"][::50]
num_docs = len(corpus)
print(f"{num_docs} documents")

* 語彙セットの作成と出現回数の取得

In [None]:
vectorizer = CountVectorizer(stop_words="english", min_df=0.001, max_df=0.2)
X = vectorizer.fit_transform(corpus)
vocab = vectorizer.get_feature_names_out()

In [None]:
num_words = len(vocab)
print(f"{num_words} different words")

### 事前分布のパラメータの設定

In [None]:
num_topics = 50
alpha = torch.ones((1, num_topics), device=device) / num_topics

In [None]:
alpha

### 事後分布のパラメータの初期化

In [None]:
zeta = torch.zeros((num_docs, num_topics), device=device)
phi = torch.zeros((num_words, num_topics), device=device)

batch_size = 1000
for i in tqdm(range(0, num_docs, batch_size)):
  sub_X = torch.tensor(X[i:i+batch_size,:].toarray(), device=device)
  q = torch.randn((sub_X.shape[0], num_words, num_topics), device=device)
  q = torch.softmax(q, dim=-1)
  zeta[i:i+batch_size,:] = alpha + (sub_X.unsqueeze(-1) * q).sum(1)
  phi += (sub_X.unsqueeze(-1) * q).sum(0)
phi /= phi.sum(0, keepdim=True)

In [None]:
print(zeta[0].sum().item(), X[0].sum())

In [None]:
phi.sum(0)

### perplexityを算出するヘルパ関数

In [None]:
def compute_perplexity(quiet=False):
  batch_size = 1000
  perplexity = 0.0
  num_tokens = 0
  for i in tqdm(range(0, X.shape[0], batch_size), disable=quiet):
    sub_X = torch.tensor(X[i:i+batch_size,:].toarray(), device=device)
    normalized_zeta = (
      zeta[i:i+batch_size,:]
      / zeta[i:i+batch_size,:].sum(-1, keepdim=True)
    )
    word_prob = (phi.unsqueeze(0) * normalized_zeta.unsqueeze(1)).sum(-1)
    perplexity += (sub_X * torch.log(1e-10 + word_prob)).sum()
    num_tokens += sub_X.sum()
  perplexity = torch.exp(- perplexity / num_tokens).item()
  return perplexity

In [None]:
perplexity = compute_perplexity()
print(f"perplexity = {perplexity:.3f}")

### 事後分布のパラメータを更新するヘルパ関数

In [None]:
def update(phi, quiet=False):
  batch_size = 1000
  new_phi = torch.zeros((num_words, num_topics), device=device)
  for i in tqdm(range(0, num_docs, batch_size), disable=quiet):
    sub_X = torch.tensor(X[i:i+batch_size,:].toarray(), device=device)
    q = (
      phi.unsqueeze(0)
      * torch.exp(torch.digamma(zeta[i:i+batch_size,:])).unsqueeze(1)
    )
    q /= q.sum(-1, keepdim=True)
    zeta[i:i+batch_size,:] = alpha + (sub_X.unsqueeze(-1) * q).sum(1)
    new_phi += (sub_X.unsqueeze(-1) * q).sum(0)
  return new_phi / new_phi.sum(0, keepdim=True)

In [None]:
phi = update(phi)
perplexity = compute_perplexity()
print(f"perplexity = {perplexity:.3f}")

In [None]:
for epoch in range(2, 31):
  phi = update(phi, quiet=True)
  perplexity = compute_perplexity(quiet=True)
  print(f"epoch {epoch} | perplexity = {perplexity:.3f}")

### トピック語の表示

In [None]:
for word_list in vocab[phi.t().argsort(descending=True)[:,:20].cpu().numpy()]:
  print(" ".join(list(word_list)))