# Naive Bayes (Gaussian, Multinomial, Complement, Bernoulli, Categorical) + Out-of-core

Naive Bayes is a family of **probabilistic, generative** models.

It’s popular because it can be:
- **fast** (training is mostly counting)
- **strong** on sparse, high-dimensional data (e.g., text)
- **surprisingly good** even when its assumptions are “wrong”

## Learning goals
By the end you should be able to:
- explain Bayes’ rule and why Naive Bayes is a *generative* classifier
- derive the Naive Bayes decision rule in log-space
- understand the **conditional independence** assumption (and its consequences)
- implement (from scratch) **Gaussian NB**, **Multinomial NB**, and **Bernoulli NB**
- know when to use **Complement NB** and **Categorical NB**
- train Naive Bayes **out-of-core** with `partial_fit` on streaming batches

## Table of contents
1. Bayes as “belief update”
2. The naive assumption (conditional independence)
3. Gaussian Naive Bayes (continuous features)
4. Multinomial Naive Bayes (count features)
5. Bernoulli Naive Bayes (binary features)
6. Complement Naive Bayes (imbalanced text)
7. Categorical Naive Bayes (discrete categories)
8. Out-of-core (streaming) fitting


In [None]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio

from dataclasses import dataclass

from scipy.special import logsumexp

from sklearn.datasets import make_blobs
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import (
    GaussianNB,
    MultinomialNB,
    ComplementNB,
    BernoulliNB,
    CategoricalNB,
)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

rng = np.random.default_rng(7)


## 1) Bayes as “belief update”

Bayes’ rule is just a way to update beliefs when you see evidence:

$$
P(y \mid x) = \frac{P(x \mid y)\,P(y)}{P(x)}
$$

- $P(y)$ is the **prior** (what you believed before seeing data)
- $P(x \mid y)$ is the **likelihood** (how compatible the data is with a hypothesis)
- $P(y \mid x)$ is the **posterior** (updated belief)
- $P(x)$ is a normalization constant

### A helpful mental image
Think of $P(y)$ as a *base rate* and $P(x \mid y)$ as an *evidence multiplier*.

Naive Bayes turns this into a classifier by comparing posteriors across classes.


In [None]:
# A classic Bayes example: medical test
# Disease prevalence (prior)
P_D = 0.01

# Test quality
sensitivity = 0.95           # P(test+ | disease)
false_positive_rate = 0.05   # P(test+ | no disease)

# Posterior P(disease | test+)
P_pos = sensitivity * P_D + false_positive_rate * (1 - P_D)
P_D_given_pos = sensitivity * P_D / P_pos

print(f"P(disease)               = {P_D:.3f}")
print(f"P(test+ | disease)       = {sensitivity:.3f}")
print(f"P(test+ | no disease)    = {false_positive_rate:.3f}")
print(f"P(disease | test+)       = {P_D_given_pos:.3f}")

fig = go.Figure()
fig.add_trace(go.Bar(x=["prior P(D)", "posterior P(D|+)"] , y=[P_D, P_D_given_pos]))
fig.update_layout(title="Bayes update: rare disease + positive test", yaxis_title="probability", width=650, height=420)
fig.show()


## 2) The naive assumption (conditional independence)

Naive Bayes assumes:

> given the class $y$, the features $x_1,\dots,x_d$ are conditionally independent.

Mathematically:

$$
P(x \mid y) = \prod_{j=1}^{d} P(x_j \mid y)
$$

This is “naive” because real-world features are often correlated.

### Why it still works often
- The goal of classification is to pick the **argmax** class. Even if probabilities are imperfect, the ranking can be correct.
- Many datasets have “mostly independent enough” signals (especially after preprocessing).
- In text, the bag-of-words representation makes independence *less crazy* than it sounds.

### Log-space is your friend
Products become sums:

$$
\log P(y \mid x) = \log P(y) + \sum_{j=1}^{d} \log P(x_j \mid y) + \text{const}
$$

This avoids numeric underflow and is computationally convenient.


In [None]:
# Visual demo: correlated features (independence is violated)

