## 1. Calculate the gene expression score

In [None]:
import os,gc
import scanpy as sc
import pandas as pd
import anndata as ad
import numpy as np

os.chdir("/data/wuqinhua/phase/covid19")

In [None]:
## ----------------- TOP50 gene ---------------------

adata = sc.read('./Alldata_anno.h5ad')
print("read_over")

adata1 = adata[adata.obs['group'] == "H"]
adata2 = adata[adata.obs['group'] == "M"]
adata3 = adata[adata.obs['group'] == "S"]
print("split_over")

H_gene = pd.read_csv("./Analysis_result/Attr_result/attr_H_PHASE.csv") 
H_50 = H_gene.sort_values(by='attr_value', ascending=False).head(50)
H_50 = H_50['gene_name'].tolist()

M_gene = pd.read_csv("./Analysis_result/Attr_result/attr_M_PHASE.csv") 
M_50 = M_gene.sort_values(by='attr_value', ascending=False).head(50)
M_50 = M_50['gene_name'].tolist()

S_gene = pd.read_csv("./Analysis_result/Attr_result/attr_S_PHASE.csv") 
S_50 = S_gene.sort_values(by='attr_value', ascending=False).head(50)
S_50 = S_50['gene_name'].tolist()

sc.tl.score_genes(adata1, H_50, ctrl_size=50, gene_pool=None, n_bins=25, score_name='gene_ex_score', random_state=0, copy=False, use_raw=None)
adata1.obs.to_csv('./Analysis_result/Conjoint_result/gene_ex_scores_Hh.csv')

sc.tl.score_genes(adata2, M_50, ctrl_size=50, gene_pool=None, n_bins=25, score_name='gene_ex_score', random_state=0, copy=False, use_raw=None)
adata2.obs.to_csv('./Analysis_result/Conjoint_result/gene_ex_scores_Mm.csv')

sc.tl.score_genes(adata3, S_50, ctrl_size=50, gene_pool=None, n_bins=25, score_name='gene_ex_score', random_state=0, copy=False, use_raw=None)
adata3.obs.to_csv('./Analysis_result/Conjoint_result/gene_ex_scores_Ss.csv')
print("1_over")


## 2. Correlation scatter

In [None]:
setwd("/data/wuqinhua/phase/covid19")
rm(list = ls())
gc()

library(cowplot) 
library(ggplot2) 
library(RColorBrewer) 
library(dplyr) 
library(tidyr)
library(ggpubr) 
library(broom)
library(tidyverse)
library(tibble)
library(janitor)
library(ggrepel)
library(tidyr)
library(ggplot2)
library(forestploter)
library(gridExtra)
library(tidyverse)
library(dplyr)
library(broom)
library(ggpubr)
library(randomForest)
library(mice)
library(reshape2)
library(gghalves)
library(cowplot)
library(patchwork)

In [None]:
attnData = read.csv('./Analysis_result/Attn_result/attn_cell_PHASE.csv')
nameAll = unique(attnData$predicted_labels)
nameList = sort(nameAll)

sampleFold = data.frame(id = character(), celltype = character(), fold = numeric())
idList = unique(attnData$sample_id)
for (id in idList) {
  attnTmp = attnData %>% filter(sample_id == id)
  avgScore = 1 / dim(attnTmp)[1]
  foldRes = attnTmp %>% group_by(predicted_labels) %>% summarise(res = median(log2(attn/avgScore)))
  dataTmp = data.frame(id = rep(id,dim(foldRes)[1]),
                       celltype = foldRes$predicted_labels,
                       fold = foldRes$res)
  dataTmp_s = dataTmp %>% filter(celltype %in% nameList)

  dataTmp_s$fold = scale(dataTmp_s$fold)


  sampleFold = rbind(sampleFold,dataTmp_s)
}

sampleFold.Table = dcast(sampleFold,id ~ celltype)
rownames(sampleFold.Table) = sampleFold.Table$id
sampleFold.Table$id = NULL


attn = sampleFold.Table
head(attn)


In [1]:
ex_Hh = read.csv('./Analysis_result/Conjoint_result/gene_ex_scores_Hh.csv')
ex_Mm = read.csv('./Analysis_result/Conjoint_result/gene_ex_scores_Mm.csv')
ex_Ss = read.csv('./Analysis_result/Conjoint_result/gene_ex_scores_Ss.csv')
exData = rbind(ex_Hh, ex_Mm, ex_Ss)
head(exData)

In [None]:
nameAll = unique(exData$predicted_labels)
nameList = sort(nameAll)

sampleFold = data.frame(id = character(), celltype = character(), fold = numeric())
idList = unique(exData$sample_id)
for (id in idList) {
  attnTmp = exData %>% filter(sample_id == id)
  avgScore = 1 / dim(attnTmp)[1]
  foldRes = attnTmp %>% group_by(predicted_labels) %>% summarise(res = median(gene_ex_score))
  dataTmp = data.frame(id = rep(id,dim(foldRes)[1]),
                       celltype = foldRes$predicted_labels,
                       fold = foldRes$res)
  dataTmp_s = dataTmp %>% filter(celltype %in% nameList)
  sampleFold = rbind(sampleFold,dataTmp_s)
}

ex_sampleFold.Table = dcast(sampleFold,id ~ celltype)
rownames(ex_sampleFold.Table) = ex_sampleFold.Table$id
ex_sampleFold.Table$id = NULL
ex = ex_sampleFold.Table
head(ex)

In [None]:
sampleInfo = read.csv('./COVID19_sample_condition_560.csv')
rownames(sampleInfo) = sampleInfo$sample_id
sampleInfo = sampleInfo[rownames(attn),]
ex = ex[rownames(attn),]

label_df <- data.frame(sample = sampleInfo$sample_id, label = sampleInfo$group)

custom_colors <- c("H" = "#A2D8A2", "M" = "#B8D8F1", "S" = "#F5B54F")
plot_list <- list()

for (celltype in colnames(ex_sampleFold.Table)) {
  merged_data <- data.frame(
    sample = rownames(attn),
    attn = attn[, celltype],
    ex = ex[, celltype],
    label = label_df$label
  )


p <- ggscatter(merged_data, x = 'attn', y = 'ex', title = paste("Cell type:", celltype),
                 add = "reg.line", color = 'label',
                 add.params = list(color = "#4B4B4B", fill = "lightgray"),
                 cor.coeff.args = list(method = "pearson", size = 4), 
                 conf.int = TRUE, cor.coef = TRUE, cor.method = "pearson", 
                 size = 1, font.label = c(5, 'plain')) +
                 scale_color_manual(values = custom_colors)

  # print(p)
  plot_list[[celltype]] <- p
  
  ggsave(paste0("./Plot/Conjoint_plot/cor_plot/correlation_", celltype, ".png"), plot = p, width = 5, height = 4)
}


combine_plot = wrap_plots(plot_list,ncol=8)
ggsave("./Plot/Conjoint_plot//correlation_all.png",combine_plot,width = 25,height = 28)