# NMF-based sample grouping
related to Figure.1 & sup fig.1



**Purpose.** This notebook performs non-negative matrix factorization (NMF) on a precomputed sample-by-feature matrix to derive sample groups and visualize the resulting programs, and (optionally) evaluates survival differences across the NMF groups.

**Input.**  `./data/`.

**Outputs.** All figures/tables/models are saved under `./outputs/nmf/` (see the *Results & exports* section).


In [None]:
## =========================
## CONFIG (paths + key params)
## =========================
options(stringsAsFactors = FALSE)
options(width = 120)

# Reproducibility
SEED <- 550

# Paths (all relative)
INPUT_DIR  <- file.path(".", "data")
OUTPUT_DIR <- file.path(".", "outputs", "nmf")
dir.create(OUTPUT_DIR, recursive = TRUE, showWarnings = FALSE)

# Input
# - Original source: CIBERSORTx deconvolution results on TCGA-LIHC bulk RNA-seq,
#   using a previously annotated scRNA-seq reference (cell-type signatures defined upstream).
# - The raw CIBERSORTx output is a sample × cell-type fraction matrix.
# - For NMF, we use 'scale_ratio', a transformed/scaled version of the fraction matrix
#   generated offline (non-negative; may be samples × features or features × samples).
#
# The .RData must contain at minimum:
#   - scale_ratio: numeric matrix/data.frame (samples × features OR features × samples)
#   - matrix_merged_full: data.frame with columns: sampleID, days, event
INPUT_RDATA <- file.path(INPUT_DIR, "nmf_inputs.RData")

# NMF parameters
RANKS_TO_TRY     <- 2:15
NMF_RANK_FINAL   <- 4
NMF_NRUN_SURVEY  <- 100
NMF_NRUN_FINAL   <- 500
NMF_METHOD       <- "lee"

# Outputs
OUT_RANK_SURVEY_PDF <- file.path(OUTPUT_DIR, "nmf_rank_survey.pdf")
OUT_MODEL_RDS       <- file.path(OUTPUT_DIR, sprintf("nmf_model_rank%d.rds", NMF_RANK_FINAL))
OUT_ASSIGNMENT_CSV  <- file.path(OUTPUT_DIR, "TCGA_NMF_group_assignment.csv")
OUT_HEATMAP_PDF     <- file.path(OUTPUT_DIR, "nmf_heatmap.pdf")
OUT_KM_PDF          <- file.path(OUTPUT_DIR, "KM_nmf_group.pdf")
OUT_MEDIAN_CSV      <- file.path(OUTPUT_DIR, "median_survival_by_nmf_group.csv")


In [None]:
## =========================
## Reproducibility: seeds + packages + versions
## =========================
RNGkind(kind = "Mersenne-Twister", normal.kind = "Inversion", sample.kind = "Rounding")
set.seed(SEED)

suppressPackageStartupMessages({
  library(NMF)
  library(dplyr)
  library(tibble)
  library(readr)
  library(ComplexHeatmap)
  library(circlize)
  library(grid)
  library(survival)
  library(survminer)
})

cat("Session info:\n")
print(sessionInfo())

pkgs <- c("NMF","ComplexHeatmap","circlize","dplyr","readr","survival","survminer")
ver_tbl <- tibble(
  package = pkgs,
  version = sapply(pkgs, function(p) as.character(packageVersion(p)))
)
print(ver_tbl)


## Data loading

This notebook expects an author-prepared `./data/nmf_inputs.RData` that contains:
- `scale_ratio`: numeric matrix (either **samples × features** or **features × samples**)
- `matrix_merged_full`: sample-level metadata table with **`sampleID`**, **`days`**, **`event`** (0/1 or "Dead"/"Alive")

The code below detects which axis corresponds to samples and converts the NMF input into **features × samples** (`X`).


In [None]:
## =========================
## Load inputs + minimal sanity checks
## =========================
stopifnot(file.exists(INPUT_RDATA))
load(INPUT_RDATA)

stopifnot(exists("scale_ratio"))
stopifnot(exists("matrix_merged_full"))

normalize_id <- function(x) {
  x <- as.character(x)
  gsub("\\.", "-", x)
}

scale_ratio <- as.data.frame(scale_ratio, check.names = FALSE)
stopifnot(!is.null(rownames(scale_ratio)) || !is.null(colnames(scale_ratio)))

matrix_merged_full <- as.data.frame(matrix_merged_full, check.names = FALSE)
required_cols <- c("sampleID", "days", "event")
missing_cols <- setdiff(required_cols, colnames(matrix_merged_full))
stopifnot(length(missing_cols) == 0)

meta_ids <- normalize_id(matrix_merged_full$sampleID)

row_ids <- if (!is.null(rownames(scale_ratio))) normalize_id(rownames(scale_ratio)) else character(0)
col_ids <- if (!is.null(colnames(scale_ratio))) normalize_id(colnames(scale_ratio)) else character(0)

row_match <- if (length(row_ids) > 0) mean(row_ids %in% meta_ids) else 0
col_match <- if (length(col_ids) > 0) mean(col_ids %in% meta_ids) else 0

