In [None]:
import os
import sys
import scipy
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import scipy.io as sio
import anndata as ad
import matplotlib.pyplot as plt

os.chdir("/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/AD")

# 1. Cell-level

In [None]:
adata = sc.read_h5ad("./ensemble_adata_with_attention.h5ad")
adata

In [None]:
metadata = adata.obs
metadata.to_csv("./AD_metadata.csv")

In [None]:
idList = adata.obs['sample_id'].unique()
for id in idList:
    attnTmp = adata.obs[adata.obs['sample_id'] == id]
    avgScore = 1 / len(attnTmp)
    log_attn = np.log2(attnTmp['attention_weight_mean'] / avgScore)
    attn_scaled = (log_attn - np.mean(log_attn)) / np.std(log_attn)
    attn_scaled_clipped = np.clip(attn_scaled, -1, 1)
    adata.obs.loc[adata.obs['sample_id'] == id, 'attn_scaled'] = attn_scaled_clipped

In [None]:
adata1 = adata[adata.obs['phenotype'] == "Not AD"]
adata2 = adata[adata.obs['phenotype'] == "Low"]
adata3 = adata[adata.obs['phenotype'] == "Intermediate"]
adata4 = adata[adata.obs['phenotype'] == "High"]    


In [None]:
sc.settings.verbosity = 1
sc.settings.figdir = '/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/AD/Analysis/Figure/Attn_plot'
sc.settings.set_figure_params(dpi=100, fontsize=10, dpi_save=400,
    facecolor = 'white', figsize=(6,6), format='png')
def one_col_lgd(umap,ncol):
    legend = umap.legend(bbox_to_anchor=[1.00, 0.5],
    loc='center left', ncol=ncol, prop={'size': 6})
    legend.get_frame().set_linewidth(0.0)
    return legend

In [None]:
leiden_umap = sc.pl.umap(adata, color=['celltype'],
    show=False, palette=sns.color_palette("husl", 12),
legend_fontsize=6, frameon=True, title='celltype')
lgd = one_col_lgd(leiden_umap)
fig = leiden_umap.get_figure()
fig.set_size_inches(5, 5)
fig.savefig(str(sc.settings.figdir) + '/umap_celltype.pdf', 
            format='pdf', bbox_extra_artists=(lgd,), bbox_inches='tight')

In [None]:
sc.pl.umap(adata, color=['celltype'], legend_loc="on data")

In [None]:
leiden_umap = sc.pl.umap(adata, color=['Supertype'],
    show=False, palette=sns.color_palette("husl", 150),
legend_fontsize=6, frameon=True, title='celltype')
lgd = one_col_lgd(leiden_umap,5)
fig = leiden_umap.get_figure()
fig.set_size_inches(5, 5)
fig.savefig(str(sc.settings.figdir) + '/umap_Supertype.pdf', 
            format='pdf', bbox_extra_artists=(lgd,), bbox_inches='tight')

