Skip to content

predict() silently refits the factor matrix F on new data #64

@roi-meir

Description

@roi-meir

Description

When predict(fit, newdata) is called to project new data points onto a previously fitted topic model, the internal helper project_poisson_nmf calls fit_poisson_nmf without setting update.factors = NULL. As a result, fit_poisson_nmf uses its default update.factors = seq(1, ncol(X)) and runs full factor updates on the new data, overwriting the F matrix that was estimated on the training data. Only L should be updated during prediction; F must be held fixed.

Root cause
In R/predict.R, the call inside project_poisson_nmf was:

BEFORE (buggy)

fit <- fit_poisson_nmf(X, fit0 = fit, numiter = numiter, ...)
Without an explicit update.factors = NULL, the default update.factors = seq(1, ncol(X)) is forwarded to the update loop, which updates every row of F using the test samples. The fix is:

AFTER (fixed)

fit <- fit_poisson_nmf(X, fit0 = fit, numiter = numiter, update.factors = NULL, ...)

Note: fit_poisson_nmf also calls rescale.fit() before the update loop. This rescales F and L jointly (preserving their product) for numerical stability and happens identically in both cases — it is not the source of the bug.

Reproducible example

library(fastTopics)
set.seed(1)

# Simulate and split data
dat   <- simulate_multinom_gene_data(175, 1200, k = 3)
train <- dat$X[1:100, ]
test  <- dat$X[101:175, ]

# Fit topic model on training data
fit <- init_poisson_nmf(train, F = dat$F, init.method = "random")
fit <- fit_poisson_nmf(train, fit0 = fit, verbose = "none")
fit <- poisson2multinom(fit)

# Recreate project_poisson_nmf internals to compare both behaviours
F_pois  <- multinom2poisson(fit)$F
numiter <- 20
n <- nrow(test); k <- ncol(F_pois)
L0 <- matrix(1/k, n, k)
rownames(L0) <- rownames(test); colnames(L0) <- colnames(F_pois)
fit_init <- init_poisson_nmf(test, F = F_pois, L = L0)

# FIXED: update.factors = NULL — F structure is frozen
fit_fixed   <- fit_poisson_nmf(test, fit0 = fit_init, numiter = numiter,
                               update.factors = NULL, verbose = "none")

# BUGGY: no update.factors — F is refitted on test data
fit_unfixed <- fit_poisson_nmf(test, fit0 = fit_init, numiter = numiter,
                               verbose = "none")

# Compare F matrices (column-normalised = scale-free, structure only)
norm_col    <- function(M) sweep(M, 2, colSums(M), "/")
F_pois_n    <- norm_col(F_pois)
F_fixed_n   <- norm_col(fit_fixed$F)
F_unfixed_n <- norm_col(fit_unfixed$F)
cat("Max |F_fixed_n   - F_pois_n|:", max(abs(F_fixed_n   - F_pois_n)),  "\n")
cat("Max |F_unfixed_n - F_pois_n|:", max(abs(F_unfixed_n - F_pois_n)),  "\n")

# Confirm installed predict() follows the buggy path
L_fixed   <- poisson2multinom(fit_fixed)$L
L_unfixed <- poisson2multinom(fit_unfixed)$L
Ltest     <- predict(fit, test, numiter = numiter, verbose = "none")
cat("predict() matches BUGGY?", isTRUE(all.equal(Ltest, L_unfixed)), "\n")
cat("predict() matches FIXED?", isTRUE(all.equal(Ltest, L_fixed)),   "\n")
cat("Max |L_buggy - L_fixed|:", max(abs(L_unfixed - L_fixed)),        "\n")

Output:

Max |F_fixed_n   - F_pois_n| : 0          # F structure unchanged
Max |F_unfixed_n - F_pois_n| : 0.0028     # F structure changed by test updates
predict() matches BUGGY?  TRUE            # installed predict() has the bug
predict() matches FIXED?  FALSE
Max |L_buggy - L_fixed|: 0.3855           # ~0.39 error in topic proportions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions