In [None]:
import os
import sys
import scipy
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import scipy.io as sio
import anndata as ad
import matplotlib.pyplot as plt

os.chdir("/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/CRC")

# 1. Cell-level

In [None]:
adata = sc.read_h5ad("./ensemble_adata_with_attention.h5ad")
adata

In [None]:
metadata = adata.obs
metadata.to_csv("./CRC_metadata.csv")

In [None]:
idList = adata.obs['sample_id'].unique()
for id in idList:
    attnTmp = adata.obs[adata.obs['sample_id'] == id]
    avgScore = 1 / len(attnTmp)
    log_attn = np.log2(attnTmp['attention_weight_mean'] / avgScore)
    attn_scaled = (log_attn - np.mean(log_attn)) / np.std(log_attn)
    attn_scaled_clipped = np.clip(attn_scaled, -1, 1)
    adata.obs.loc[adata.obs['sample_id'] == id, 'attn_scaled'] = attn_scaled_clipped

In [None]:
adata1 = adata[adata.obs['phenotype'] == "normal"]
adata2 = adata[adata.obs['phenotype'] == "MMRd"]
adata3 = adata[adata.obs['phenotype'] == "MMRp"]

In [None]:
sc.settings.verbosity = 1
sc.settings.figdir = '/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/CRC/Analysis/Figure/Attn_plot'
sc.settings.set_figure_params(dpi=100, fontsize=10, dpi_save=400,
    facecolor = 'white', figsize=(6,6), format='png')
def one_col_lgd(umap,ncol):
    legend = umap.legend(bbox_to_anchor=[1.00, 0.5],
    loc='center left', ncol=ncol, prop={'size': 6})
    legend.get_frame().set_linewidth(0.0)
    return legend

In [None]:
leiden_umap = sc.pl.umap(adata, color='celltype', show=False, legend_fontsize=6, legend_loc='on data')

In [None]:
leiden_umap = sc.pl.umap(adata, color='cl295v11SubFull', show=False, legend_fontsize=3, legend_loc='on data')

In [None]:
leiden_umap = sc.pl.umap(adata, color=['cl295v11SubFull'],
    show=False, palette=sns.color_palette("husl", 90),
legend_fontsize=6, frameon=True, title='celltype')
lgd = one_col_lgd(leiden_umap,3)
fig = leiden_umap.get_figure()
fig.set_size_inches(6, 5)
fig.savefig(str(sc.settings.figdir) + '/umap_celltype.pdf', 
            format='pdf', bbox_extra_artists=(lgd,), bbox_inches='tight')