if (row_match >= col_match) {
  # scale_ratio: samples x features  -> transpose to features x samples
  X <- t(as.matrix(scale_ratio))
  sample_ids <- rownames(scale_ratio)
} else {
  # scale_ratio: features x samples  -> keep as-is
  X <- as.matrix(scale_ratio)
  sample_ids <- colnames(scale_ratio)
}

stopifnot(!is.null(colnames(X)))

sample_ids_norm <- normalize_id(colnames(X))
match_idx <- match(sample_ids_norm, meta_ids)
stopifnot(!any(is.na(match_idx)))

meta <- matrix_merged_full[match_idx, , drop = FALSE]
meta$sampleID <- colnames(X)

meta$days  <- as.numeric(meta$days)
if (is.character(meta$event) || is.factor(meta$event)) {
  meta$event <- ifelse(as.character(meta$event) == "Dead", 1L, 0L)
}
meta$event <- as.integer(meta$event)

stopifnot(all(meta$event %in% c(0L, 1L)))


## Core analysis

1) **Rank survey** over `RANKS_TO_TRY` (cophenetic / dispersion, etc.)  
2) Fit **final NMF** at `NMF_RANK_FINAL` and assign each sample to the maximal coefficient component (`nmf_group`)  
3) Visualize standardized feature patterns as a heatmap (features × samples, split by `nmf_group`)  
4) (Optional) Kaplan–Meier survival curves by `nmf_group`


In [None]:
## =========================
## 1) NMF rank survey
## =========================
set.seed(SEED)

estim <- nmf(X, rank = RANKS_TO_TRY, nrun = NMF_NRUN_SURVEY, method = NMF_METHOD)

pdf(OUT_RANK_SURVEY_PDF, width = 10, height = 8, useDingbats = FALSE)
plot(estim)
dev.off()


In [None]:
## =========================
## 2) Fit final NMF + group assignment
## =========================
set.seed(SEED)

nmf_fit <- nmf(X, rank = NMF_RANK_FINAL, nrun = NMF_NRUN_FINAL, method = NMF_METHOD)

# Save fitted model
saveRDS(nmf_fit, OUT_MODEL_RDS)

H <- coef(nmf_fit)
stopifnot(all(colnames(H) == colnames(X)))

nmf_group <- apply(H, 2, which.max)  # 1..K
meta$nmf_group <- factor(nmf_group, levels = sort(unique(nmf_group)))

nmf_assignment <- tibble(
  sample_id = meta$sampleID,
  nmf_group = as.integer(as.character(meta$nmf_group))
)

write_csv(nmf_assignment, OUT_ASSIGNMENT_CSV)


In [None]:
## =========================
## 3) Heatmap (standardized patterns)
## =========================
# Standardize each feature across samples (features x samples)
z_mat <- t(scale(t(X))) / 4

W <- basis(nmf_fit)
feat_group <- apply(W, 1, which.max)
feat_order <- order(feat_group, -apply(W, 1, max))

# Order samples by NMF group, then by within-group coefficient strength
sample_order <- order(meta$nmf_group, -apply(t(H), 1, max))

z_mat_plot <- z_mat[feat_order, sample_order, drop = FALSE]
feat_group_plot <- factor(feat_group[feat_order], levels = sort(unique(feat_group)))
nmf_group_plot <- meta$nmf_group[sample_order]

ha <- HeatmapAnnotation(
  nmf_group = nmf_group_plot,
  annotation_name_side = "left",
  annotation_legend_param = list(nmf_group = list(title = "NMF Group"))
)

ht <- Heatmap(
  z_mat_plot,
  name            = "z",
  top_annotation  = ha,
  column_split    = nmf_group_plot,
  row_split       = feat_group_plot,
  cluster_rows    = FALSE,
  cluster_columns = FALSE,
  row_title       = NULL,
  column_title    = NULL,
  row_names_gp    = grid::gpar(fontsize = 8),
  column_names_gp = grid::gpar(fontsize = 5)
)

pdf(OUT_HEATMAP_PDF, width = 10, height = 8, useDingbats = FALSE)
draw(ht)
dev.off()


In [None]:
## =========================
## 4) Survival analysis by NMF group (Kaplan–Meier)
## =========================
df_surv <- meta %>%
  select(sampleID, days, event, nmf_group) %>%
  filter(!is.na(days), !is.na(event))

fit <- survfit(Surv(days, event) ~ nmf_group, data = df_surv)

p <- ggsurvplot(
  fit,
  data      = df_surv,
  pval      = TRUE,
  risk.table= TRUE,
  legend.title = "NMF Group"
)

# Export plot
pdf(OUT_KM_PDF, width = 8, height = 7, useDingbats = FALSE)
print(p)
dev.off()

# Export median survival table (days)
med_tbl <- surv_median(fit) %>%
  mutate(group = sub("^nmf_group=", "", strata)) %>%
  select(group, median, lower, upper)

write_csv(med_tbl, OUT_MEDIAN_CSV)
