In [None]:
import os
import sys
import scipy
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import scipy.io as sio
import anndata as ad
import matplotlib.pyplot as plt

os.chdir("/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/Age")

# 1. UMAP

In [None]:
adata = sc.read_h5ad("./ensemble_adata_with_attention.h5ad")

In [None]:

attnData = adata.obs
# attnData.to_csv('metadata.csv')

In [None]:
sc.settings.verbosity = 1
sc.settings.figdir = './Analysis/Figure/Attn_plot'
sc.settings.set_figure_params(dpi=100, fontsize=10, dpi_save=400,
    facecolor = 'white', figsize=(6,6), format='png')

In [None]:
idList = attnData['Tube_id'].unique()
for id in idList:
    attnTmp = attnData[attnData['Tube_id'] == id]
    avgScore = 1 / len(attnTmp)
    log_attn = np.log2(attnTmp['attention_weight_mean'] / avgScore)
    attn_scaled = (log_attn - np.mean(log_attn)) / np.std(log_attn)
    attn_scaled_clipped = np.clip(attn_scaled, -1, 1)
    attnData.loc[attnData['Tube_id'] == id, 'attn_scaled'] = attn_scaled_clipped

In [None]:

adata.obs["attn_scaled"] = attnData["attn_scaled"].values
adata1 = adata[adata.obs['Age_group'] == "A"]
adata2 = adata[adata.obs['Age_group'] == "B"]
adata3 = adata[adata.obs['Age_group'] == "C"]
adata4 = adata[adata.obs['Age_group'] == "D"]
adata5 = adata[adata.obs['Age_group'] == "E"]

### 1.1 UMAP of celltype

In [None]:
leiden_umap = sc.pl.umap(adata, color=['celltype'],
    show=False,  palette=sns.color_palette("husl", 24),
legend_fontsize=6, frameon=True, title='celltype', save = "_celltype.pdf")

### 1.2 UMAP of group

In [None]:
leiden_umap = sc.pl.umap(adata, color='Age', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='UMAP of Age',save="_Age.pdf") 

### 1.3 UMAP of age group

In [None]:
leiden_umap = sc.pl.umap(adata, color='Age_group', show=False, legend_fontsize=6, palette =["#FFCCCC", "#999933", "#B0E57C",  "#99CCFF",  "#D2B5E1"],
                           frameon= True, title='UMAP of Age_group',save="_Age_group.pdf") 

### 1.4 UMAP of Attention score