In [None]:
leiden_umap = sc.pl.umap(adata1, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of normal',save="_attn_normal.pdf")

In [None]:
leiden_umap = sc.pl.umap(adata2, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of MMRd',save="_attn_MMRd.pdf")

In [None]:
leiden_umap = sc.pl.umap(adata3, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of MMRp',save="_attn_MMRp.pdf")

# 2. Celltype-level

In [None]:
setwd('/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/CRC')
# rm(list = ls())
# gc()

library(tidyr)
library(ggplot2)
library(forestploter)
library(gridExtra)
library(tidyverse)
library(dplyr)
library(broom)
library(ggpubr)
library(randomForest)
library(mice)
library(reshape2)
library(Metrics)
library(ComplexHeatmap)
library(RColorBrewer)
library(fastcluster)
library(ggbeeswarm)
library(circlize)
library(ggrepel)
library(ggpubr)
library(scales) 
library(ggtext) 

### 2.1 Scatter trend plot

In [None]:
attnData = read.csv('./CRC_metadata.csv')
head(attnData)
colnames(attnData)

nameAll = unique(attnData$celltype)
nameAll = sort(nameAll)
nameAll

nameList = nameAll

sampleFold = data.frame(id = character(), celltype = character(), fold = numeric())
idList = unique(attnData$sample_id)
for (id in idList) {
  attnTmp = attnData %>% filter(sample_id == id)
  avgScore = 1 / dim(attnTmp)[1]
  foldRes = attnTmp %>% group_by(celltype) %>% summarise(res = median(log2(attention_weight_mean/avgScore)))
  dataTmp = data.frame(id = rep(id,dim(foldRes)[1]),
                       celltype = foldRes$celltype,
                       fold = foldRes$res)
  dataTmp_s = dataTmp %>% filter(celltype %in% nameList)
  
  dataTmp_s$fold = scale(dataTmp_s$fold)
  sampleFold = rbind(sampleFold,dataTmp_s)
}

sampleFold.Table = dcast(sampleFold,id ~ celltype)
rownames(sampleFold.Table) = sampleFold.Table$id
sampleFold.Table$id = NULL

sampleInfo = read.csv('/data/wuqinhua/phase_1r/CRC/sample_info.csv')
rownames(sampleInfo) = sampleInfo$sample
sampleInfo = sampleInfo[rownames(sampleFold.Table),]
head(sampleFold.Table)

In [None]:
# ===================================================================
# 1. Calculation Score One: Association with Categorical Phenotypes (Kruskal-Wallis Test)
# ===================================================================
sampleInfo$phenotype <- as.factor(sampleInfo$phenotype)
sampleInfo <- sampleInfo[rownames(sampleFold.Table), ]
phenotype_association_scores_class <- data.frame(
  cell_type = character(),
  kruskal_wallis_p_value = numeric()
)

for (cell in colnames(sampleFold.Table)) {
  test_data <- data.frame(
    score = sampleFold.Table[[cell]],
    group = sampleInfo$phenotype
  )

  test_result <- kruskal.test(score ~ group, data = test_data)
  

  phenotype_association_scores_class <- rbind(
    phenotype_association_scores_class,
    data.frame(
      cell_type = cell,
      kruskal_wallis_p_value = test_result$p.value
    )
  )
}

phenotype_association_scores_class <- phenotype_association_scores_class %>%
  mutate(neg_log10_pval = -log10(kruskal_wallis_p_value))

head(phenotype_association_scores_class)


# ===================================================================
# 2. Calculation Score Two: Group Volatility (Standard deviation)
# ===================================================================

overall_variation_scores <- data.frame(
  cell_type = character(),
  overall_sd = numeric()
)

for (cell in colnames(sampleFold.Table)) {
  sd_val <- sd(sampleFold.Table[[cell]], na.rm = TRUE)
  overall_variation_scores <- rbind(
    overall_variation_scores,
    data.frame(
      cell_type = cell,
      overall_sd = sd_val
    )
  )
}
head(overall_variation_scores)


# ===================================================================
# 3. Integrate the scores and create "key scores" for coloring
# ===================================================================

plot_data_combined_score <- merge(
  phenotype_association_scores_class, 
  overall_variation_scores, 
  by = "cell_type"
) %>%
  mutate(
    scaled_pval = rescale(neg_log10_pval, to = c(0, 1)),
    scaled_sd = rescale(overall_sd, to = c(0, 1)),
    keyness_score = scaled_pval * scaled_sd,
    label_text = ifelse(keyness_score > quantile(keyness_score, 0.6, na.rm = TRUE), cell_type, "")
  )

head(plot_data_combined_score)


# ===================================================================
# 4. Plot
# ===================================================================

combined_score_plot <- ggplot(plot_data_combined_score, 
                              aes(x = neg_log10_pval, 
                                  y = overall_sd, 
                                  color = keyness_score, 
                                  label = label_text)) +
  
  geom_point(size = 6, alpha = 0.8) +
  scale_color_distiller(palette = "Reds", direction = 1, name = "Keyness Score") +
  geom_text_repel(color = "black", size = 4, max.overlaps = Inf,
                  box.padding = 0.6, min.segment.length = 0) +
  geom_vline(xintercept = -log10(0.05), linetype = "dashed", color = "gray50") +
  theme_bw(base_size = 12) +

  labs(
    title = "Identifying Key Cell Types by Combined Score ",
    subtitle = "Color intensity reflects the composite score of association and variation",
    x = "Strength of Association with Phenotype (-log10 P-value)",
    y = "Overall Variation (Standard Deviation)"
  ) +
  
  theme(
    plot.title = element_text(hjust = 0.5, face = "bold", size = 12),
    plot.subtitle = element_text(hjust = 0.5, size = 10, color = "gray30"),
    legend.position = "right"
  )

ggsave("./Analysis/Figure/Attn_plot/classification_combined_score_plot.pdf", 
       combined_score_plot, 
       width = 6, height = 5, dpi = 300)

print(combined_score_plot)

### 2.2 Violin Plot

In [None]:
nameAll = unique(attnData$cl295v11SubFull)
nameAll = sort(nameAll)
nameAll

nameList = nameAll

sampleFold = data.frame(id = character(), celltype = character(), fold = numeric())
idList = unique(attnData$sample_id)
for (id in idList) {
  attnTmp = attnData %>% filter(sample_id == id)
  avgScore = 1 / dim(attnTmp)[1]
  foldRes = attnTmp %>% group_by(cl295v11SubFull) %>% summarise(res = median(log2(attention_weight_mean/avgScore)))
  dataTmp = data.frame(id = rep(id,dim(foldRes)[1]),
                       celltype = foldRes$cl295v11SubFull,
                       fold = foldRes$res)
  dataTmp_s = dataTmp %>% filter(celltype %in% nameList)
  dataTmp_s$fold = scale(dataTmp_s$fold)
  sampleFold = rbind(sampleFold,dataTmp_s)
}

sampleFold.Table = dcast(sampleFold,id ~ celltype)
rownames(sampleFold.Table) = sampleFold.Table$id
sampleFold.Table$id = NULL

sampleInfo = read.csv('/data/wuqinhua/phase_1r/CRC/sample_info.csv')
rownames(sampleInfo) = sampleInfo$sample
sampleInfo = sampleInfo[rownames(sampleFold.Table),]
head(sampleFold.Table)


cell_types_attn <- colnames(sampleFold.Table)
combined_data_atten <- data.frame()
for (cell in cell_types_attn) {
  dataTmp <- data.frame(
    atten = sampleFold.Table[[cell]],
    group = sampleInfo$phenotype,
    cell_type = cell
  )
  dataTmp_s <- na.omit(dataTmp)
  combined_data_atten <- rbind(combined_data_atten, dataTmp_s)
}

custom_colors <- c("normal" = '#66DD00',"MMRd"="#77DDFF","MMRp"="#FFBB66")

plot_df_attn <- combined_data_atten %>%
  dplyr::rename(Attn = atten) %>%
  dplyr::mutate(group = factor(group, levels = names(custom_colors))) %>%
  na.omit()



annotation_data <- plot_df_attn %>%
  group_by(cell_type) %>%
  summarise(
    kruskal_test_res = list(kruskal.test(Attn ~ group)),
    p_value = kruskal_test_res[[1]]$p.value,
    h_statistic = kruskal_test_res[[1]]$statistic,
    .groups = 'drop'
  ) %>%
  mutate(
    significance = case_when(
      p_value < 0.001 ~ "***",
      p_value < 0.01  ~ "**",
      p_value < 0.05  ~ "*",
      TRUE ~ "ns"
    ),
    title_with_sig = paste0(cell_type, " ", significance),

    stat_label = paste0(
        "H = ", round(h_statistic, 2),     
        ", ",                               
        scales::pvalue(p_value,             
                       accuracy = 0.001, 
                       add_p = TRUE)
    )
  )
plot_df_attn_final <- plot_df_attn %>%
  left_join(annotation_data %>% select(cell_type, title_with_sig), by = "cell_type") %>%
  filter(!is.na(title_with_sig))

p_facet_classification_updated <- ggplot(plot_df_attn_final, 
                                         aes(x = group, y = Attn, fill = group)) +
  
  geom_violin(trim = FALSE, alpha = 0.8) +
  geom_boxplot(width = 0.1, fill = "white", alpha = 0.7, outlier.shape = NA) +
  
  geom_text(
    data = annotation_data, 
    aes(label = stat_label), 
    x = -Inf, y = Inf,
    hjust = -0.1, vjust = 2,
    size = 2,
    fontface = "italic",
    inherit.aes = FALSE
  ) + 
  
  scale_fill_manual(values = custom_colors, name = "Phenotype") +
  facet_wrap(~title_with_sig, scales = "free",ncol = 11,nrow = 8) +
  
  labs(
    title = "Attention Score Distribution by Phenotype and Cell Type",
    y = "Standardized Attention Score"
  ) +
  
  theme_classic(base_size = 12) +
  theme(
    panel.grid.major.x = element_blank(),
    panel.grid.minor = element_blank(),
    strip.background = element_rect(fill = "#EFEFEF", color = NA),
    strip.text = element_text(size = 11, face = "bold", color = "#333333"),
    axis.text.x = element_text(size = 8, angle = 45, hjust = 1), 
    axis.text.y = element_text(size = 8),
    axis.title = element_text(size = 8, face = "bold"),
    legend.position = "right",
    axis.title.x = element_blank()
  )

ggsave("./Analysis/Figure/Attn_plot/attn_classification_violin_plot.pdf",
       p_facet_classification_updated, width = 30, height =25, dpi = 300)

print(p_facet_classification_updated)