In [1]:
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_core.exceptions import OutputParserException
from pydantic import BaseModel, Field
from openai import BadRequestError, RateLimitError

from typing import List
import os 
from os import listdir
from os.path import basename, exists
import re
from functools import reduce
import sys
import time
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type

import dill
import pandas as pd
import numpy as np
from SuperSCC import list_files

In [2]:
os.chdir("/home/fengtang/jupyter_notebooks/working_script/evulate_feature_selection/2nd_submssion/")

# local = "/home/fengtang/jupyter_notebooks/working_script/evulate_feature_selection/2nd_submssion/markers_for_run_GPT_on_local"

In [3]:
class Output(BaseModel):
    GeneSetName: List[str] = Field(..., description = "The name of the gene set being evaluated")
    GeneList: List[list] = Field(..., description = "Comma-separated list of genes in the set")
    RelevantGeneRatio: List[float] = Field(..., description = "The proportion of relevant genes in the set")
    BiologicalRelevanceScore: List[float] = Field(..., description = "Derived from Gene Ontology and KEGG pathways, reflecting the biological function of the gene set")
    Pvalue: List[list] = Field(..., description = "A value from the statistical comparison")
    SetvsSet: List[list] = Field(..., description = "Gene set names for comparison")
    Summary: List[str] = Field(..., description = "A brief summary of the gene functions or pathway associations for each gene in the gene set")
    Conclusion: List[str] = Field(..., description = "A clear conclusion to indicate the gene set (e.g gene set1) as a better representative of that specific cell type")


In [None]:
def llm_compare_gene_set(model, api_key, base_url, temperature, cell_type, SuperSCC_gene_set, Seurat_wilcox_gene_set, Seurat_roc_gene_set, Scanpy_t_test_gene_set, Scanpy_log_reg_gene_set,
                         structure_output = True, system = None, parser = None):
    
    llm = ChatOpenAI(model = model, 
                     temperature = temperature, 
                     api_key = api_key,
                     base_url= base_url,
                     max_retries = 20)
    
    if system == None:
        system = """
            Suppose you are an insightful biologist tasked with evaluating multiple gene sets to determine which one better reflects the underlying biological function. 
            You will use both Gene Ontology and KEGG databases to design scoring metrics. 
            If cell type labels are provided, evaluate which gene set is a better representative of that specific cell type. 
            Gene Set Format: Input gene sets can be in gene symbol or Ensembl ID format. If Ensembl IDs are provided, automatically convert them to gene symbols, ensuring the accuracy of the conversion. 
            For comparison bewteen each pair of gene sets, use a statistical test like Fisher’s exact test (or chi-squared test if applicable), ensuring that the calculation detail is shown and accuracy is guaranteed (e.g. make sure 2x2 contingency table is used for Fisher’s exact test ). 
            Evaluation Method (Scoring Metrics):  
                - Relevant Gene Ratio: The proportion of relevant genes in each gene set, should be numeric value range from 0 to 1. 
                - Biological Relevance Score: Derived from Gene Ontology and KEGG pathways, reflecting the biological function of the gene set, should be numeric value range from 0 to 1. 
                - Also the evaluation should be independent of gene set order. Normalize the ratio to account for any differences in gene set size. 
            Output requirements:  
                - GeneSetName: The name of the gene set being evaluated. 
                - GeneList: Comma-separated list of genes in the set.
                - RelevantGeneRatio: The proportion of relevant genes in the set. 
                - Pvalue: A value from the statistical comparison. When statistical test can not be done, should return 'None'.
                - BiologicalRelevanceScore: Based on Gene Ontology and KEGG database associations. 
                - Summary: A brief summary of the gene functions or pathway associations for each gene in the gene set. 
                - SetvsSet: Gene set names for comparison (e.g. gene_set1 vs gene_set2)
                - Conclusion: a clear conclusion to indicate which gene set name as a better representative of that specific cell type and also summarize the reason.

            cell_type: {cell_type}
            
            SuperSCC_gene_set: {SuperSCC_gene_set}
            Seurat_wilcox_gene_set: {Seurat_wilcox_gene_set}
            Seurat_roc_gene_set: {Seurat_roc_gene_set}
            Scanpy_t_test_gene_set: {Scanpy_t_test_gene_set}
            Scanpy_log_reg_gene_set: {Scanpy_log_reg_gene_set}

            <format_instruction>
            {format_instructions}
            </format_instruction>
            """
    else:
        system = system
    
    if parser == None:
        parser = PydanticOutputParser(pydantic_object = Output)
    else:
        parser = PydanticOutputParser(pydantic_object = parser)

    prompt = PromptTemplate(template = system, 
                            input_variables=["cell_type", "SuperSCC_gene_set", "Seurat_wilcox_gene_set", "Seurat_roc_gene_set",
                                             "Scanpy_t_test_gene_set", "Scanpy_log_reg_gene_set"], 
                            partial_variables = {"format_instructions": parser.get_format_instructions()})
    
    if structure_output:
        chain = prompt | llm | parser
    else:
        chain = prompt | llm 
    
    res = chain.invoke({"cell_type": cell_type, "SuperSCC_gene_set": SuperSCC_gene_set, "Seurat_wilcox_gene_set": Seurat_wilcox_gene_set, "Seurat_roc_gene_set": Seurat_roc_gene_set,
                        "Scanpy_t_test_gene_set": Scanpy_t_test_gene_set, "Scanpy_log_reg_gene_set": Scanpy_log_reg_gene_set})

    return res

