<a href="https://colab.research.google.com/github/tomonari-masada/course2024-stats2/blob/main/VI_for_LDA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational inference for latent Dirichlet allocation

In [None]:
from tqdm.auto import tqdm
from sklearn.feature_extraction.text import CountVectorizer
import torch
import transformers
from datasets import load_dataset

transformers.set_seed(1234)

device = "cuda:0"

* https://huggingface.co/datasets/dell-research-harvard/newswire

### データセットの取得

In [None]:
ds = load_dataset("dell-research-harvard/newswire")

Resolving data files:   0%|          | 0/100 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/23 [00:00<?, ?it/s]

* 文書数を調べる。

In [None]:
len(ds["train"]["article"])

2719607

### データセットの準備
* データセットが大きいので1/50だけ使う。

In [None]:
corpus = ds["train"]["article"][::50]
num_docs = len(corpus)
print(f"{num_docs} documents")

54393 documents


* 語彙セットの作成と出現回数の取得

In [None]:
vectorizer = CountVectorizer(stop_words="english", min_df=0.001, max_df=0.2)
X = vectorizer.fit_transform(corpus)
vocab = vectorizer.get_feature_names_out()

In [None]:
num_words = len(vocab)
print(f"{num_words} different words")

10140 different words


### 事前分布のパラメータの設定

In [None]:
num_topics = 50
alpha = torch.ones((1, num_topics), device=device) / num_topics

In [None]:
alpha

tensor([[0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200,
         0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200,
         0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200,
         0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200,
         0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200, 0.0200,
         0.0200, 0.0200, 0.0200, 0.0200, 0.0200]], device='cuda:0')

### 事後分布のパラメータの初期化

In [None]:
zeta = torch.zeros((num_docs, num_topics), device=device)
phi = torch.zeros((num_words, num_topics), device=device)

batch_size = 1000
for i in tqdm(range(0, num_docs, batch_size)):
  sub_X = torch.tensor(X[i:i+batch_size,:].toarray(), device=device)
  q = torch.randn((sub_X.shape[0], num_words, num_topics), device=device)
  q = torch.softmax(q, dim=-1)
  zeta[i:i+batch_size,:] = alpha + (sub_X.unsqueeze(-1) * q).sum(1)
  phi += (sub_X.unsqueeze(-1) * q).sum(0)
phi /= phi.sum(0, keepdim=True)

  0%|          | 0/55 [00:00<?, ?it/s]

In [None]:
print(zeta[0].sum().item(), X[0].sum())

555.0 554


In [None]:
phi.sum(0)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000], device='cuda:0')

### perplexityを算出するヘルパ関数

In [None]:
def compute_perplexity(quiet=False):
  batch_size = 1000
  perplexity = 0.0
  num_tokens = 0
  for i in tqdm(range(0, X.shape[0], batch_size), disable=quiet):
    sub_X = torch.tensor(X[i:i+batch_size,:].toarray(), device=device)
    normalized_zeta = (
      zeta[i:i+batch_size,:]
      / zeta[i:i+batch_size,:].sum(-1, keepdim=True)
    )
    word_prob = (phi.unsqueeze(0) * normalized_zeta.unsqueeze(1)).sum(-1)
    perplexity += (sub_X * torch.log(1e-10 + word_prob)).sum()
    num_tokens += sub_X.sum()
  perplexity = torch.exp(- perplexity / num_tokens).item()
  return perplexity

In [None]:
perplexity = compute_perplexity()
print(f"perplexity = {perplexity:.3f}")

  0%|          | 0/55 [00:00<?, ?it/s]

perplexity = 4260.828


### 事後分布のパラメータを更新するヘルパ関数

In [None]:
def update(phi, quiet=False):
  batch_size = 1000
  new_phi = torch.zeros((num_words, num_topics), device=device)
  for i in tqdm(range(0, num_docs, batch_size), disable=quiet):
    sub_X = torch.tensor(X[i:i+batch_size,:].toarray(), device=device)
    for _ in range(3):
      q = (
        phi.unsqueeze(0)
        * torch.exp(torch.digamma(zeta[i:i+batch_size,:])).unsqueeze(1)
      )
      q /= q.sum(-1, keepdim=True)
      zeta[i:i+batch_size,:] = alpha + (sub_X.unsqueeze(-1) * q).sum(1)
    new_phi += (sub_X.unsqueeze(-1) * q).sum(0)
  return new_phi / new_phi.sum(0, keepdim=True)

In [None]:
phi = update(phi)
perplexity = compute_perplexity()
print(f"perplexity = {perplexity:.3f}")

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

perplexity = 4254.392


In [None]:
for epoch in range(2, 21):
  phi = update(phi, quiet=True)
  perplexity = compute_perplexity(quiet=True)
  print(f"epoch {epoch} | perplexity = {perplexity:.3f}")

epoch 2 | perplexity = 4231.360
epoch 3 | perplexity = 4182.535
epoch 4 | perplexity = 4093.403
epoch 5 | perplexity = 3935.176
epoch 6 | perplexity = 3668.212
epoch 7 | perplexity = 3344.582
epoch 8 | perplexity = 3106.495
epoch 9 | perplexity = 2969.667
epoch 10 | perplexity = 2890.713
epoch 11 | perplexity = 2842.496
epoch 12 | perplexity = 2811.491
epoch 13 | perplexity = 2790.680
epoch 14 | perplexity = 2776.213
epoch 15 | perplexity = 2765.849
epoch 16 | perplexity = 2758.223
epoch 17 | perplexity = 2752.499
epoch 18 | perplexity = 2748.121
epoch 19 | perplexity = 2744.726
epoch 20 | perplexity = 2742.044
epoch 21 | perplexity = 2739.886
epoch 22 | perplexity = 2738.138
epoch 23 | perplexity = 2736.711
epoch 24 | perplexity = 2735.539
epoch 25 | perplexity = 2734.565
epoch 26 | perplexity = 2733.754
epoch 27 | perplexity = 2733.056
epoch 28 | perplexity = 2732.463
epoch 29 | perplexity = 2731.952
epoch 30 | perplexity = 2731.514


### トピック語の表示

In [None]:
for word_list in vocab[phi.t().argsort(descending=True)[:,:20].cpu().numpy()]:
  print(" ".join(list(word_list)))

united states state soviet 000 court president york american police time night war men government jury week reported city year
000 farm year wheat president farmers party cent price house republican crop agriculture senate committee government corn state prices senator
york 000 american year city st time war night men years air 10 states united old army world miles plane
german war british american tho london air united states great soviet time says york french press army reported world men
army government president war general mccarthy troops state american united party men minister reported police greek states night german city
president state york house party general city governor castro national white time men states government secretary united night 000 union
united states president american war 000 foreign secretary government conference peace state world nations meeting agreement french soviet kissinger talks
police city 000 night men persons man street miles building area york 