# Two classes with correlated 2D Gaussians
n = 700
mean0 = np.array([-1.0, -0.5])
mean1 = np.array([+1.0, +0.5])

cov = np.array([[1.0, 0.85], [0.85, 1.0]])  # strong correlation

X0 = rng.multivariate_normal(mean0, cov, size=n // 2)
X1 = rng.multivariate_normal(mean1, cov, size=n // 2)
X_corr = np.vstack([X0, X1])
y_corr = np.array([0] * (n // 2) + [1] * (n // 2))

corr0 = np.corrcoef(X0.T)[0, 1]
corr1 = np.corrcoef(X1.T)[0, 1]
print(f"Correlation (class 0): {corr0:.3f}")
print(f"Correlation (class 1): {corr1:.3f}")

fig = px.scatter(
    x=X_corr[:, 0],
    y=X_corr[:, 1],
    color=y_corr.astype(str),
    title="Correlated features (violates NB independence)",
    labels={"x": "x1", "y": "x2", "color": "class"},
)
fig.update_traces(marker=dict(size=6, opacity=0.6))
fig.update_layout(width=720, height=470)
fig.show()


## 3) Gaussian Naive Bayes (continuous features)

Gaussian NB assumes each feature is normally distributed *within each class*:

$$
(x_j \mid y=c) \sim \mathcal{N}(\mu_{c,j},\,\sigma^2_{c,j})
$$

### Parameter estimation
For each class $c$ and feature $j$:

- $\mu_{c,j}$ is the sample mean
- $\sigma^2_{c,j}$ is the sample variance

### Decision rule (log posterior)
For a sample $x$:

$$
\log P(y=c \mid x) = \log \pi_c + \sum_{j=1}^d \log \mathcal{N}(x_j \mid \mu_{c,j}, \sigma^2_{c,j}) + \text{const}
$$

Where $\pi_c = P(y=c)$ is the class prior.

Anecdote:
> Gaussian NB is like saying: *“In class A, feature 1 tends to be around 3 with some wiggle; in class B, it’s around 7…”* and doing that for each feature independently.


In [None]:
@dataclass
class ScratchGaussianNB:
    var_smoothing: float = 1e-9

    def fit(self, X: np.ndarray, y: np.ndarray):
        X = np.asarray(X, dtype=float)
        y = np.asarray(y)
        self.classes_, y_enc = np.unique(y, return_inverse=True)

        n_classes = self.classes_.shape[0]
        n_features = X.shape[1]

        self.class_count_ = np.bincount(y_enc, minlength=n_classes).astype(float)
        self.class_prior_ = self.class_count_ / self.class_count_.sum()

        self.theta_ = np.zeros((n_classes, n_features), dtype=float)  # means
        self.var_ = np.zeros((n_classes, n_features), dtype=float)    # variances

        for c in range(n_classes):
            Xc = X[y_enc == c]
            self.theta_[c] = Xc.mean(axis=0)
            self.var_[c] = Xc.var(axis=0)

        # variance smoothing (like sklearn)
        overall_var = X.var(axis=0).max()  # scalar
        self.epsilon_ = self.var_smoothing * overall_var
        self.var_ = self.var_ + self.epsilon_
        return self

    def _joint_log_likelihood(self, X: np.ndarray) -> np.ndarray:
        X = np.asarray(X, dtype=float)
        # shape: (n_samples, n_classes)
        n_samples = X.shape[0]
        n_classes = self.classes_.shape[0]

        log_prior = np.log(self.class_prior_ + 1e-300)

        jll = np.empty((n_samples, n_classes), dtype=float)
        for c in range(n_classes):
            mean = self.theta_[c]
            var = self.var_[c]
            # sum_j [ -0.5 log(2π var_j) - (x_j - mean_j)^2 / (2 var_j) ]
            log_prob = -0.5 * np.sum(np.log(2.0 * np.pi * var))
            log_prob = log_prob - 0.5 * np.sum(((X - mean) ** 2) / var, axis=1)
            jll[:, c] = log_prior[c] + log_prob
        return jll

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        jll = self._joint_log_likelihood(X)
        log_norm = logsumexp(jll, axis=1, keepdims=True)
        return np.exp(jll - log_norm)

    def predict(self, X: np.ndarray) -> np.ndarray:
        jll = self._joint_log_likelihood(X)
        return self.classes_[np.argmax(jll, axis=1)]


In [None]:
# A clean Gaussian-ish dataset (two blobs)
X_g, y_g = make_blobs(
    n_samples=800,
    centers=[(-2, -1), (2, 1)],
    cluster_std=[1.2, 1.1],
    random_state=7,
)
X_tr_g, X_te_g, y_tr_g, y_te_g = train_test_split(X_g, y_g, test_size=0.3, random_state=7, stratify=y_g)

scratch_gnb = ScratchGaussianNB(var_smoothing=1e-9).fit(X_tr_g, y_tr_g)
sk_gnb = GaussianNB(var_smoothing=1e-9).fit(X_tr_g, y_tr_g)

pred_scratch = scratch_gnb.predict(X_te_g)
pred_sklearn = sk_gnb.predict(X_te_g)

print("Scratch GaussianNB accuracy:", accuracy_score(y_te_g, pred_scratch))
print("sklearn GaussianNB accuracy:", accuracy_score(y_te_g, pred_sklearn))


In [None]:
def plot_proba_boundary_2d(model, X, y, title: str, grid_steps: int = 220):
    x_min, x_max = X[:, 0].min() - 1.0, X[:, 0].max() + 1.0
    y_min, y_max = X[:, 1].min() - 1.0, X[:, 1].max() + 1.0

    xs = np.linspace(x_min, x_max, grid_steps)
    ys = np.linspace(y_min, y_max, grid_steps)
    xx, yy = np.meshgrid(xs, ys)
    grid = np.c_[xx.ravel(), yy.ravel()]

    proba = model.predict_proba(grid)[:, 1].reshape(xx.shape)

    fig = go.Figure()
    fig.add_trace(go.Contour(
        x=xs,
        y=ys,
        z=proba,
        colorscale="RdBu",
        opacity=0.75,
        contours=dict(showlines=False),
        colorbar=dict(title="P(class=1)"),
    ))

    fig.add_trace(go.Scatter(
        x=X[:, 0],
        y=X[:, 1],
        mode="markers",
        marker=dict(color=y, colorscale="Viridis", size=6, line=dict(width=0.5, color="white")),
        name="data",
    ))

    fig.update_layout(title=title, width=760, height=520)
    fig.update_yaxes(scaleanchor="x", scaleratio=1)
    return fig


fig1 = plot_proba_boundary_2d(scratch_gnb, X_te_g, y_te_g, "Scratch GaussianNB decision surface")
fig1.show()

fig2 = plot_proba_boundary_2d(sk_gnb, X_te_g, y_te_g, "sklearn GaussianNB decision surface")
fig2.show()


### 3.1 `sklearn` GaussianNB parameters

`GaussianNB(priors=None, var_smoothing=1e-9)`

- `priors`: manually set class priors $\pi_c$. Useful when your training data is not representative of deployment.
- `var_smoothing`: adds a small value to variances to prevent numerical issues.

Interpretation of `var_smoothing`:
- too small → can blow up when a feature has tiny variance
- too large → oversmooths and washes out feature differences


## 4) Multinomial Naive Bayes (counts / text)

Multinomial NB is the classic choice for **bag-of-words** style inputs.

Think:
- features are **counts** (how many times word *j* appears)
- a document is generated by repeatedly sampling words from a class-specific distribution

### Model
For class $c$, a vocabulary distribution $\theta_c$ over $V$ words:

$$
\theta_{c,j} \ge 0,\quad \sum_{j=1}^V \theta_{c,j} = 1
$$

Given a document count vector $x \in \mathbb{N}^V$:

$$
P(x \mid y=c) \propto \prod_{j=1}^V \theta_{c,j}^{x_j}
$$

Taking logs:

$$
\log P(y=c \mid x) = \log \pi_c + \sum_{j=1}^V x_j \log \theta_{c,j} + \text{const}
$$

### Smoothing (Dirichlet / Laplace)
Without smoothing, unseen words can give $\theta_{c,j}=0$ and kill probabilities.

With additive smoothing ($\alpha>0$):

$$
\theta_{c,j} = \frac{N_{c,j} + \alpha}{\sum_{k=1}^{V} N_{c,k} + \alpha V}
$$

Anecdote:
> Smoothing is like saying: *“Even if we haven’t seen the word ‘unicorn’ in spam yet, we won’t assume it’s impossible.”*


In [None]:
def make_synthetic_text_dataset(
    n_docs: int = 2000,
    vocab_size: int = 30,
    avg_len: int = 60,
    imbalance: float = 0.5,
    seed: int = 7,
):
    r = np.random.default_rng(seed)

    # Class priors
    p1 = float(imbalance)
    y = (r.random(n_docs) < p1).astype(int)

    # Class-specific word distributions (Dirichlet draws)
    # Make them different by shifting concentration around two different centers.
    base0 = r.random(vocab_size)
    base1 = r.random(vocab_size)
    base0 = base0 / base0.sum()
    base1 = base1 / base1.sum()

    # Sharpen and separate distributions
    theta0 = r.dirichlet(25 * base0 + 1.0)
    theta1 = r.dirichlet(25 * base1 + 1.0)

    # Document lengths
    lengths = r.poisson(lam=avg_len, size=n_docs) + 5

    X = np.zeros((n_docs, vocab_size), dtype=int)
    for i in range(n_docs):
        theta = theta1 if y[i] == 1 else theta0
        X[i] = r.multinomial(n=lengths[i], pvals=theta)

    vocab = [f"w{j:02d}" for j in range(vocab_size)]
    return X, y, theta0, theta1, vocab


X_counts, y_text, theta0, theta1, vocab = make_synthetic_text_dataset(n_docs=3000, vocab_size=40, avg_len=70, imbalance=0.5)
X_tr_t, X_te_t, y_tr_t, y_te_t = train_test_split(X_counts, y_text, test_size=0.3, random_state=7, stratify=y_text)

# Plot the true underlying word probabilities (top words)
idx0 = np.argsort(theta0)[-10:][::-1]
idx1 = np.argsort(theta1)[-10:][::-1]

fig = go.Figure()
fig.add_trace(go.Bar(x=[vocab[i] for i in idx0], y=theta0[idx0], name="class 0"))
fig.add_trace(go.Bar(x=[vocab[i] for i in idx1], y=theta1[idx1], name="class 1"))
fig.update_layout(
    title="Synthetic text: top word probabilities per class (ground truth)",
    barmode="group",
    xaxis_title="word",
    yaxis_title="probability",
    width=900,
    height=450,
)
fig.show()


In [None]:
@dataclass
class ScratchMultinomialNB:
    alpha: float = 1.0
    fit_prior: bool = True

    def fit(self, X: np.ndarray, y: np.ndarray):
        X = np.asarray(X)
        if np.any(X < 0):
            raise ValueError("MultinomialNB expects non-negative counts")

        y = np.asarray(y)
        self.classes_, y_enc = np.unique(y, return_inverse=True)
        n_classes = self.classes_.shape[0]
        n_features = X.shape[1]

        class_count = np.bincount(y_enc, minlength=n_classes).astype(float)

        if self.fit_prior:
            self.class_log_prior_ = np.log(class_count / class_count.sum())
        else:
            self.class_log_prior_ = np.full(n_classes, -np.log(n_classes), dtype=float)

        # feature counts per class
        feature_count = np.zeros((n_classes, n_features), dtype=float)
        for c in range(n_classes):
            feature_count[c] = X[y_enc == c].sum(axis=0)

        smoothed_fc = feature_count + self.alpha
        smoothed_cc = smoothed_fc.sum(axis=1, keepdims=True)
        self.feature_log_prob_ = np.log(smoothed_fc) - np.log(smoothed_cc)
        return self

    def _joint_log_likelihood(self, X: np.ndarray) -> np.ndarray:
        X = np.asarray(X)
        return X @ self.feature_log_prob_.T + self.class_log_prior_[None, :]

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        jll = self._joint_log_likelihood(X)
        return np.exp(jll - logsumexp(jll, axis=1, keepdims=True))

    def predict(self, X: np.ndarray) -> np.ndarray:
        jll = self._joint_log_likelihood(X)
        return self.classes_[np.argmax(jll, axis=1)]


In [None]:
# Compare scratch vs sklearn MultinomialNB
scratch_mnb = ScratchMultinomialNB(alpha=1.0, fit_prior=True).fit(X_tr_t, y_tr_t)
sk_mnb = MultinomialNB(alpha=1.0, fit_prior=True).fit(X_tr_t, y_tr_t)

pred_scratch = scratch_mnb.predict(X_te_t)
pred_sklearn = sk_mnb.predict(X_te_t)

print("Scratch MultinomialNB accuracy:", accuracy_score(y_te_t, pred_scratch))
print("sklearn MultinomialNB accuracy:", accuracy_score(y_te_t, pred_sklearn))

print()
print("Classification report (sklearn MultinomialNB):")
print(classification_report(y_te_t, pred_sklearn, digits=3))


In [None]:
# Effect of smoothing alpha
alphas = np.logspace(-3, 1, 20)
acc = []
for a in alphas:
    m = MultinomialNB(alpha=float(a)).fit(X_tr_t, y_tr_t)
    acc.append(accuracy_score(y_te_t, m.predict(X_te_t)))

fig = go.Figure()
fig.add_trace(go.Scatter(x=alphas, y=acc, mode="lines+markers"))
fig.update_layout(
    title="MultinomialNB: test accuracy vs smoothing alpha",
    xaxis_title="alpha (log scale)",
    yaxis_title="accuracy",
    width=800,
    height=450,
)
fig.update_xaxes(type="log")
fig.show()


### 4.1 `sklearn` MultinomialNB parameters

`MultinomialNB(alpha=1.0, force_alpha=True, fit_prior=True, class_prior=None)`

- `alpha`: additive smoothing strength
- `force_alpha`: if `False`, may clamp tiny `alpha` values for numeric stability
- `fit_prior`: learn class priors from data
- `class_prior`: set priors manually (overrides `fit_prior`)

Rules of thumb:
- try `alpha` in `[0.01, 1.0]` for text
- use `class_prior` when you know deployment base rates differ from training


## 5) Bernoulli Naive Bayes (binary features)

Bernoulli NB is like Multinomial NB, but it cares about **presence/absence** rather than counts.

For binary $x_j \in \{0,1\}$:

$$
P(x \mid y=c) = \prod_{j=1}^{V} p_{c,j}^{x_j} (1-p_{c,j})^{1-x_j}
$$

When does Bernoulli NB shine?
- when word **frequency** is less important than word **presence**
- when you want to explicitly model “word not present” as evidence

In `sklearn`, `BernoulliNB` also supports a `binarize` threshold that turns counts into 0/1.


In [None]:
# Compare BernoulliNB vs MultinomialNB on the same synthetic text

X_tr_bin = (X_tr_t > 0).astype(int)
X_te_bin = (X_te_t > 0).astype(int)

m_mnb = MultinomialNB(alpha=1.0).fit(X_tr_t, y_tr_t)
m_bnb = BernoulliNB(alpha=1.0, binarize=None).fit(X_tr_bin, y_tr_t)

acc_mnb = accuracy_score(y_te_t, m_mnb.predict(X_te_t))
acc_bnb = accuracy_score(y_te_t, m_bnb.predict(X_te_bin))

fig = go.Figure()
fig.add_trace(go.Bar(x=["MultinomialNB (counts)", "BernoulliNB (binary)"] , y=[acc_mnb, acc_bnb]))
fig.update_layout(title="Counts vs binary: accuracy comparison", yaxis_title="accuracy", width=700, height=420)
fig.show()

print("MultinomialNB accuracy:", acc_mnb)
print("BernoulliNB accuracy  :", acc_bnb)


## 6) Complement Naive Bayes (imbalanced text)

Complement NB was designed for text classification when classes are imbalanced.

Idea (intuition):
- Instead of modeling “what class *c* looks like”, model “what *not c* looks like” (the **complement**).
- Then classify by picking the class whose complement is *least compatible* with the document.

In practice, `ComplementNB` often improves performance on imbalanced text datasets.


In [None]:
# Make an imbalanced dataset (class 1 is rare)
X_counts_imb, y_imb, _, _, _ = make_synthetic_text_dataset(n_docs=6000, vocab_size=40, avg_len=60, imbalance=0.1, seed=11)
X_tr_i, X_te_i, y_tr_i, y_te_i = train_test_split(X_counts_imb, y_imb, test_size=0.3, random_state=7, stratify=y_imb)

m_mnb_i = MultinomialNB(alpha=1.0).fit(X_tr_i, y_tr_i)
m_cnb_i = ComplementNB(alpha=1.0).fit(X_tr_i, y_tr_i)

pred_mnb = m_mnb_i.predict(X_te_i)
pred_cnb = m_cnb_i.predict(X_te_i)

print("Class balance (test):", np.bincount(y_te_i) / y_te_i.size)
print()
print("MultinomialNB report:")
print(classification_report(y_te_i, pred_mnb, digits=3))
print("ComplementNB report:")
print(classification_report(y_te_i, pred_cnb, digits=3))


### 6.1 `sklearn` ComplementNB parameters

`ComplementNB(alpha=1.0, force_alpha=True, fit_prior=True, class_prior=None, norm=False)`

- `norm`: if `True`, normalizes weights; sometimes helps.

ComplementNB is typically used for **classification** (not regression).


## 7) Categorical Naive Bayes (discrete categories)

`CategoricalNB` is for features like:
- color ∈ {red, green, blue}
- browser ∈ {chrome, safari, firefox}
- country ∈ {DE, FR, US, ...}

Each feature is an integer code representing a category.

For each class and feature, we estimate a categorical probability table.

`CategoricalNB` is **not** the same as one-hot encoding + MultinomialNB.
It’s a direct model of per-feature categorical distributions.


In [None]:
# Toy categorical dataset
# Features: [weather, transport]
# weather: 0=sunny,1=rainy,2=overcast
# transport: 0=car,1=bus,2=bike
# Label: 1=go_out, 0=stay_in

weather = rng.integers(0, 3, size=800)
transport = rng.integers(0, 3, size=800)

# Make a slightly structured rule with noise
p_go_out = (
    0.15
    + 0.25 * (weather == 0)  # sunny
    + 0.10 * (weather == 2)  # overcast
    + 0.15 * (transport == 2)  # bike
    - 0.15 * (weather == 1)  # rainy
)
p_go_out = np.clip(p_go_out, 0.05, 0.95)

y_cat = (rng.random(800) < p_go_out).astype(int)
X_cat = np.c_[weather, transport]

X_tr_c, X_te_c, y_tr_c, y_te_c = train_test_split(X_cat, y_cat, test_size=0.3, random_state=7, stratify=y_cat)

m_cat = CategoricalNB(alpha=1.0).fit(X_tr_c, y_tr_c)
acc_cat = accuracy_score(y_te_c, m_cat.predict(X_te_c))
print("CategoricalNB accuracy:", acc_cat)

# Visualize predicted P(go_out=1) for each combination
combos = np.array([(w, t) for w in range(3) for t in range(3)])
proba = m_cat.predict_proba(combos)[:, 1]

labels_weather = ["sunny", "rainy", "overcast"]
labels_transport = ["car", "bus", "bike"]

z = proba.reshape(3, 3)

fig = go.Figure(data=go.Heatmap(
    z=z,
    x=labels_transport,
    y=labels_weather,
    colorscale="Blues",
    colorbar=dict(title="P(go_out=1)"),
))
fig.update_layout(title="CategoricalNB: predicted probability table", width=650, height=450)
fig.show()


### 7.1 `sklearn` CategoricalNB parameters

`CategoricalNB(alpha=1.0, force_alpha=True, fit_prior=True, class_prior=None, min_categories=None)`

- `min_categories`: force each feature to have at least this many categories (useful if some categories are missing in training).


## 8) Out-of-core naive Bayes model fitting (`partial_fit`)

Sometimes your dataset is too large to fit in memory.

Naive Bayes is great here because you can train it incrementally:

- stream data in batches
- call `partial_fit` repeatedly
- the model updates its sufficient statistics (counts / means / variances)

Important details:
- On the **first** `partial_fit`, you must pass `classes=np.array([...])`.
- `partial_fit` is available for several NB variants (including MultinomialNB and GaussianNB).


In [None]:
def stream_synthetic_text_batches(
    n_docs: int,
    vocab_size: int,
    avg_len: int,
    class_prior: float,
    seed: int,
    batch_size: int,
):
    r = np.random.default_rng(seed)

    # fixed class word distributions
    base0 = r.random(vocab_size)
    base1 = r.random(vocab_size)
    base0 = base0 / base0.sum()
    base1 = base1 / base1.sum()
    theta0 = r.dirichlet(30 * base0 + 1.0)
    theta1 = r.dirichlet(30 * base1 + 1.0)

    n_batches = (n_docs + batch_size - 1) // batch_size
    for b in range(n_batches):
        m = min(batch_size, n_docs - b * batch_size)
        y = (r.random(m) < class_prior).astype(int)
        lengths = r.poisson(lam=avg_len, size=m) + 5
        X = np.zeros((m, vocab_size), dtype=int)
        for i in range(m):
            theta = theta1 if y[i] == 1 else theta0
            X[i] = r.multinomial(n=int(lengths[i]), pvals=theta)
        yield X, y


# Fixed test set
X_test_stream, y_test_stream, _, _, _ = make_synthetic_text_dataset(
    n_docs=2500, vocab_size=80, avg_len=70, imbalance=0.2, seed=123
)

# Stream training batches
batch_size = 400
stream = stream_synthetic_text_batches(
    n_docs=12000,
    vocab_size=80,
    avg_len=70,
    class_prior=0.2,
    seed=999,
    batch_size=batch_size,
)

m_stream = MultinomialNB(alpha=0.5)
classes = np.array([0, 1])

seen = 0
checkpoints = []
accs = []

for X_batch, y_batch in stream:
    if seen == 0:
        m_stream.partial_fit(X_batch, y_batch, classes=classes)
    else:
        m_stream.partial_fit(X_batch, y_batch)

    seen += X_batch.shape[0]

    if seen % (batch_size * 2) == 0:
        y_pred = m_stream.predict(X_test_stream)
        checkpoints.append(seen)
        accs.append(accuracy_score(y_test_stream, y_pred))

fig = go.Figure()
fig.add_trace(go.Scatter(x=checkpoints, y=accs, mode="lines+markers"))
fig.update_layout(
    title="Out-of-core MultinomialNB: accuracy vs streamed samples",
    xaxis_title="# samples seen",
    yaxis_title="accuracy on fixed test set",
    width=850,
    height=450,
)
fig.show()


## Summary: choosing the right Naive Bayes

- **GaussianNB**: continuous features; surprisingly strong baseline for numeric data.
- **MultinomialNB**: count data (text word counts, event counts).
- **BernoulliNB**: binary features (word present/absent).
- **ComplementNB**: often better than MultinomialNB on **imbalanced** text.
- **CategoricalNB**: discrete categorical features (integer-coded categories).

## Exercises
1. Create a dataset where features are highly correlated and see how GaussianNB degrades.
2. For MultinomialNB, plot how `alpha` changes the *top words* for each class.
3. For BernoulliNB, compare `binarize=0`, `binarize=1`, and manual binarization.
4. Stream batches with `partial_fit` and compare to a single `fit`.