In [6]:
def id2symbol(reference, query, multi_select = "first"):
     query = pd.DataFrame({"gene_id": query})

     query = query.join(reference.set_index("gene_id"), how = "left", on = "gene_id")
     query = query[query.gene_id.duplicated(keep=multi_select) == False]
     query = query[query.gene_name.duplicated(keep=multi_select) == False]
     return query.gene_name.values.tolist()

In [7]:
# get the id2symbol reference dataframe
reference = pd.read_csv("human_id2symbol.csv")

In [8]:
# get the SuperSCC markers
superscc_markers = pd.read_pickle("SuperSCC_default_retrieving_method_top_20_markers.pkl")
for i in superscc_markers:
    for j in superscc_markers[i]:
        feature = superscc_markers[i][j].feature.tolist()
        if(any([True if re.search("ENSG\\d+", i) else False for i in feature])):
            symbols = id2symbol(reference = reference, query = feature)
        else:
            symbols = feature
        superscc_markers[i][j] = symbols

# output the marker for running openAI API on local
# with open(f"{local}/superscc_markers.pkl", "wb") as file:
#     dill.dump(superscc_markers, file)

In [9]:
# get the Scanpy-t test markers
file = pd.read_csv("/mnt/disk5/zhongmin/superscc/结果位置/结果位置_3.csv", encoding = "GBK", index_col = 0)
scanpy_t_test_markers = dict()

for idx, i in enumerate(file.t_test_path.values):
    csv = pd.read_csv(i)
    csv = csv.loc[(csv.logfoldchanges > 1) & (csv.pvals_adj < 0.05), :].sort_values("pvals_adj", ascending = True)
    key = file.index.tolist()[idx]
    group_csv = csv.groupby("group")
    groups = group_csv.groups.keys()
    for idx, group in enumerate(groups):
        markers = group_csv.get_group(group).head(20).names.tolist()
        if(any([True if re.search("ENSG\\d+", i) else False for i in markers])):
            markers = id2symbol(reference, markers)
        if idx == 0:
            scanpy_t_test_markers[key] = {group: markers}
        else:
            scanpy_t_test_markers[key].update({group: markers})

# output the marker for running openAI API on local
# with open(f"{local}/scanpy_t_test_markers.pkl", "wb") as file:
#     dill.dump(scanpy_t_test_markers, file)

In [11]:
# get the Seurat-wilcox test markers
files = list_files(path="seurat_res", pattern=".+csv$")
seurat_wilcox_markers = dict()

for file in files:
    csv = pd.read_csv(file)
    csv = csv.loc[(csv.p_val_adj < 0.05) & (csv.avg_log2FC > 1)].sort_values("p_val_adj", ascending = True)
    name = re.sub("_seurat_feature.csv", "", basename(file))

    group_csv = csv.groupby("cluster")

    for idx, i in enumerate(group_csv.groups.keys()):
        markers = group_csv.get_group(i).head(20).gene.values.tolist()
        if(any([True if re.search("ENSG\\d+", i) else False for i in markers])):
            markers = id2symbol(reference, markers)
        if idx == 0:
            seurat_wilcox_markers[name] = {i: markers}
        else:
            seurat_wilcox_markers[name].update({i: markers})
            
# output the marker for running openAI API on local
# with open(f"{local}/seurat_wilcox_markers.pkl", "wb") as file:
#     dill.dump(seurat_wilcox_markers, file)

