<a href="https://colab.research.google.com/github/tomonari-masada/course2023-stats2/blob/main/VB_for_LDA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install 'portalocker>=2.0.0'



* ここでランタイムを再起動。

In [2]:
from sklearn.feature_extraction.text import CountVectorizer
import torch
from torchtext.datasets import IMDB

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
train_iter = IMDB(split='train')
corpus = []
for _, line in train_iter:
  corpus.append(line)

In [4]:
vectorizer = CountVectorizer(min_df=200, max_df=0.1, stop_words="english")
X = vectorizer.fit_transform(corpus)
vocab = vectorizer.get_feature_names_out()

In [5]:
X = torch.tensor(X.todense(), dtype=torch.float32)
X = X.to("cuda")

In [6]:
X.shape

torch.Size([25000, 1869])

In [7]:
K = 20
D, W = X.shape

In [8]:
alpha = torch.ones(K) / K
alpha = alpha.to("cuda")

In [9]:
phi = torch.rand((W, K))
phi = phi / phi.sum(0, keepdim=True)
phi = phi.to("cuda")

In [10]:
zeta = (X.sum(1, keepdim=True) / K).repeat(1, K) + alpha

In [11]:
batch_size = 1000

In [12]:
def train(X, zeta, phi, alpha, device="cuda", batch_size=1000):
  K = alpha.shape[0]
  D, W = X.shape
  log_phi = (phi + 1e-10).log()
  new_phi = 0
  new_zeta = torch.zeros_like(zeta)
  for i in range(0, D, batch_size):
    batch = X[i:i+batch_size,:]
    q = torch.rand((batch.shape[0], W, K)).to(device)
    q = q / q.sum(2, keepdim=True)
    for _ in range(10):
      pseudo_count = q * batch.unsqueeze(-1)
      temp_zeta = pseudo_count.sum(1) + alpha
      q = log_phi.unsqueeze(0) + torch.digamma(temp_zeta).unsqueeze(1)
      q = (q - q.max(2, keepdim=True)[0]).exp()
      q = q / q.sum(2, keepdim=True)
    pseudo_count = q * batch.unsqueeze(-1)
    new_zeta[i:i+batch_size,:] = pseudo_count.sum(1) + alpha
    new_phi += pseudo_count.sum(0)
  new_phi = new_phi / new_phi.sum(0, keepdim=True)
  return new_zeta, new_phi

In [13]:
def perplexity(X, zeta, phi, alpha, device="cuda", batch_size=1000):
  K = alpha.shape[0]
  D, W = X.shape
  perplexity = 0
  n_tokens = 0
  for i in range(0, D, batch_size):
    batch = X[i:i+batch_size,:]
    batch_zeta = zeta[i:i+batch_size,:]
    topic_proba = torch.digamma(batch_zeta).exp()
    topic_proba = topic_proba / topic_proba.sum(-1, keepdim=True)
    log_word_proba = (topic_proba @ phi.T).log()
    perplexity += (log_word_proba * batch).sum()
    n_tokens += batch.sum()
  return (- perplexity / n_tokens).exp()

In [14]:
def print_topic_words(phi, prefix=""):
  for k, topic_words in enumerate(vocab[phi.sort(dim=0, descending=True)[1].to("cpu").numpy()[:20,:]].T):
    print(prefix, k, ' '.join(topic_words))

In [15]:
epoch = 1

In [16]:
for _ in range(100):
  zeta, phi = train(X, zeta, phi, alpha)
  print(f"epoch {epoch} | perplexity={perplexity(X, zeta, phi, alpha):.3f}")
  print_topic_words(phi, prefix=f"    {epoch}")
  print("-"*80)
  epoch += 1

epoch 1 | perplexity=1299.381
    1 0 away feel series minutes role comedy family screen sense kind black fun hard place performances mind john reason night men
    1 1 book ll believe actor star hard read music american sure set true wasn said horror completely wonderful job ending screen
    1 2 family original comedy music believe horror feel shows series money ending episode guy instead death hard performance played house minutes
    1 3 stupid effects believe especially played playing probably ending trying play plays doing father instead wasn watched game low house wife
    1 4 worst minutes comes kids instead money guy especially horror version tv sense seeing music fun half boring watched waste girl
    1 5 worst comedy money play shows looks played probably truly fun woman set night maybe written low especially human came hard
    1 6 original woman actor series performance trying having home feel episode effects comedy goes book kids away maybe night beautiful played
    1 7 