In [None]:
leiden_umap = sc.pl.umap(adata1, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of normal',save="_attn_normal.pdf")

In [None]:
leiden_umap = sc.pl.umap(adata2, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of Low',save="_attn_Low.pdf")

In [None]:
leiden_umap = sc.pl.umap(adata3, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of Intermediate',save="_attn_Intermediate.pdf")


In [None]:
leiden_umap = sc.pl.umap(adata4, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of High',save="_attn_High.pdf")


# 2. Celltype-level

In [None]:
setwd('/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/AD')
# rm(list = ls())
# gc()

library(tidyr)
library(ggplot2)
library(forestploter)
library(gridExtra)
library(tidyverse)
library(dplyr)
library(broom)
library(ggpubr)
library(randomForest)
library(mice)
library(reshape2)
library(Metrics)
library(ComplexHeatmap)
library(RColorBrewer)
library(fastcluster)
library(ggbeeswarm)
library(circlize)
library(ggrepel)
library(ggpubr)
library(scales) 
library(ggtext)
library(lme4)
library(lmerTest)
library(ordinal)  

### 2.1 Scatter trend plot

In [None]:
attnData = read.csv('./AD_metadata.csv')
head(attnData)
colnames(attnData)

nameAll = unique(attnData$celltype)
nameAll = sort(nameAll)
nameAll

nameList = nameAll


sampleFold = data.frame(id = character(), celltype = character(), fold = numeric())
idList = unique(attnData$sample_id)
for (id in idList) {
  attnTmp = attnData %>% filter(sample_id == id)
  avgScore = 1 / dim(attnTmp)[1]
  foldRes = attnTmp %>% group_by(celltype) %>% summarise(res = median(log2(attention_weight_mean/avgScore)))
  dataTmp = data.frame(id = rep(id,dim(foldRes)[1]),
                       celltype = foldRes$celltype,
                       fold = foldRes$res)
  dataTmp_s = dataTmp %>% filter(celltype %in% nameList)
  
  dataTmp_s$fold = scale(dataTmp_s$fold)
  sampleFold = rbind(sampleFold,dataTmp_s)
}

sampleFold.Table = dcast(sampleFold,id ~ celltype)
rownames(sampleFold.Table) = sampleFold.Table$id
sampleFold.Table$id = NULL

sampleInfo = read.csv('/data/wuqinhua/phase_1r/AD/sample_info.csv')
rownames(sampleInfo) = sampleInfo$sample_id
sampleInfo = sampleInfo[rownames(sampleFold.Table),]
head(sampleFold.Table)

In [None]:
# ===================================================================
# 1. Calculation Score 1: Association with Categorical Phenotypes (Linear Mixed-Effects Model)
# ===================================================================

modeling_data <- merge(sampleFold.Table, sampleInfo, by.x = "row.names", by.y = "sample_id")

rownames(modeling_data) <- modeling_data$Row.names
modeling_data$Row.names <- NULL

modeling_data$phenotype <- as.factor(modeling_data$phenotype)

phenotype_association_scores_lmm <- data.frame(
  cell_type = character(),
  p_value = numeric()
)

for (cell in colnames(sampleFold.Table)) {

  formula <- as.formula(paste0("`", cell, "` ~ phenotype + (1 | donor_id)"))

  model_fit <- try(lmer(formula, data = modeling_data))

  if (inherits(model_fit, "try-error")) {
    p_val <- NA
  } else {
    model_summary <- summary(model_fit)
    p_val <- anova(model_fit)$`Pr(>F)`[1] 
  }
  
  phenotype_association_scores_lmm <- rbind(
    phenotype_association_scores_lmm,
    data.frame(
      cell_type = cell,
      p_value = p_val
    )
  )
}

phenotype_association_scores_lmm <- phenotype_association_scores_lmm %>%
  mutate(neg_log10_pval = -log10(p_value))

head(phenotype_association_scores_lmm)

# ===================================================================
# 2. Calculation Score 2: Population Volatility (standard deviation considering donor effect)
# ===================================================================
long_format_data <- sampleFold.Table %>%
  rownames_to_column(var = "sample_id") %>%
  pivot_longer(
    cols = -sample_id,
    names_to = "cell_type",
    values_to = "score"
  ) %>%
  left_join(sampleInfo %>% select(sample_id, donor_id), by = "sample_id")

head(long_format_data)

donor_variation_scores <- data.frame(
  cell_type = character(),
  donor_effect_sd = numeric())

for (cell in colnames(sampleFold.Table)) {

  subset_data <- long_format_data %>%
    filter(cell_type == cell) %>%
    na.omit()

  model_fit <- try(lmer(score ~ (1 | donor_id), data = subset_data))
  
  sd_val <- NA 
  
  if (!inherits(model_fit, "try-error")) {
    var_corr <- as.data.frame(VarCorr(model_fit))
    sd_val <- var_corr[var_corr$grp == "donor_id", "sdcor"]
  }
  
  donor_variation_scores <- rbind(
    donor_variation_scores,
    data.frame(
      cell_type = cell,
      donor_effect_sd = sd_val
    )
  )
}

head(donor_variation_scores)

In [None]:
# ===================================================================
# 3. Integrate the scores and create "key scores" for coloring
# ===================================================================
plot_data_combined_score <- merge(
  phenotype_association_scores_lmm, 
  donor_variation_scores, 
  by = "cell_type"
) %>%
  mutate(
    scaled_pval = rescale(neg_log10_pval, to = c(0, 1)),
    scaled_sd = rescale(donor_effect_sd, to = c(0, 1)),
    keyness_score = scaled_pval * scaled_sd,
    label_text = ifelse(keyness_score > quantile(keyness_score, 0.6, na.rm = TRUE), cell_type, "")
  )

head(plot_data_combined_score)

# ===================================================================
# 4. Plot
# ===================================================================
combined_score_plot <- ggplot(plot_data_combined_score, 
                              aes(x = neg_log10_pval, 
                                  y = donor_effect_sd, 
                                  color = keyness_score, 
                                  label = label_text)) +
  geom_point(size = 6, alpha = 0.8) +
  scale_color_distiller(palette = "Reds", direction = 1, name = "Keyness Score") +
  geom_text_repel(color = "black", size = 4, max.overlaps = Inf,
                  box.padding = 0.6, min.segment.length = 0) +
  geom_vline(xintercept = -log10(0.05), linetype = "dashed", color = "gray50") +
  theme_bw(base_size = 12) +
  labs(
    title = "Identifying Key Cell Types by Combined Score (with Donor Effect)",
    subtitle = "Color intensity reflects the composite score of association and variation",
    x = "Strength of Association with Phenotype (-log10 P-value from LMM)",
    y = "Overall Variation (Standard Deviation)"
  ) +
  theme(
    plot.title = element_text(hjust = 0.5, face = "bold", size = 12),
    plot.subtitle = element_text(hjust = 0.5, size = 10, color = "gray30"),
    legend.position = "right"
  )

ggsave("./Analysis/Figure/Attn_plot/classification_combined_score_plot_with_donor_effect.pdf", 
       combined_score_plot, 
       width = 8, height = 5, dpi = 300)

print(combined_score_plot)

### 2.2 Violin Plot

In [None]:
attnData = read.csv('./AD_metadata.csv')
head(attnData)
colnames(attnData)

nameAll = unique(attnData$Supertype)
nameAll = sort(nameAll)
nameAll

nameList = nameAll

sampleFold = data.frame(id = character(), celltype = character(), fold = numeric())
idList = unique(attnData$sample_id)
for (id in idList) {
  attnTmp = attnData %>% filter(sample_id == id)
  avgScore = 1 / dim(attnTmp)[1]
  foldRes = attnTmp %>% group_by(Supertype) %>% summarise(res = median(log2(attention_weight_mean/avgScore)))
  dataTmp = data.frame(id = rep(id,dim(foldRes)[1]),
                       celltype = foldRes$Supertype,
                       fold = foldRes$res)
  dataTmp_s = dataTmp %>% filter(celltype %in% nameList)
  
  dataTmp_s$fold = scale(dataTmp_s$fold)
  sampleFold = rbind(sampleFold,dataTmp_s)
}

sampleFold.Table = dcast(sampleFold,id ~ celltype)
rownames(sampleFold.Table) = sampleFold.Table$id
sampleFold.Table$id = NULL

sampleInfo = read.csv('/data/wuqinhua/phase_1r/AD/sample_info.csv')
rownames(sampleInfo) = sampleInfo$sample_id
sampleInfo = sampleInfo[rownames(sampleFold.Table),]
head(sampleFold.Table)

cell_types_attn <- colnames(sampleFold.Table)
combined_data_atten <- data.frame()
for (cell in cell_types_attn) {
  dataTmp <- data.frame(
    atten = sampleFold.Table[[cell]],
    group = sampleInfo$phenotype,
    donor_id = sampleInfo$donor_id,
    cell_type = cell
  )
  dataTmp_s <- na.omit(dataTmp)
  combined_data_atten <- rbind(combined_data_atten, dataTmp_s)
}

custom_colors <- c("Not AD" = '#66DD00',"Low"="#77DDFF","Intermediate"="#FFBB66",'High'='#bb7ae9')

plot_df_attn <- combined_data_atten %>%
  dplyr::rename(Attn = atten) %>%
  dplyr::mutate(group = factor(group, levels = names(custom_colors))) %>%
  na.omit()

annotation_data_lmm <- plot_df_attn %>%
  group_by(cell_type) %>%
  do(broom::tidy(anova(lmer(Attn ~ group + (1 | donor_id), data = .)))) %>%
  ungroup() %>%
  filter(term == "group") %>%
  mutate(
    significance = case_when(
      p.value < 0.001 ~ "***",
      p.value < 0.01  ~ "**",
      p.value < 0.05  ~ "*",
      TRUE ~ "ns"
    ),
    title_with_sig = paste0(cell_type, " ", significance),
    
    stat_label = paste0(
        "F = ", round(statistic, 2),      
        ", ",                             
        scales::pvalue(p.value,           
                       accuracy = 0.001, 
                       add_p = TRUE)
    )
  )

plot_df_attn_final <- plot_df_attn %>%
  left_join(annotation_data_lmm %>% select(cell_type, title_with_sig), by = "cell_type") %>%
  filter(!is.na(title_with_sig))


p_facet_classification_updated <- ggplot(plot_df_attn_final, 
                                         aes(x = group, y = Attn, fill = group)) +
  
  geom_violin(trim = FALSE, alpha = 0.8) +
  geom_boxplot(width = 0.15, fill = "white", alpha = 0.7, outlier.shape = NA) +
  
  geom_text(
    data = annotation_data_lmm, 
    aes(label = stat_label), 
    x = -Inf, y = Inf,
    hjust = -0.1, vjust = 2,
    size = 3.5,
    fontface = "italic",
    inherit.aes = FALSE
  ) + 
  
  scale_fill_manual(values = custom_colors, name = "Phenotype") +

  facet_wrap(~title_with_sig, scales = "free",ncol = 10) +
  
  labs(
    title = "Attention Score Distribution (accounting for Donor Effect)",
    y = "Standardized Attention Score"
  ) +
  
  theme_classic(base_size = 12) +
  theme(
    panel.grid.major.x = element_blank(),
    panel.grid.minor = element_blank(),
    strip.background = element_rect(fill = "#EFEFEF", color = NA),
    strip.text = element_text(size = 10, face = "bold", color = "#333333"),
    axis.text.x = element_text(size = 9, angle = 45, hjust = 1),
    axis.text.y = element_text(size = 8),
    axis.title = element_text(size = 10, face = "bold"),
    legend.position = "right",
    axis.title.x = element_blank()
  )

ggsave("./Analysis/Figure/Attn_plot/attn_classification_violin_plot_with_donor_effect_Supertype.pdf",
       p_facet_classification_updated, width = 30, height = 45, dpi = 300)

print(p_facet_classification_updated)

### 2.3 Ordered logistic regression

In [None]:
attnData = read.csv('./AD_metadata.csv')
head(attnData)
colnames(attnData)

nameAll = unique(attnData$Supertype)
nameAll = sort(nameAll)
nameAll

nameList = nameAll

sampleFold = data.frame(id = character(), celltype = character(), fold = numeric())
idList = unique(attnData$sample_id)
for (id in idList) {
  attnTmp = attnData %>% filter(sample_id == id)
  avgScore = 1 / dim(attnTmp)[1]
  foldRes = attnTmp %>% group_by(Supertype) %>% summarise(res = median(log2(attention_weight_mean/avgScore)))
  dataTmp = data.frame(id = rep(id,dim(foldRes)[1]),
                       celltype = foldRes$Supertype,
                       fold = foldRes$res)
  dataTmp_s = dataTmp %>% filter(celltype %in% nameList)
  
  dataTmp_s$fold = scale(dataTmp_s$fold)
  sampleFold = rbind(sampleFold,dataTmp_s)
}

sampleFold.Table = dcast(sampleFold,id ~ celltype)
rownames(sampleFold.Table) = sampleFold.Table$id
sampleFold.Table$id = NULL

sampleInfo = read.csv('/data/wuqinhua/phase_1r/AD/sample_info.csv')
rownames(sampleInfo) = sampleInfo$sample_id
sampleInfo = sampleInfo[rownames(sampleFold.Table),]
head(sampleFold.Table)

In [None]:
phenotype_levels <- c("Not AD", "Low", "Intermediate", "High")
sampleInfo$phenotype_ordered <- factor(
  sampleInfo$phenotype, 
  levels = phenotype_levels, 
  ordered = TRUE
)

print(levels(sampleInfo$phenotype_ordered))
print(class(sampleInfo$phenotype_ordered))

ordinal_data_long <- sampleFold.Table %>%
  tibble::rownames_to_column("sample_id") %>%
  pivot_longer(
    cols = -sample_id,
    names_to = "Supertype",
    values_to = "attn_value"
  ) %>%
  left_join(
    sampleInfo %>% select(sample_id, donor_id, phenotype_ordered), 
    by = "sample_id"
  ) %>%
  na.omit()

print(head(ordinal_data_long))

In [None]:
cell_type_list <- unique(ordinal_data_long$Supertype)
results_list <- list()
MIN_SAMPLES_PER_LEVEL <- 10


for (ct in cell_type_list) {
  subset_data <- ordinal_data_long %>% filter(Supertype == ct)
  level_counts <- table(subset_data$phenotype_ordered)
  if (any(level_counts < MIN_SAMPLES_PER_LEVEL) || nrow(subset_data) < 50) {
      cat(paste("celltype:", ct, "The data is insufficient. Skipped.\n"))
      next 
  }

  cat(paste("Processing cell type:", ct, "\n"))
  
  model_control <- clmm.control(
    maxIter = 1000,        
    gradTol = 1e-5       
  )

  model_fit <- try({
    clmm(
      phenotype_ordered ~ attn_value + (1 | donor_id), 
      data = subset_data,
      control = model_control 
    )
  }, silent = TRUE)  
  
  if (!inherits(model_fit, "try-error")) {
    summary_fit <- summary(model_fit)
    coef_table <- summary_fit$coefficients
    attn_coef_row <- coef_table[rownames(coef_table) == "attn_value", ]
    
    conf_int_result <- try(
      confint(model_fit, parm = "attn_value"),  
      silent = TRUE
    )
    if (inherits(conf_int_result, "try-error") || nrow(conf_int_result) == 0) {
        cat(paste("celltype:", ct, "The confidence interval cannot be calculated and has been skipped.\n"))
        next
    }
   
    odds_ratio <- exp(attn_coef_row["Estimate"])
    conf_int <- exp(conf_int_result)  
    
    results_list[[ct]] <- data.frame(
      Supertype = ct,
      odds_ratio = odds_ratio,
      ci_low = conf_int[1],    
      ci_high = conf_int[2],   
      p_value = attn_coef_row["Pr(>|z|)"],
      stringsAsFactors = FALSE
    )
  } else {
    cat(paste("celltype:", ct, "Model fitting failed (it may still not converge. You can try to increase maxIter)\n"))
  }
}

if (length(results_list) > 0) {
  model_results_df <- do.call(rbind, results_list)
  rownames(model_results_df) <- NULL  
} else {
  model_results_df <- data.frame()  
  cat("No model was successfully fitted for all cell types.\n")
}

In [None]:
model_results_df <- do.call(rbind, results_list)
rownames(model_results_df) <- NULL

forest_plot_data <- model_results_df %>%
  mutate(
    cell_type_sorted = fct_reorder(Supertype, odds_ratio),
    significance = ifelse(p_value < 0.05, "p < 0.05", "p ≥ 0.05")
  )

supertype_map <- attnData %>%
  select(Supertype, celltype) %>%
  distinct()
  
effect_plot_data <- left_join(forest_plot_data, supertype_map, by = "Supertype") %>%
  mutate(
    log_or = log(odds_ratio),
    log_ci_low = log(ci_low),
    log_ci_high = log(ci_high),
    label_color = ifelse(p_value < 0.05, "red", "black")
  ) %>%
  na.omit()

celltype_order <- c("Lamp5", "Pax6", "Sncg", "Vip", "Sst", "Sst Chodl", "Pvalb", "Pvalb and chandelier", 
                    "L2/3 IT", "L4 IT", "L5 IT", "L5 ET", "L6 IT", "L6 CT", "L6b", "L5/6 NP", "L6 IT Car3",
                    "Astrocyte", "OPC", "Oligodendrocyte", "VLMC", "Microglia and immune", "Endothelial")
existing_levels <- intersect(celltype_order, unique(effect_plot_data$celltype))
effect_plot_data$celltype <- factor(effect_plot_data$celltype, levels = existing_levels)

effect_plot_data <- effect_plot_data %>%
  arrange(celltype, Supertype)

axis_label_colors <- effect_plot_data$label_color

effect_plot_data$Supertype <- factor(effect_plot_data$Supertype, levels = unique(effect_plot_data$Supertype))

final_plot_no_errorbars <- ggplot(
  data = effect_plot_data, 
  aes(x = Supertype, y = log_or, fill = celltype)
) +
  geom_col(width = 0.7) +
  
  geom_hline(yintercept = 0, linetype = "solid", color = "black", linewidth = 0.5) +
  
  facet_grid(. ~ celltype, scales = "free_x", space = "free_x") +
  
  labs(
    y = "Effect Size (log Odds Ratio)",
    x = NULL
  ) +

  theme_classic(base_size = 12) +
  theme(
    axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, color = axis_label_colors),
    axis.ticks.x = element_blank(),
    legend.position = "none",
    strip.background = element_rect(fill = "grey90", color = NA),
    strip.text = element_text(face = "bold", size = 12),
    panel.spacing.x = unit(0, "lines"),
    panel.border = element_rect(color = "grey70", linetype = "dashed", fill = NA)
  )

ggsave("./Analysis/Figure/Attn_plot/final_plot_no_errorbars.pdf",
       final_plot_no_errorbars, width = 25, height = 8, dpi = 300)

print(final_plot_no_errorbars)