In [None]:
# get the shared cell type per method per dataset
shared_cell_type = dict()
for i in superscc_markers:
    superscc_cell_type = set(superscc_markers[i].keys())
    scanpy_t_test_cell_type = set(scanpy_t_test_markers[i].keys())
    seurat_wilcox_cell_type = set(seurat_wilcox_markers[i].keys())
    res = reduce(lambda x, y: x.intersection(y), [superscc_cell_type, scanpy_t_test_cell_type, seurat_wilcox_cell_type])
    shared_cell_type[i] = res

In [None]:
# top 20 markers comparison
model = "deepseek-v3" # qwen-max or GPT 4.1-mini
api_key = "#####" # your api key
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"


home = "/home/fengtang/jupyter_notebooks/working_script/evulate_feature_selection/2nd_submssion/evaluation_res/original_prompt/DeepSeek-V3/10_markers"
os.chdir(home)
os.getcwd()



for tempeture in [0.1, 0.5, 0.9]:

    if not exists(f"tempeature_{tempeture}"):
        os.makedirs(f"tempeature_{tempeture}")
        os.chdir(f"tempeature_{tempeture}")
    else:
        os.chdir(f"tempeature_{tempeture}")
    
    for dataset in shared_cell_type:
        
        if not exists(f"{dataset}"):
            os.makedirs(f"{dataset}")
            os.chdir(f"{dataset}")
        else:
            os.chdir(f"{dataset}")

        for cell_type in shared_cell_type[dataset]:
            
            print(f"Processing with cell type '{cell_type}' in dataset '{dataset}'")


            try:
                res = llm_compare_gene_set(model = model, 
                                        api_key = api_key, 
                                        base_url = base_url,
                                        SuperSCC_gene_set = superscc_markers[dataset][cell_type],
                                        Scanpy_t_test_gene_set = scanpy_t_test_markers[dataset][cell_type],
                                        Seurat_wilcox_gene_set = seurat_wilcox_markers[dataset][cell_type],
                                        cell_type = cell_type, 
                                        temperature = tempeture,
                                        structure_output = True)
            
                cell_type = re.sub("/", "or", cell_type)
                with open(f"{cell_type}_{dataset}_llm_evaulation.pkl", "wb") as file:
                    dill.dump(res, file)
                time.sleep(10)
            except:
                cell_type = re.sub("/", "or", cell_type)
                with open(f"{cell_type}_{dataset}_log.txt", "a") as file:
                    file.write(f"Failure on cell type '{cell_type}' in dataset '{dataset}' \n")
        
        os.chdir("..")
        
    os.chdir(home)
        

In [None]:
# top 10 markers comparison
model = "deepseek-v3" # qwen-max or GPT 4.1-mini
api_key = "#####"
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"


for tempeture in [0.5]:

    if not exists(f"tempeature_{tempeture}"):
        os.makedirs(f"tempeature_{tempeture}")
        os.chdir(f"tempeature_{tempeture}")
    else:
        os.chdir(f"tempeature_{tempeture}")
    
    for dataset in shared_cell_type:
        
        if not exists(f"{dataset}"):
            os.makedirs(f"{dataset}")
            os.chdir(f"{dataset}")
        else:
            os.chdir(f"{dataset}")

        for cell_type in shared_cell_type[dataset]:
            
            print(f"Processing with cell type '{cell_type}' in dataset '{dataset}'")

            try:
                res = llm_compare_gene_set(model = model, 
                                        api_key = api_key, 
                                        base_url = base_url,
                                        SuperSCC_gene_set = superscc_markers[dataset][cell_type][0:10],
                                        Scanpy_t_test_gene_set = scanpy_t_test_markers[dataset][cell_type][0:10],
                                        Seurat_wilcox_gene_set = seurat_wilcox_markers[dataset][cell_type][0:10],
                                        cell_type = cell_type, 
                                        temperature = tempeture,
                                        structure_output = True)
            
                cell_type = re.sub("/", "or", cell_type)
                with open(f"{cell_type}_{dataset}_llm_evaulation.pkl", "wb") as file:
                    dill.dump(res, file)
                # time.sleep(10)
            except:
                cell_type = re.sub("/", "or", cell_type)
                with open(f"{cell_type}_{dataset}_log.txt", "a") as file:
                    file.write(f"Failure on cell type '{cell_type}' in dataset '{dataset}' \n")
        
        os.chdir("..")
        
    os.chdir(home)
        