<a href="https://colab.research.google.com/github/sreevanimtcs2502/sreevanimtcs2502/blob/main/plsa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from tqdm import tqdm


import kagglehub
path = kagglehub.dataset_download("lakshmi25npathi/imdb-dataset-of-50k-movie-reviews")
data = pd.read_csv(f"{path}/IMDB Dataset.csv")
data = data.sample(3000, random_state=42)
texts = data['review'].astype(str).tolist()


vectorizer = CountVectorizer(stop_words='english', max_features=1500, min_df=2)
X_all = vectorizer.fit_transform(texts).toarray()
vocab = vectorizer.get_feature_names_out()
D, V = X_all.shape

train_idx, val_idx = train_test_split(np.arange(D), test_size=0.2, random_state=42)
X_train = X_all[train_idx]
X_val = X_all[val_idx]

def perplexity_from_params(X, P_w_z, P_z_d):
    total = X.sum()
    ll = 0.0
    P_w_zT = P_w_z
    for d in range(X.shape[0]):
        p_w_d = P_z_d[d].dot(P_w_zT)
        p_w_d = np.clip(p_w_d, 1e-12, None)
        ll += (X[d] * np.log(p_w_d)).sum()
    perp = np.exp(-ll / total)
    return perp

def run_plsi_em(X_train, X_val, K=10, max_iter=100, smoothing=1e-2,
                tol=1e-4, patience=5, verbose=True):
    D_train, V = X_train.shape
    D_val = X_val.shape[0]
    rng = np.random.RandomState(None)

    P_z = rng.dirichlet(np.ones(K))
    P_w_z = rng.dirichlet(np.ones(V), size=K)
    P_z_d_train = rng.dirichlet(np.ones(K), size=D_train)

    best_state = None
    best_val_perp = np.inf
    no_improve = 0
    history = []

    for it in range(1, max_iter+1):
        N_k_w = np.zeros((K, V))
        N_k_d = np.zeros((K, D_train))
        N_k = np.zeros(K)

        for d in range(D_train):
            x_dw = X_train[d]
            nonzero_idx = np.nonzero(x_dw)[0]
            if nonzero_idx.size == 0:
                continue
            pzd = P_z_d_train[d]
            Pw = P_w_z[:, nonzero_idx]
            numer = (pzd[:, None] * Pw)
            denom = numer.sum(axis=0, keepdims=True) + 1e-12
            post = numer / denom
            counts = x_dw[nonzero_idx]
            N_k_w[:, nonzero_idx] += post * counts
            N_k_d[:, d] += (post * counts).sum(axis=1)
            N_k += (post * counts).sum(axis=1)

        P_w_z = (N_k_w + smoothing)
        P_w_z /= P_w_z.sum(axis=1, keepdims=True)

        P_z_d_train = (N_k_d.T + smoothing)
        P_z_d_train /= P_z_d_train.sum(axis=1, keepdims=True)

        P_z = (N_k + smoothing)
        P_z /= P_z.sum()

        P_z_w = (P_w_z * P_z[:, None])
        P_z_w /= P_z_w.sum(axis=0, keepdims=True) + 1e-12

        P_z_d_val = np.zeros((D_val, K))
        for i, d in enumerate(range(X_val.shape[0])):
            x = X_val[d]
            nz = np.nonzero(x)[0]
            if nz.size == 0:
                P_z_d_val[i] = 1.0 / K
            else:
                P_z_d_val[i] = (x[nz] @ P_z_w[:, nz].T)
                s = P_z_d_val[i].sum()
                if s <= 0:
                    P_z_d_val[i] = 1.0 / K
                else:
                    P_z_d_val[i] /= s

        val_perp = perplexity_from_params(X_val, P_w_z, P_z_d_val)
        history.append(val_perp)
        if verbose:
            print(f"Iter {it:03d}  val_perplexity={val_perp:.2f}")

        if val_perp + tol < best_val_perp:
            best_val_perp = val_perp
            best_state = (P_z.copy(), P_w_z.copy(), P_z_d_train.copy())
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                if verbose:
                    print(f"No improvement for {patience} iters - stopping early.")
                break

    return best_state, best_val_perp, history

n_restarts = 1
best_overall = None
best_perp = np.inf
histories = []
for r in range(n_restarts):
    if r>0:
        print(f"\nRestart {r+1}/{n_restarts}")
    state, valp, hist = run_plsi_em(X_train, X_val, K=15, max_iter=100,
                                   smoothing=0.1, tol=1e-4, patience=7, verbose=True)
    histories.append(hist)
    if valp < best_perp:
        best_perp = valp
        best_overall = state

print(f"\nBest validation perplexity across restarts: {best_perp:.2f}")
P_z_best, P_w_z_best, P_z_d_train_best = best_overall

train_perp = perplexity_from_params(X_train, P_w_z_best, P_z_d_train_best)
P_z_w = (P_w_z_best * P_z_best[:, None])
P_z_w /= P_z_w.sum(axis=0, keepdims=True) + 1e-12
P_z_d_val_final = np.zeros((X_val.shape[0], P_z_w.shape[0]))
for i in range(X_val.shape[0]):
    x = X_val[i]
    nz = np.nonzero(x)[0]
    if nz.size == 0:
        P_z_d_val_final[i] = 1.0 / P_z_w.shape[0]
    else:
        P_z_d_val_final[i] = (x[nz] @ P_z_w[:, nz].T)
        P_z_d_val_final[i] /= P_z_d_val_final[i].sum()
val_perp_final = perplexity_from_params(X_val, P_w_z_best, P_z_d_val_final)

print(f"Train Perplexity (best): {train_perp:.2f}")
print(f"Val   Perplexity (best): {val_perp_final:.2f}")

K = P_w_z_best.shape[0]
num_top_words = 10

print("\n--- Top Words per Topic ---")
for k in range(K):
    top_word_indices = P_w_z_best[k, :].argsort()[-num_top_words:][::-1]
    top_words = [vocab[idx] for idx in top_word_indices]
    print(f"Topic {k+1}: {', '.join(top_words)}")

Iter 001  val_perplexity=707.99
Iter 002  val_perplexity=708.20
Iter 003  val_perplexity=707.77
Iter 004  val_perplexity=706.90
Iter 005  val_perplexity=705.67
Iter 006  val_perplexity=704.09
Iter 007  val_perplexity=702.19
Iter 008  val_perplexity=700.01
Iter 009  val_perplexity=697.61
Iter 010  val_perplexity=695.03
Iter 011  val_perplexity=692.33
Iter 012  val_perplexity=689.55
Iter 013  val_perplexity=686.73
Iter 014  val_perplexity=683.90
Iter 015  val_perplexity=681.08
Iter 016  val_perplexity=678.30
Iter 017  val_perplexity=675.57
Iter 018  val_perplexity=672.91
Iter 019  val_perplexity=670.32
Iter 020  val_perplexity=667.83
Iter 021  val_perplexity=665.43
Iter 022  val_perplexity=663.14
Iter 023  val_perplexity=660.94
Iter 024  val_perplexity=658.85
Iter 025  val_perplexity=656.86
Iter 026  val_perplexity=654.97
Iter 027  val_perplexity=653.17
Iter 028  val_perplexity=651.46
Iter 029  val_perplexity=649.83
Iter 030  val_perplexity=648.28
Iter 031  val_perplexity=646.80
Iter 032