In [None]:
leiden_umap = sc.pl.umap(adata1, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of A',save="_attn_Age_A.pdf") 


In [None]:
leiden_umap = sc.pl.umap(adata2, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of B',save="_attn_Age_B.pdf") 
leiden_umap = sc.pl.umap(adata3, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of C',save="_attn_Age_C.pdf") 
leiden_umap = sc.pl.umap(adata4, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of D',save="_attn_Age_D.pdf") 
leiden_umap = sc.pl.umap(adata5, color='attn_scaled', show=False, legend_fontsize=6, color_map ='viridis',
                           frameon= True, title='Attention Score of E',save="_attn_Age_E.pdf") 

# 2. Celltype-level

In [None]:
setwd('/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/Age')
# rm(list = ls())
# gc()

library(tidyr)
library(ggplot2)
library(forestploter)
library(gridExtra)
library(tidyverse)
library(dplyr)
library(broom)
library(ggpubr)
library(randomForest)
library(mice)
library(reshape2)
library(Metrics)
library(ComplexHeatmap)
library(RColorBrewer)
library(fastcluster)
library(ggbeeswarm)
library(circlize)
library(ggrepel)
library(ggpubr)

In [None]:
attnData = read.csv('./metadata.csv')
head(attnData)
colnames(attnData)

nameAll = unique(attnData$celltype)
nameAll = sort(nameAll)
nameAll

nameList = nameAll

sampleFold = data.frame(id = character(), celltype = character(), fold = numeric())
idList = unique(attnData$Tube_id)
for (id in idList) {
  attnTmp = attnData %>% filter(Tube_id == id)
  avgScore = 1 / dim(attnTmp)[1]
  foldRes = attnTmp %>% group_by(celltype) %>% summarise(res = median(log2(attention_weight_mean/avgScore)))
  dataTmp = data.frame(id = rep(id,dim(foldRes)[1]),
                       celltype = foldRes$celltype,
                       fold = foldRes$res)
  dataTmp_s = dataTmp %>% filter(celltype %in% nameList)
  dataTmp_s$fold = scale(dataTmp_s$fold)
  sampleFold = rbind(sampleFold,dataTmp_s)
}

sampleFold.Table = dcast(sampleFold,id ~ celltype)
rownames(sampleFold.Table) = sampleFold.Table$id
sampleFold.Table$id = NULL

sampleInfo = read.csv('/home/wuqinhua/Project/PHASE/Age/Info/sample_info.csv')
rownames(sampleInfo) = sampleInfo$Tube_id
sampleInfo = sampleInfo[rownames(sampleFold.Table),]
head(sampleFold.Table)

### 2.1 Scatter trend plot

In [None]:
if (!all(rownames(sampleFold.Table) %in% rownames(sampleInfo))) {
  stop("Not all samples in sampleFold.Table are recorded in sampleInfo.")
}
sampleInfo <- sampleInfo[rownames(sampleFold.Table), ]

# --- 1. Calculate Score One: Correlation with Age (Phenotype Association) ---
phenotype_association_scores <- data.frame(
  cell_type = character(),
  correlation_rho_with_age = numeric(),
  p_value = numeric()
)

for (cell in colnames(sampleFold.Table)) {
  if (sd(sampleFold.Table[[cell]], na.rm = TRUE) == 0) next
  test_result <- cor.test(
    sampleFold.Table[[cell]], 
    sampleInfo$Age, 
    method = "pearson"
  )

  phenotype_association_scores <- rbind(
    phenotype_association_scores,
    data.frame(
      cell_type = cell,
      correlation_rho_with_age = test_result$estimate,
      p_value = test_result$p.value
    )
  )
}
head(phenotype_association_scores)


# --- 2. Calculate Score Two: Group Stability/Volatility (Overall Variation) ---
overall_variation_scores <- data.frame(
  cell_type = character(),
  overall_sd = numeric() 
)

for (cell in colnames(sampleFold.Table)) {
  sd_val <- sd(sampleFold.Table[[cell]], na.rm = TRUE)

  overall_variation_scores <- rbind(
    overall_variation_scores,
    data.frame(
      cell_type = cell,
      overall_sd = sd_val
    )
  )
}
head(overall_variation_scores)


combined_scores_final <- merge(phenotype_association_scores, 
                               overall_variation_scores, 
                               by = "cell_type")

combined_scores_final <- combined_scores_final %>%
  mutate(
    label = ifelse(p_value < 0.05, cell_type, "") 
  )
  
head(combined_scores_final)

In [None]:
rho_threshold <- median(abs(combined_scores_final$correlation_rho_with_age), na.rm = TRUE)
sd_threshold <- median(combined_scores_final$overall_sd, na.rm = TRUE)

plot_data_final <- combined_scores_final %>%
  mutate(
    direction = ifelse(correlation_rho_with_age > 0, "Positive", "Negative"),
    
    label_text = ifelse(
      abs(correlation_rho_with_age) > rho_threshold & overall_sd > sd_threshold, 
      cell_type, 
      ""
    )
  )


threshold_scatter_plot <- ggplot(plot_data_final, 
                                 aes(x = abs(correlation_rho_with_age), 
                                     y = overall_sd, 
                                     color = direction,
                                     label = label_text)) +

  geom_vline(xintercept = rho_threshold, linetype = "dashed", color = "gray50") +
  geom_hline(yintercept = sd_threshold, linetype = "dashed", color = "gray50") +

  geom_point(size = 4, alpha = 0.8) +
  
  geom_text_repel(color = "black", size = 4, max.overlaps = Inf, 
                  box.padding = 0.6, min.segment.length = 0) +
  
  scale_color_manual(
    name = "Direction of Correlation",
    values = c("Positive" = "#E64B35", "Negative" = "#3C5488")
  ) +

  annotate("text", x = Inf, y = Inf, label = "High Correlation & High Variation\n(Key Targets)", 
           hjust = 1.1, vjust = 1.2, color = "black", fontface = "bold", size = 4) +

  theme_bw() +
  labs(
    title = "Identifying Key Cell Types by Association and Variation",
    x = "Strength of Association with Age (Absolute Spearman's rho)",
    y = "Overall Variation (Standard Deviation)"
  ) +
  theme(
    plot.title = element_text(hjust = 0.5, face = "bold", size = 16),
    legend.position = "bottom"
  )

ggsave("./Analysis/Figure/Attn_plot/threshold_scatter_plot.pdf", threshold_scatter_plot, width = 8, height = 6)
print(threshold_scatter_plot)

### 2.2 R2: Age and Attn

In [None]:
cell_types_attn <- colnames(sampleFold.Table)
combined_data_atten <- data.frame()
for (cell in cell_types_attn) {
  dataTmp <- data.frame(
    atten = sampleFold.Table[[cell]],
    group = sampleInfo$Age_group,
    age   = sampleInfo$Age,
    cell_type = cell
  )
  dataTmp_s <- na.omit(dataTmp)
  combined_data_atten <- rbind(combined_data_atten, dataTmp_s)
}

custom_colors <- c("A" = "#FFCCCC", "B" = "#999933", "C" = "#B0E57C", "D" = "#99CCFF", "E" = "#D2B5E1")

plot_df_attn <- combined_data_atten %>%
  dplyr::select(age, atten, group, cell_type) %>%
  dplyr::rename(Age = age, Attn = atten, Age_group = group) %>%
  dplyr::mutate(Age_group = factor(Age_group, levels = names(custom_colors))) %>%
  na.omit()

significance_labels <- phenotype_association_scores %>%
  mutate(
    significance = case_when(
      p_value < 0.001 ~ "***",
      p_value < 0.01 ~ "**", 
      p_value < 0.05 ~ "*",
      TRUE ~ "ns"
    ),
    title_with_sig = paste0(cell_type, " ", significance)
  ) %>%
  select(cell_type, title_with_sig)
  
plot_df_attn_with_sig <- plot_df_attn %>%
  left_join(significance_labels, by = "cell_type") %>%
  mutate(title_with_sig = ifelse(is.na(title_with_sig), cell_type, title_with_sig))

p_facet_attn_with_r2 <- ggplot(plot_df_attn_with_sig, aes(x = Age, y = Attn)) +
  
  geom_point(aes(color = Age_group), alpha = 0.65, size = 1.3) +

  geom_smooth(
    method = "loess", se = TRUE,
    color = "#666666", fill = "#C9C9C9",
    alpha = 0.22, linewidth = 0.8
  ) +

  ggpubr::stat_cor(
    method = "pearson",
    aes(label = paste(..rr.label.., ..p.label.., sep = "~`,`~")), 
    label.x.npc = "left",
    label.y.npc = "top",
    hjust = 0,
    size = 3.2
  ) +
  
  scale_color_manual(values = custom_colors, name = "Age Group") +

  facet_wrap(~title_with_sig, scales = "free", labeller = label_value) +
  
  labs(
    title = "Standardized Attn vs Age by Cell Type (with significance markers)",
    x = "Age",
    y = "Standardized Attn"
  ) +

  theme_classic(base_size = 10) +
  theme(
    panel.background = element_rect(fill = "white", color = NA),
    plot.background  = element_rect(fill = "white", color = NA),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    strip.background = element_rect(fill = "#EFEFEF", color = NA),
    strip.text = element_text(size = 9, face = "bold", color = "#333333"),
    axis.text.x = element_text(size = 8, color = "#333333", hjust = 1),
    axis.text.y = element_text(size = 8, color = "#333333"),
    axis.title = element_text(size = 10, face = "bold"),
    legend.position = "right",
    legend.title = element_text(size = 9, face = "bold"),
    legend.text  = element_text(size = 8),
    panel.spacing = unit(0.8, "lines")
  )

ggsave("./Analysis/Figure/Attn_plot/attn_age_scatter_r2_facet_with_sig.pdf",
       p_facet_attn_with_r2, width = 20, height = 18, dpi = 300)

print(p_facet_attn_with_r2)