In [None]:
#parameters for papermill to inject
gse = "GSE247175"
working_dir = ""

In [None]:
import os

if os.path.isdir(working_dir):
    os.chdir(working_dir)
    #print("changed working dir")
else:
    raise ValueError("not a valid directory")

In [None]:
#no need to refresh kernel when changes are made to the helper scripts
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import pandas as pd
import sys
from IPython.display import display,FileLink, Markdown, HTML, Image

In [None]:
#set to "" during production. Change only if you want to save images in a different directory.
resource_path = "" 

In [None]:
import contextlib

@contextlib.contextmanager
def suppress_output(stdout=True, stderr=True, dest='/dev/null'):
    ''' Usage:
    with suppress_output():
        print('hi')
    '''
    dev_null = open(dest, 'a')
    if stdout:
        _stdout = sys.stdout
        sys.stdout = dev_null
    if stderr:
        _stderr = sys.stderr
        sys.stderr = dev_null
    try:
        yield
    finally:
        if stdout:
            sys.stdout = _stdout
        if stderr:
            sys.stderr = _stderr

In [None]:
from Bio import Entrez
from dotenv import load_dotenv

load_dotenv()

os.environ['ENTREZ_EMAIL'] = os.getenv('ENTREZ_EMAIL')

Entrez.email = os.environ['ENTREZ_EMAIL']

id_handle = Entrez.esearch(db="gds", term=f"{gse}[Accession]", retmax=1)
id_record = Entrez.read(id_handle)
gds_id = id_record["IdList"][0]

In [None]:
stream = Entrez.esummary(db='gds', id=gds_id)
record = Entrez.read(stream)

In [None]:
map_species = {
    "homo sapiens": "human",
    "mus musculus": "mouse"
}

In [None]:
species = map_species[record[0]['taxon'].lower()]

In [None]:
if len(record[0]['PubMedIds'])==0: #discard studies with no pubmed citation
    raise ValueError("No PubMed citation found for this study.")

In [None]:
pmid = int(record[0]['PubMedIds'][0])

In [None]:
if len(record[0]['Samples']) not in range(6, 25): #discard studies that dont have 6-24 samples
    raise ValueError("Number of samples need to be within 6-24.")

In [None]:
def fetch_pubmed_metadata(pmid):
    handle = Entrez.efetch(db="pubmed", id=pmid, retmode="xml")
    records = Entrez.read(handle)
    article = records['PubmedArticle'][0]['MedlineCitation']['Article']

    title = article['ArticleTitle']
    journal = article['Journal']['Title']
    journal_abbr = article['Journal']['ISOAbbreviation']
    year = article['Journal']['JournalIssue']['PubDate'].get('Year', '')
    volume = article['Journal']['JournalIssue'].get('Volume', '')
    issue = article['Journal']['JournalIssue'].get('Issue', '')
    pages = article.get('Pagination', {}).get('MedlinePgn', '')
    authors = article.get('AuthorList', [])

    # def format_author(author):
    #     initials = ''.join(author.get('Initials', ''))
    #     return f"{author['LastName']} {initials}"

    def format_author(author):
        initials = '. '.join(author.get('Initials', '')) + '.'
        return f"{author['LastName']}, {initials}"


    authors = [format_author(a) for a in authors]

    return {
        "title": title,
        "journal": journal,
        "year": year,
        "volume": volume,
        "issue": issue,
        "pages": pages,
        "authors": authors
    }

In [None]:
pmdict = fetch_pubmed_metadata(pmid)

In [None]:
def format_apa_citation(metadata, pmid=None):
    authors = metadata['authors']

    if len(authors) <= 20:
        if len(authors) > 1:
            author_str = ', '.join(authors[:-1]) + ', & ' + authors[-1]
        else:
            author_str = authors[0]
    else:
        author_str = ', '.join(authors[:19]) + ', ... ' + authors[-1]

    citation = (
        f"{author_str} ({metadata['year']}). {metadata['title']} "
        f"*{metadata['journal']}*, {metadata['volume']}({metadata['issue']}), {metadata['pages']}."
    )

    if metadata.get('doi'):
        citation += f" https://doi.org/{metadata['doi']}"
    elif pmid:
        citation += f" PMID: {pmid}"

    return citation


In [None]:
citation = format_apa_citation(pmdict)

In [None]:
import json
from datetime import datetime

#write a json file as a catalog list.
metadata_path = "metadata.json"

entry = {
    "id": gse, 
    "author": ", ".join(pmdict['authors']),
    "year": int(pmdict['year']),
    "species": species,
    "title": pmdict['title'],
    "pmid": pmid,
    "num_samps": len(record[0]['Samples']),
    "samples": ", ".join(sorted([w['Accession'] for w in record[0]['Samples']])),
    "citation": citation,
    "timestamp": datetime.now().isoformat()
}

with open(metadata_path, "w") as f:
    json.dump(entry, f, indent=4)


In [None]:
display(Markdown(f"# **Reanalysis of \"{pmdict['title']}\" by {pmdict['authors'][0]} et al., {pmdict['journal']}, {pmdict['year']}**"))

In [None]:
link = f"https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc={gse}"
display(Markdown(f"{citation}"))
HTML(f'<a href="{link}" target="_blank">Visit GEO accession page</a>')

In [None]:
from google import genai

os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")

client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
#print(os.environ["GOOGLE_API_KEY"])
prompt = f'''
  You are an expert academic writer. Your task is to reformat the provided research information into a concise abstract around 250 words following this exact template:

  "In this study, <FIRST AUTHOR> et al. [1] profiled <CELLS AND CONDITIONS> to further our understanding of <TOPIC>. The reanalysis of this dataset include <FILL IN>"

  Here is the contextual information:
  Author: {pmdict['authors']}
  Title: {pmdict['title']}
  Summary: {record[0]['summary']}

  In the reanalysis explanation, use the following information: the reanalysis is a full RNA-seq analysis pipeline that consists of: UMAP[2], PCA[3], t-SNE[4] plots of the samples; clustergram heatmap; differential gene expression analysis
  for each pair of control and perturbation samples; Enrichment analysis for each gene signature using Enrichr [5, 6, 7]; Transcription factor analysis of gene signatures
  using ChEA3 [8] ; Reverser and mimicker drug match analysis using L2S2 [9] and DRUG-seqr [10], both FDA and non-FDA approved. Results are provided as tables in addition to bar charts.

  Please write the reanalysis as a complete paragraph with smoothly transitioning sentences. Use consistent, present tense.
  Do not omit or change the ordering of the reference numbers.
  Do not change the reference and insert it, in parentheses, where indicated. 

  Now, generate the abstract strictly following the template. Do not include any other text or introductory/concluding remarks.
'''

response = client.models.generate_content(
    model="gemini-2.0-flash", contents=prompt
)
#print(response.text)

## **Abstract**

In [None]:
display(Markdown(f"{response.text}\n*This abstract was generated with the assistance of Gemini 2.0 Flash.*"))

## **Methods**

*RNA-seq alignment*

Gene count matrices were obtained from ARCHS4 [11], which preprocessed the raw FASTQ data using the Kallisto [12] and STAR [] pseudoalignment algorithm.

*Gene matrix processing* 

The raw gene matrix was filtered to remove genes that do not have an average of 3 reads across the samples. It was then quantile, log2, and z-score normalized. A regex-based function was used to infer whether individual samples belong to a “control” or a “perturbation” group by processing the metadata associated with each sample. 

*Dimensionality Reduction Visualization*

Three types of dimensionality reduction techniques were applied on the processed expression matrices: UMAP[2], PCA[3], and t-SNE[4]. UMAP was calculated by the UMAP Python package and PCA and t-SNE were calculated using the Scikit-Learn Python library. The samples were then represented on 2D scatterplots.

*Clustergram Heatmap*

As a preliminary step, the top 1000 genes exhibiting most variability were selected. Using this new set, clustergram heatmaps were generated. Two versions of the clustergram exist: an interactive one generated by Clustergrammer [13] and a publication-ready alternative.

*Differentially Expressed Genes Calculation and Volcano Plot*

Differentially expressed genes between the control and perturbation samples were calculated using Limma Voom [14]. The logFC and -log10p values of each gene were visualized as a volcano scatterplot. Upregulated and downregulated genes were selected according to this criteria: p < 0.05 and |logFC| > 1.0.

*Enrichr Enrichment Analysis*

The upregulated and down-regulated sets were separately submitted to Enrichr [5, 6, 7]. These sets were compared against libraries from ChEA [8], ARCHS4 [12], Reactome Pathways [15], MGI Mammalian Phenotype [16], Gene Ontology Biological Processes [17], GWAS Catalog [18], KEGG [19, 20, 21], and WikiPathways [22]. The top matched terms from each library and their respective -log10p values were visualized as barplots.

*Chea3 Transcription Factor Analysis*

The upregulated and down-regulated sets were separately submitted to Chea3 [8]. These sets were compared against the libraries ARCHS4 Coexpression [12], GTEx Coexpression [23], Enrichr [5, 6, 7], ENCODE ChIP-seq [24, 25], ReMap ChIP-seq [26], and Literature-mined ChIP-seq. The top matched TFs were ranked according to their average score across each library and represented as barplots.

*L2S2 and Drug-seqr drug analysis*

The top 500 up and downregulated sets were submitted simulataneously to identify reverser and mimicker molecules, both FDA and non-FDA approved, from the L2S2 [9] and Drug-seqr [10] databases. The top matched molecules were compiled into tables and visualized as barplots. 


In [None]:
tab_num = 1
fig_num = 1
save_formats = ['png', 'svg', 'jpeg']

In [None]:
import archs4py as a4
#file_path = a4.download.counts("human", path="", version="latest") #comment out if the file is already downloaded
file = os.path.join("/home/ajy20/projects/8--auto-playbook-geo-reports", "human_gene_v2.latest.h5") 

In [None]:
metadata = a4.meta.series(file, gse)

In [None]:
%%capture
import re
import nltk
from nltk.corpus import stopwords

nltk.download('stopwords')
words_to_remove = ['experiement', 'tissue', 'type', 'batch', 'treatment', 'experiment', 'patient', 'batch', '1', '2', '3', '4', '5', '6', '7', '8', '9']
stopwords_plus = set(stopwords.words('english') + (words_to_remove))
pattern = r'[-,_.:]'


In [None]:
terms_to_remove = ["cell line", "cell type", "genotype", "treatment"]


pattern1 = r"\b(" + "|".join(map(re.escape, terms_to_remove)) + r")\b"

metadata["cleaned_characteristics"] = metadata["characteristics_ch1"].str.replace(
    pattern1, 
    "", 
    flags=re.IGNORECASE, 
    regex=True
).str.replace(r"\s+", " ", regex=True).str.strip()

metadata['cleaned_characteristics'] = metadata['cleaned_characteristics'].apply(lambda x: re.sub(pattern, " ", x).strip().lower())
metadata['cleaned_characteristics'] = metadata['cleaned_characteristics'].apply(
    lambda text: " ".join([word for word in text.split() if word not in stopwords_plus])
)

In [None]:
metadata['clean_title'] = metadata['title'].apply(lambda x: re.sub('[0-9]+', '', x))
metadata['clean_title'] = metadata['clean_title'].apply(lambda x: re.sub(pattern, " ", x).strip().lower())
metadata['clean_title'] = metadata['clean_title'].apply(
    lambda text: " ".join([word for word in text.split() if word not in stopwords_plus])
)

#metadata

In [None]:
groups = metadata.groupby(by='clean_title', level=None)

In [None]:
ctrl_words = set(['wt', 'wildtype', 'control', 'cntrl', 'ctrl', 'uninfected', 'normal', 'untreated', 'unstimulated', 'shctrl', 'ctl', 'healthy', 'sictrl', 'sicontrol', 'ctr', 'wild', 'dmso', 'vehicle', 'naive'])

In [None]:
groupings = {}
for label, group in groups:
    if len(group) not in {3, 4}: #enforce 3-4 samples per group
        raise ValueError("Study does not have 3-4 samples per group")
    
    groupings[label] = group['geo_accession'].tolist()

# print(groupings)

In [None]:
title_conditions = list(groupings.keys())
title_ctrl = []
for c in title_conditions:
    if len(set(c.split()).intersection(ctrl_words)) > 0:
        title_ctrl.append(c)
        
# print(title_conditions)
# print(title_ctrl)    

In [None]:
og_labels = {}
labled_groupings = {}

for label in groupings:
    samps = groupings[label]
    data = list(map(lambda s: s.lower(), metadata.loc[samps]['characteristics_ch1'].values))
    data_clean = []
    for d in data:
        data_clean.append(set(filter(lambda w: w not in stopwords_plus, re.sub(pattern, ' ', d).split())))
    condition = set(data_clean[0])
    for s in data_clean[1:]:
        condition.intersection_update(s)
    condition = ' '.join(list(condition))
    labled_groupings[condition] = samps
    og_labels[condition] = label

ch1_ctrl = []
ch1_conditions = list(labled_groupings.keys())

for condition in labled_groupings:
    split_conditions = condition.lower().split()
    if len(set(split_conditions).intersection(ctrl_words)) > 0:
        ch1_ctrl.append(condition)

# print(og_labels)
# print(ch1_conditions)
# print(ch1_ctrl)

In [None]:
#must have 1-2 controls. Must have perturbation groups as well (not all groups can be controls).
def check_eligibility(conditions, ctrl_conditions):
    if len(ctrl_conditions) not in range(1, 3) or len(ctrl_conditions) == len(conditions):
        return False
    else:
        return True

In [None]:
ch1_eligibility = check_eligibility(ch1_conditions, ch1_ctrl)
title_eligibility = check_eligibility(title_conditions, title_ctrl)
#print(ch1_eligibility)
#rint(title_eligibility)

In [None]:
#check if title and characteristic assignment determined the same controls
def compare_groups(title_ctrl, ch1_ctrl, og_labels):
    #convert ch1 condition to corresponding title condition, check if their respective sample sets are equal
    for c in ch1_ctrl:
        ch1_corresponding = og_labels[c]
        if set(groupings[ch1_corresponding]) != set(labled_groupings[c]):
            return False

    return True

In [None]:
if ch1_eligibility and title_eligibility:
    if compare_groups(title_ctrl, ch1_ctrl, og_labels):
        ctrl_conditions = title_ctrl
        conditions = title_conditions
    else:
        raise Exception("Group Assignment Failed")
    
elif ch1_eligibility ^ title_eligibility:
    if ch1_eligibility:
        ctrl_conditions = ch1_ctrl
        conditions = ch1_conditions
        groupings = labled_groupings
    else:
        ctrl_conditions = title_ctrl
        conditions = title_conditions
else:
    raise Exception("Group Assignment Failed")

# print(ctrl_conditions)
# print(conditions)

In [None]:
import time
import requests

# Step 1: Send job
url = "https://maayanlab.cloud/sigpy/data/samples"
data = {
    "gsm_ids": [w['Accession'] for w in record[0]['Samples']],
    "species": species
}
response = requests.post(url, json=data)
response.raise_for_status()
task_id = response.json()['task_id']
#print("Task ID:", task_id)

# Step 2: Poll for status
status_url = f"https://maayanlab.cloud/sigpy/data/samples/status/{task_id}"

max_attempts = 10
sleep_sec = 2

for attempt in range(max_attempts):
    status_response = requests.get(status_url)
    status_response.raise_for_status()
    status = status_response.json().get('status')

    #print(f"[{attempt+1}] Status:", status)
    if status == 'SUCCESS':
        break
    elif status == 'FAILURE':
        raise ValueError("Error getting matrix from ARCHS4.")
    
    time.sleep(sleep_sec)
    sleep_sec *= 1.5  # optional backoff
else:
    raise TimeoutError("Timed out waiting for ARCHS4 task to finish.")

# Step 3: Download result
download_url = f"https://maayanlab.cloud/sigpy/data/samples/download/{task_id}"
download_response = requests.get(download_url)
with open("matrix.zip", "wb") as f:
    f.write(download_response.content)

#print("✅ matrix.zip downloaded.")


In [None]:
import zipfile

zip_file_path = 'matrix.zip'

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall()
    extracted_files = zip_ref.namelist()

gene_matrix = pd.read_csv(extracted_files[0], sep='\t')
gene_matrix.set_index(gene_matrix.columns[0], inplace=True)
gene_matrix.index.name = None

os.remove(zip_file_path)

In [None]:
gene_matrix.head(5)

In [None]:
display(Markdown(f"**table {tab_num}**: This is a preview of the first 5 rows of the raw RNA-seq expression matrix from {gse}."))
tab_num +=1
display(FileLink(os.path.join(resource_path, extracted_files[0]), result_html_prefix="Download raw counts: "))

In [None]:
# Remove genes with all-zero counts
filtered_matrix = gene_matrix.loc[gene_matrix.sum(axis=1) > 0, :]

# Then filter out low average expression
filtered_matrix = filtered_matrix.loc[filtered_matrix.mean(axis=1) >= 3, :]

In [None]:
from maayanlab_bioinformatics.normalization.log import log2_normalize
from maayanlab_bioinformatics.normalization.zscore import zscore_normalize 
from maayanlab_bioinformatics.normalization.quantile_legacy import quantile_normalize

def normalize(gene_counts):
    norm_exp = quantile_normalize(gene_counts)
    norm_exp = log2_normalize(norm_exp)
    norm_exp = zscore_normalize(norm_exp)
    return norm_exp


In [None]:
%%capture
norm_matrix = normalize(filtered_matrix) #normalized counts

In [None]:
def annotate_matrix(expr_df, groupings, ctrl_conditions):
    sampdict = {}
    for group in groupings.keys():
        samps = groupings[group]
        for samp in samps:
            if group in ctrl_conditions:
                sampdict[samp] = "control"
            else:
                sampdict[samp] = "perturbation"
    
    annotat = pd.DataFrame.from_dict(sampdict, orient='index', columns=['group'])
    anndict = {
        'count': expr_df,
        'annotations': annotat
    }
    return anndict

In [None]:
annotated_norm_matrix = annotate_matrix(norm_matrix, groupings, ctrl_conditions)

#print(annotated_norm_matrix['annotations'])

In [None]:
annotated_matrix = annotate_matrix(filtered_matrix.astype('int64'), groupings, ctrl_conditions) #filtered but not normalized

# **Results**

In [None]:
save_html = True #if true, render plotly graphs as html and embed with ipython. else, use fig.show()
use_fig_plot = False #if true, render matplotlib graphs using show(), else it will render the saved pngs.

## **Dimensionality Reduction**

In [None]:
import pandas as pd
import plotly.express as px
import dash_bio
import sklearn
from umap import UMAP
import numpy as np
import os
import kaleido

# import plotly.io as pio
# from IPython.display import Image, display

def plot(gene_counts, annotations, save_formats, n_components=2, decomp="pca", save_html=False, save_path=""):
    """
    Plots a dimensionality reduction (PCA, t-SNE, or UMAP) of gene counts.

    Args:
        gene_counts (pd.DataFrame): A normalized gene counts matrix.
        annotations (pd.DataFrame): Maps samples to its experimental group.
        n_components (int): Number of components to plot (2 or 3).
        decomp (str): Which decomposition to use: "pca", "tsne", or "umap".

    Returns:
        None
    """
    decomp = decomp.lower()
    gene_counts = gene_counts.T  # Ensure samples are rows

    if(n_components not in [2, 3]):
        raise ValueError("n_components must be either 2 or 3.")

    # Align annotations to gene_counts
    annotations = annotations.reindex(gene_counts.index)

    decomp = decomp.lower()
    if decomp == "pca":
        model = sklearn.decomposition.PCA(n_components=n_components)
        projections = model.fit_transform(gene_counts)
        total_variance = model.explained_variance_ratio_.sum() * 100
        title = f'PCA of Gene Expression (Total Variance explained: {total_variance:.2f}%)'
        labels = {str(i): f'PC{i+1} (variance explained: {model.explained_variance_ratio_[i]*100:.2f}%)' for i in range(n_components)}

    elif decomp == "tsne":
        perplexity = min(30, gene_counts.shape[0] - 1)
        model = sklearn.manifold.TSNE(n_components=n_components, random_state=42, perplexity=perplexity)
        projections = model.fit_transform(gene_counts.values)
        title = 't-SNE of Gene Expression'
        labels = {str(i): f't-SNE {i+1}' for i in range(n_components)}

    elif decomp == "umap":
        model = UMAP(n_components=n_components, random_state=42)
        projections = model.fit_transform(gene_counts.values)
        title = 'UMAP of Gene Expression'
        labels = {str(i): f'UMAP {i+1}' for i in range(n_components)}

    else:
        raise ValueError("decomp must be one of: 'pca', 'tsne', or 'umap'.")

    if n_components == 2:
        fig = px.scatter(
            projections, x=0, y=1, color=annotations["group"],
            #title=title,
            labels=labels, 
            hover_name=gene_counts.index,
        )
        fig.update_layout(
            width=700,
            height=500,
            plot_bgcolor="rgba(0,0,0,0)",
            showlegend=True,
            legend=dict(
                title="",
                font=dict(
                    size=14,
                    family='Arial'
                )
            ),
            # title=dict(
            #     font=dict(
            #         size=20
            #     ),
            #     x=0.5,
            #     xanchor="center"
            # )
        )
        fig.update_xaxes(
            showline=True,           
            linecolor="black",       
            linewidth=2,             
            showgrid=False,          
            zeroline=False,
            title=dict(
                font=dict(
                    size=20,
                    family='Arial'
                )
            )       
        )
        fig.update_yaxes(
            scaleanchor="x",
            scaleratio=1,
            showline=True,
            linecolor="black",
            linewidth=2,
            showgrid=False,
            zeroline=False,
            title=dict(
                font=dict(
                    size=20,
                    family='Arial'
                ),
                standoff=5
            )
        )
    else: #support for 3D plots
        fig = px.scatter_3d(
            projections, x=0, y=1, z=2, color=annotations["group"],
            title=title, labels=labels, width=600, height=600
        )
        fig.update_scenes(
            aspectmode="cube",
            xaxis=dict(
                showline=True,
                linecolor='lightgrey',
                showbackground=False,
                gridcolor='lightgrey',
                zerolinecolor='black'
            ),
            yaxis=dict(
                showline=True,
                linecolor='lightgrey',
                showbackground=False,
                gridcolor='lightgrey',
                zerolinecolor='black'
            ),
            zaxis=dict(
                showline=True,
                linecolor='lightgrey',
                showbackground=False,
                gridcolor='lightgrey',
                zerolinecolor='black'
            ),
        )

    fig.update_traces(marker=dict(size=20))
    
    for f in save_formats:
        fig_name = decomp + '.' + f
        fig.write_image(os.path.join(save_path, fig_name), scale=2)

    if save_html:
        fig_name = decomp + ".html"
        fig.write_html(os.path.join(save_path, fig_name))
    else:
        fig.show()

def plot_clustergram(gene_counts):
    '''
    Creates a reuasable Plotly Figure object that can be displayed in a Jupyter notebook.

    Args: 
        gene_counts (DataFrame): a normalized, filtered gene count matrix.
    
    Returns:
        plotly.graph_objs._figure.Figure: A clustergram in the form of a Plotly Figure
    '''
    
    clustergram = dash_bio.Clustergram(
        data = gene_counts,
        column_labels = list(gene_counts.columns.values),
        row_labels = list(gene_counts.index),
        color_threshold={
            'row': 250,
            'col': 700
        },
        hidden_labels='row',
        height = 800,
        width = 600,
        color_map= [
            [0.0, '#636EFA'],
            #[0.25, '#AB63FA'],
            [0.5, '#FFFFFF'],
            #[0.75, '#E763FA'],
            [1.0, '#EF553B']
        ],
        row_dist = "cosine",
        col_dist = "cosine",
        link_method="average",
        paper_bg_color='#FFFFFF'
    )
    return clustergram

def plot_volcano(deg, threshold, save_formats, save_name = "volcano", save_html=False, save_path=""):
    
    deg['significance'] = "insignificant"
    deg.loc[(deg['P.Value']<0.05) & (deg['logFC']<-threshold), 'significance'] = "downregulated"
    deg.loc[(deg['P.Value']<0.05) & (deg['logFC']>threshold), 'significance'] = "upregulated"
    
    deg['-log10p'] = -np.log10(deg['P.Value'])

    fig = px.scatter(
        deg,
        x='logFC',
        y='-log10p',
        color='significance',
        color_discrete_map={
            'insignificant': 'black',
            'upregulated': 'red',
            'downregulated': 'blue'
        },
        hover_name=deg.index,
        #title="Control vs. Perturbation Signatures-Volcano Plot",
    )
    #hide insignificant from the legend.
    for trace in fig['data']:
        if trace['name'] == 'insignificant':
            trace['showlegend']=False

    fig.update_layout(
        width=600,
        height=800,
        plot_bgcolor="rgba(0,0,0,0)",
        # title=dict(
        #     font=dict(
        #         size=20
        #     ),
        #     x=0.5,
        #     xanchor='center'
        # ),
        legend=dict(
            title="",
            font=dict(
                size=14,
                family='Arial'
            )
        ),
        font=dict(
            family='Arial'
        )
    )

    max_abs_x = max(abs(deg['logFC'].min()), abs(deg['logFC'].max()))+0.5
    max_y = deg['-log10p'].max()

    fig.update_xaxes(
        range=[-max_abs_x, max_abs_x],
        zerolinecolor="black",
        zerolinewidth=1,
        gridcolor="lightgrey",
        gridwidth=1,
        title=dict(
            font=dict(
                size=20,
            )
        )
    )
    fig.update_yaxes(
        range=[0, max_y * 1.05],
        zerolinecolor="black",
        zerolinewidth=1,
        gridcolor="lightgrey",
        gridwidth=1,
        title=dict(
            font=dict(
                size=20,
            )
        )
    )
    
    for fmt in save_formats:
        fig_name = f"{save_name}.{fmt}"
        fig.write_image(os.path.join(save_path, fig_name), scale=2)

    if save_html:
        fig.write_html(os.path.join(save_path, f"{save_name}.html"))
    else:
        fig.show()


### **UMAP**

In [None]:
plot(annotated_norm_matrix['count'], annotated_norm_matrix['annotations'], n_components=2, save_formats=save_formats, decomp="umap", save_html=save_html, save_path=resource_path)
#if save_html: display(HTML(os.path.join(resource_path, "umap.html")))
display(Image(os.path.join(resource_path, "umap.png"), width=700))
display(Markdown(f"**Figure {fig_num}**: This figure displays a 2D scatter plot of a UMAP decomposition of the sample data. Each point represents an individual sample, colored by its experimental group."))
fig_num+=1

for fmt in save_formats:
    display(FileLink(os.path.join(resource_path, f"umap.{fmt}"), result_html_prefix=f"Download UMAP figure as {fmt}: "))


### **PCA**

In [None]:
plot(annotated_norm_matrix['count'], annotated_norm_matrix['annotations'], n_components=2, save_formats=save_formats, decomp="pca", save_html=save_html, save_path=resource_path)
#if save_html: display(HTML(os.path.join(resource_path, "pca.html")))
display(Image(os.path.join(resource_path, "pca.png"), width=700))
display(Markdown(f"**Figure {fig_num}**: This figure displays a 2D scatter plot of a PCA decomposition of the sample data. Each point represents an individual sample, colored by its experimental group."))
fig_num+=1

for fmt in save_formats:
    display(FileLink(os.path.join(resource_path, f"pca.{fmt}"), result_html_prefix=f"Download PCA figure as {fmt}: "))

### **t-SNE**

In [None]:
plot(annotated_norm_matrix['count'], annotated_norm_matrix['annotations'], n_components=2, save_formats=save_formats, decomp="tsne", save_html=save_html, save_path=resource_path)
#if save_html: display(HTML(os.path.join(resource_path, "tsne.html")))
display(Image(os.path.join(resource_path, "tsne.png"), width=700))
display(Markdown(f"**Figure {fig_num}**: This figure displays a 2D scatter plot using a t-SNE decomposition of the sample data. Each point represents an individual sample, colored by its experimental group."))
fig_num+=1

for fmt in save_formats:
    display(FileLink(os.path.join(resource_path, f"tsne.{fmt}"), result_html_prefix=f"Download t-SNE figure as {fmt}: "))

## **Clustergram Heatmaps**

In [None]:
from maayanlab_bioinformatics.normalization.filter import filter_by_var
norm_t1000 = annotated_norm_matrix['count'].copy()
norm_t1000 = filter_by_var(annotated_norm_matrix['count'], top_n=1000, axis=1)
norm_t1000.columns=metadata['title'].tolist()

In [None]:
t1000_path = os.path.join(resource_path, 'expression_matrix_top1000_genes.txt')
norm_t1000.to_csv(t1000_path, sep='\t')

In [None]:
import requests, json
clustergrammer_url = 'https://maayanlab.cloud/clustergrammer/matrix_upload/'

r = requests.post(clustergrammer_url, files={'file': open(t1000_path, 'rb')})
link = r.text

In [None]:
from IPython.display import IFrame
display(IFrame(link, width="600", height="800"))
display(Markdown(f"**Figure {fig_num}**: The figure contains an interactive heatmap displaying gene expression for each sample in the RNA-seq dataset. Every row of the heatmap represents a gene, every column represents a sample, and every cell displays normalized gene expression values. The heatmap additionally features color bars beside each column which represent prior knowledge of each sample, such as the tissue of origin or experimental treatment."))
fig_num+=1

In [None]:
clustergram = plot_clustergram(norm_t1000)

In [None]:
for fmt in save_formats:
    file_name = os.path.join(resource_path, f"clustergram.{fmt}")
    clustergram.write_image(file_name, width=600, scale = 2)

In [None]:
# if save_html:
#     cluster_path = os.path.join(resource_path, "clustergram.html")
#     clustergram.write_html(cluster_path)
#     display(HTML(cluster_path))
# else:
#     clustergram.show()

display(Image(os.path.join(resource_path, "clustergram.png"), width=700))

In [None]:
display(Markdown(f"**Figure {fig_num}**: this figure is a clustergram produced with the graphing library Plotly. It sacrifices some interactivity for a more polished look."))
fig_num+=1

for fmt in save_formats:
    file_name = os.path.join(resource_path, f"clustergram.{fmt}")
    #clustergram.write_image(file_name, width=600, height=600)
    display(FileLink(file_name, result_html_prefix=f"Download as {fmt}: "))

## **Differentially Expressed Genes Calculation and Volcano Plots**

In [None]:
from maayanlab_bioinformatics.dge.limma_voom import limma_voom_differential_expression

sig_names = []
matrix = annotated_matrix['count']
dges = {}

seen = []
for condition in ctrl_conditions:
    for condition2 in conditions:
        if condition!=condition2 and {condition, condition2} not in seen:
            seen.append({condition, condition2})
            
            sig_name = f'{condition}-vs-{condition2}.tsv'
            sig_names.append(sig_name)
            try:
                        with suppress_output():
                            dge = limma_voom_differential_expression(
                                matrix[groupings[condition]],
                                matrix[groupings[condition2]],
                                voom_design=True,
                            )
                        if not dge.empty:
                            dge['logFC'] = dge['logFC'].round(2)
                            dge['AveExpr'] = dge['AveExpr'].round(2)
                            dge['t'] = dge['t'].round(2)
                            dge['B'] = dge['B'].round(2)
                            dges[sig_name] = dge
                            dge.to_csv(os.path.join(resource_path, sig_name), sep='\t')
                        else:
                            print('Empty dge returned for', sig_name)
            except Exception as e:
                print(e)
                print('Error computing:', sig_name)

In [None]:
for sig_name in sig_names:
    dge_path = os.path.join(resource_path, sig_name)
    # table = pd.read_csv(dge_path, sep="\t")

    table = dges[sig_name]
    display(table.head(5))
    display(Markdown(f"**Table {tab_num}**: This is a preview of the first 5 rows of the differentially expressed gene table calculated by Limma Voom."))
    display(FileLink(dge_path, result_html_prefix="Download DGE table: "))
    tab_num+=1

In [None]:
threshold=1.0

In [None]:
upreg = {}
downreg = {}

upreg_t500 = {}
downreg_t500 = {}

for sig_name in sig_names:
    #dge = pd.read_csv(os.path.join(resource_path, sig_name), sep="\t").set_index("gene_symbol")
    dge = dges[sig_name]

    sig_name = sig_name.replace(".tsv", "")

    up_genes = dge.loc[(dge['P.Value']<0.05) & (dge['logFC']>threshold), :].index.tolist() 
    down_genes = dge.loc[(dge['P.Value']<0.05) & (dge['logFC']<-threshold), :].index.tolist()

    upreg[sig_name] = up_genes
    downreg[sig_name] = down_genes

    up_genes_t500 = dge.loc[(dge['P.Value']<0.05)].sort_values(by="logFC", ascending=False)[:500].index.tolist()
    down_genes_t500 = dge.loc[(dge['P.Value']<0.05)].sort_values(by="logFC")[:500].index.tolist()

    upreg_t500[sig_name] = up_genes_t500
    downreg_t500[sig_name] = down_genes_t500

    save_name = f"{sig_name}_volcano"

    display(Markdown(f"**{sig_name}**"))
    plot_volcano(dge, threshold=threshold, save_formats=save_formats, save_name = save_name, save_html=save_html, save_path=resource_path)
    #if save_html: display(HTML(os.path.join(resource_path, f"{save_name}.html")))
    display(Image(os.path.join(resource_path, f"{save_name}.png"), width=700))
    display(Markdown(f"**Figure {fig_num}**: The figure contains an interactive scatter plot which displays the log2-fold changes and statistical significance of each gene calculated by performing a differential gene expression analysis for the comparison {sig_name}. Every point in the plot represents a gene. Red points indicate significantly up-regulated genes, blue points indicate down-regulated genes."))
    fig_num+=1

    for fmt in save_formats:
        file_name=os.path.join(resource_path, f"{save_name}.{fmt}")
        display(FileLink(file_name, result_html_prefix=f"Download volcano plot as {fmt}: "))

In [None]:
sig_names_clean = [name.replace('.tsv', '') for name in sig_names]

## **Enrichr: Enrichment Analysis**

In [None]:
import pandas as pd 
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import seaborn as sns
import time
from matplotlib.ticker import MaxNLocator
from IPython.display import display,FileLink, Markdown

annot_dict = {}

# Function to get Enrichr Results 
# Takes a gene list and Enrichr libraries as input 
def Enrichr_API(enrichr_gene_list, all_libraries):


    all_terms = []
    all_pvalues =[] 
    all_adjusted_pvalues = []
    library_success = []
    short_id = ''

    for library_name in all_libraries : 
        ENRICHR_URL = 'http://amp.pharm.mssm.edu/Enrichr/addList'
        genes_str = '\n'.join(enrichr_gene_list)
        description = 'Example gene list'
        payload = {
            'list': (None, genes_str),
            'description': (None, description)
        }

        response = requests.post(ENRICHR_URL, files=payload)
        if not response.ok:
            raise Exception('Error analyzing gene list')

        data = json.loads(response.text)
        time.sleep(0.5)
        ENRICHR_URL = 'http://amp.pharm.mssm.edu/Enrichr/enrich'
        query_string = '?userListId=%s&backgroundType=%s'
        user_list_id = data['userListId']
        short_id = data["shortId"]
        gene_set_library = library_name
        response = requests.get(
            ENRICHR_URL + query_string % (user_list_id, gene_set_library)
         )
        if not response.ok:
            raise Exception('Error fetching enrichment results')
        try:
            data = json.loads(response.text)
            results_df  = pd.DataFrame(data[library_name][0:5])
            all_terms.append(list(results_df[1]))
            all_pvalues.append(list(results_df[2]))
            all_adjusted_pvalues.append(list(results_df[6]))
            library_success.append(library_name)
        except:
            print('Error for ' + library_name + ' library')

    return([all_terms,all_pvalues,all_adjusted_pvalues,str(short_id),library_success])



def enrichr_figure(all_terms,all_pvalues, all_adjusted_pvalues, plot_names, all_libraries, fig_format, bar_color, show_plot=True): 
    
    # rows and columns depend on number of Enrichr libraries submitted 
    rows = []
    cols = []
    
    # Bar colors
    if bar_color!= 'lightgrey':
        bar_color_not_sig = 'lightgrey'
        edgecolor=None
        linewidth=0
    else:
        bar_color_not_sig = 'white'
        edgecolor='black'
        linewidth=1
    
    # If only 1 Enrichr library selected, make simple plot 
    if len(all_libraries)==1:
        #fig,axes = plt.subplots(1, 1,figsize=[8.5,6])
        rows = [0]
        cols = [0]
        i = 0 
        bar_colors = [bar_color if (x < 0.05) else bar_color_not_sig for x in all_pvalues[i]]
        print(type(bar_colors))
        fig = sns.barplot(x=np.log10(all_pvalues[i])*-1, y=all_terms[i], palette=bar_colors, edgecolor=edgecolor, linewidth=linewidth)
        fig.axes.get_yaxis().set_visible(False)
        fig.set_title(all_libraries[i].replace('_',' '),fontsize=26)
        fig.set_xlabel('-Log10(p-value)',fontsize=25)
        fig.xaxis.set_major_locator(MaxNLocator(integer=True))
        fig.tick_params(axis='x', which='major', labelsize=20)
        if max(np.log10(all_pvalues[i])*-1)<1:
            fig.xaxis.set_ticks(np.arange(0, max(np.log10(all_pvalues[i])*-1), 0.1))
        for ii,annot in enumerate(all_terms[i]):
            if annot in annot_dict.keys():
                annot = annot_dict[annot]
            if all_adjusted_pvalues[i][ii] < 0.05:
                annot = '  *'.join([annot, str(str(np.format_float_scientific(all_pvalues[i][ii],precision=2)))]) 
            else:
                annot = '  '.join([annot, str(str(np.format_float_scientific(all_pvalues[i][ii],precision=2)))])

            title_start= max(fig.axes.get_xlim())/200
            fig.text(title_start,ii,annot,ha='left',wrap = True, fontsize = 26) #adjust font size
            fig.patch.set_edgecolor('black')  
            fig.patch.set_linewidth('2')
        
    
    # If there are an even number of Enrichr libraries below 6
    # Plots 1x2 or 2x2
    else:
        if len(all_libraries) % 2 == 0 and len(all_libraries) < 5:
                for i in range(0,int(len(all_libraries)/2)):    
                    rows = rows + [i]*2
                    cols = list(range(0,2))*int(len(all_libraries)/2)    
                fig, axes = plt.subplots(len(np.unique(rows)), len(np.unique(cols)),figsize=[7,int(2* len(np.unique(rows)))]) 
    
        
        # All other # of libraries 6 and above will have 3 columns and a flexible number of rows to accomodate all plots
        else:
            for i in range(0,int(np.ceil(len(all_libraries)/2))):
                rows = rows + [i]*2
                cols = list(range(0,2))*int(np.ceil(len(all_libraries)/2))
            fig, axes = plt.subplots(len(np.unique(rows)), len(np.unique(cols)),figsize=[8,int(2* len(np.unique(rows)))])
           
        # If final figure only has one row...
        if len(np.unique(rows))==1:
            for i,library_name in enumerate(all_libraries):
                bar_colors = [bar_color if (x < 0.05) else bar_color_not_sig for x in all_pvalues[i]]
                sns.barplot(x=np.log10(all_pvalues[i])*-1, y=all_terms[i],ax=axes[i], palette=bar_colors, edgecolor=edgecolor, linewidth=linewidth)
                axes[i].axes.get_yaxis().set_visible(False)
                axes[i].set_title(library_name.replace('_',' '),fontsize=36)
                axes[i].set_xlabel('-Log10(p-value)',fontsize=35)
                axes[i].xaxis.set_major_locator(MaxNLocator(integer=True))
                axes[i].tick_params(axis='x', which='major', labelsize=30)
                if max(np.log10(all_pvalues[i])*-1)<1:
                    axes[i].xaxis.set_ticks(np.arange(0, max(np.log10(all_pvalues[i])*-1), 0.1))
                for ii,annot in enumerate(all_terms[i]):
                    if annot in annot_dict.keys():
                        annot = annot_dict[annot]
                    if all_adjusted_pvalues[i][ii] < 0.05:
                        annot = '  *'.join([annot, str(str(np.format_float_scientific(all_pvalues[i][ii],precision=2)))]) 
                    else:
                        annot = '  '.join([annot, str(str(np.format_float_scientific(all_pvalues[i][ii],precision=2)))])

                    title_start= max(axes[i].axes.get_xlim())/200
                    axes[i].text(title_start,ii,annot,ha='left',wrap = True, fontsize = 36)
                    axes[i].patch.set_edgecolor('black')  
                    axes[i].patch.set_linewidth('2')

            plt.subplots_adjust(top=4.5, right = 4.7,wspace = 0.03,hspace = 0.2)


        # If the final figure has more than one row...
        else:


            for i,library_name in enumerate(all_libraries):
                bar_colors = [bar_color if (x < 0.05) else bar_color_not_sig for x in all_pvalues[i]]
                sns.barplot(x=np.log10(all_pvalues[i])*-1, y=all_terms[i],ax=axes[rows[i],cols[i]], palette=bar_colors, edgecolor=edgecolor, linewidth=linewidth)
                axes[rows[i],cols[i]].axes.get_yaxis().set_visible(False)
                axes[rows[i],cols[i]].set_title(library_name.replace('_',' '),fontsize=36)
                axes[rows[i],cols[i]].set_xlabel('-Log10(p-value)',fontsize=35)
                axes[rows[i],cols[i]].xaxis.set_major_locator(MaxNLocator(integer=True))
                axes[rows[i],cols[i]].tick_params(axis='x', which='major', labelsize=30)
                if max(np.log10(all_pvalues[i])*-1)<1:
                    axes[rows[i],cols[i]].xaxis.set_ticks(np.arange(0, max(np.log10(all_pvalues[i])*-1), 0.1))
                for ii,annot in enumerate(all_terms[i]):
                    if annot in annot_dict.keys():
                        annot = annot_dict[annot]
                    if all_adjusted_pvalues[i][ii] < 0.05:
                        annot = '  *'.join([annot, str(str(np.format_float_scientific(all_pvalues[i][ii],precision=2)))]) 
                    else:
                        annot = '  '.join([annot, str(str(np.format_float_scientific(all_pvalues[i][ii],precision=2)))])

                    title_start= max(axes[rows[i],cols[i]].axes.get_xlim())/200
                    axes[rows[i],cols[i]].text(title_start,ii,annot,ha='left',wrap = True, fontsize = 30) #control bar text font size here
                    axes[rows[i],cols[i]].patch.set_edgecolor('black')  
                    axes[rows[i],cols[i]].patch.set_linewidth('2')

            plt.subplots_adjust(top=4.8, right = 4.7,wspace = 0.03,hspace = 0.2)

        # If >6 libraries are chosen and is not a multiple of 2, delete empty plots
        if len(np.unique(rows))*len(np.unique(cols)) != len(all_libraries):
            diff = (len(np.unique(rows))*len(np.unique(cols))) - len(all_libraries)
            for i in range (1,int(diff+1)):
                fig.delaxes(axes[rows[-i]][cols[-i]])
    
    # Save results 
    for plot_name in plot_names: 
        plt.savefig(plot_name,bbox_inches = 'tight')

    # Show plot 
    if show_plot:
        plt.show()
    else:
        plt.close()

In [None]:
enrichr_libraries = ["ChEA_2022", "ARCHS4_TFs_Coexp", "Reactome_Pathways_2024", "MGI_Mammalian_Phenotype_Level_4_2024", "GO_Biological_Process_2025", "GWAS_Catalog_2023"]
if species == "human":
    enrichr_libraries.extend(["WikiPathways_2024_Human", "KEGG_2021_Human"])
elif species == "mouse":
    enrichr_libraries.extend(["WikiPathways_2024_Mouse", "KEGG_2019_Mouse"])
else:
    raise Exception("Species not supported.")

enrichr_libraries.sort()

figure_file_format = save_formats

color = "tomato"

### **Upregulated Set**

In [None]:
from IPython.display import Image
#upregulated results
for sig_name in sig_names_clean:
    up_file_name = sig_name + ' up_enrichr_results'
    final_output_file_names_up = [str(os.path.join(resource_path, up_file_name+'.'+file_type)) for file_type in figure_file_format]
    uresults = Enrichr_API(upreg[sig_name], enrichr_libraries)
    display(Markdown(f"#### **{sig_name}**"))
    enrichr_figure(uresults[0],uresults[1],uresults[2],final_output_file_names_up, uresults[4],figure_file_format, color, show_plot=False)
    display(Image(final_output_file_names_up[0], width=600)) #display the PNG
    display(Markdown(f'**Figure {fig_num}**: This figure contains several barplots depicting enrichment analysis results on the upregulated gene set. Each barplot corresponds to an individual library from Enrichr, and the top matching terms by p-value are depicted in each. Statistically significant terms are represented as red bars while others are represented as gray. Access your Enrichment results here: ' + str('https://amp.pharm.mssm.edu/Enrichr/enrich?dataset='+ uresults[3])))
    fig_num+=1

    for name in final_output_file_names_up: 
        display(FileLink(name, result_html_prefix=f"Download figure as {name[name.rfind('.')+1:]}:"))

### **Downregulated Set**

In [None]:
#downregulated results
for sig_name in sig_names_clean:
    dn_file_name = sig_name + ' dn_enrichr_results'
    final_output_file_names_dn = [str(os.path.join(resource_path, dn_file_name+'.'+file_type)) for file_type in figure_file_format]
    dresults = Enrichr_API(downreg[sig_name], enrichr_libraries)
    display(Markdown(f"#### **{sig_name}**"))
    enrichr_figure(dresults[0],dresults[1],dresults[2],final_output_file_names_dn, dresults[4],figure_file_format, color, show_plot=False)
    display(Image(final_output_file_names_dn[0], width=600)) #display the PNG
    display(Markdown(f'**Figure {fig_num}**: This figure contains several barplots depicting enrichment analysis results on the upregulated gene set. Each barplot corresponds to an individual library from Enrichr, and the top matching terms by p-value are depicted in each. Statistically significant terms are represented as red bars while others are represented as gray. Access your Enrichment results here: ' + str('https://amp.pharm.mssm.edu/Enrichr/enrich?dataset='+ uresults[3])))
    fig_num+=1

    for name in final_output_file_names_dn: 
        display(FileLink(name, result_html_prefix=f"Download figure as {name[name.rfind('.')+1:]}:"))

## **CHEA3: Transcription Factor Enrichment Analysis**

In [None]:
import json
import requests
import numpy as np
from time import sleep
from tabulate import tabulate
from IPython.display import HTML, display, Image, FileLink, Markdown
import plotly.graph_objects as go
import kaleido
import os

num_tfs = 10
threshold =3

def get_chea3_results(gene_set, query_name):
    ADDLIST_URL = 'https://maayanlab.cloud/chea3/api/enrich/'
    payload = {
        'gene_set': gene_set,
        'query_name': query_name
    }
    response = requests.post(ADDLIST_URL, data=json.dumps(payload))
    if not response.ok: 
        # r.ok (where r is the object) returns whether the call to the url was successful
        raise Exception('Error analyzing gene list')
    sleep(1)
    return json.loads(response.text) # .text returns the content of response in unicode

# Function for displaying tables 
def display_tables(lib, description, results):
    
    for libname in lib:
        display(HTML(f'<h3>{libname}</h3>'))
        
        table = [0] * num_tfs
        tablecounter = 0
        for i in results[libname][0:num_tfs]:
            table[tablecounter] = [i['Rank'],
                                   i['TF'],
                                   f"{i['Intersect']}/{i['Set length']}", 
                                   i['FET p-value'], 
                                   i['FDR'], 
                                   i['Odds Ratio'],
                                   f"{', '.join(i['Overlapping_Genes'].split(',')[0:10])}, ..."]
            tablecounter += 1

        display(HTML(tabulate(table, 
                              ['Rank', 
                               'TF', 
                               'Overlap', 
                               'FET p-value', 
                               'FDR', 
                               'Odds Ratio', 
                               'Overlapping Genes'], 
                              tablefmt='html')))
        
        display(HTML(f'<h5>{description[libname]}</h5>'))
        
        tsv_name = f"{libname.replace(' ', '_')}.tsv"
        with open(tsv_name, 'w') as tsv_file:
            tsv_file.write(tabulate(table, ['Rank', 
                                            'TF',
                                            'Overlap', 
                                            'FET p-value', 
                                            'FDR', 
                                            'Odds Ratio', 
                                            'Overlapping Genes'], 
                                    tablefmt='tsv'))
        display(HTML(f'<a href="{tsv_name}">Download table in .tsv</a>'))
        
        
# Function for displaying the individual library bar charts 
def display_charts(libs, description, results): 
    for libname in libs:
        
        display(HTML(f'<h3>{libname}</h3>'))
        
        tfs = [i['TF'] for i in results[libname]][0:num_tfs]
        scores = [float(i['FET p-value']) for i in results[libname]][0:num_tfs]
        
        # reverse the order/ranking of the tfs (and their respective scores)
        tfs = tfs[::-1]
        scores = scores[::-1]

        # takes the -log of the scores
        scores = -np.log10(scores)

        
        score_range = max(scores) - min(scores)
        x_lowerbound = min(scores) - (score_range * 0.05)
        x_upperbound = max(scores) + (score_range * 0.05)
        
        libfig = go.Figure(data = go.Bar(name = libname, 
                                         x = scores, 
                                         y = tfs, 
                                         marker = go.bar.Marker(color = 'rgb(255,127,80)'), 
                                         orientation = 'h'))
        libfig.update_layout(
            title = {
                'text':'Bar Chart of Scores based on FET p-values',
                'y': 0.87,
                'x': 0.5,
                'xanchor':'center',
                'yanchor':'top'
            },
            xaxis_title = '-log\u2081\u2080(FET p-value)', 
            # \u208 unicode to get the subscript (need a subscript of "10")
            yaxis_title = 'Transcription Factors',
            font = dict(
                size = 16,
                color = 'black'
            )
        )
        
        libfig.update_xaxes(range = [x_lowerbound, x_upperbound])
        
        libfig.show()
        
        display(HTML(f'<h5>{description[libname]}</h5>'))
        
def indexfinder(lib_score_list, value):
    index = 1
    for num in lib_score_list:
        if num == value:
            return index
        elif num != 0:
            index += 1

def mean_rank_bar(results, save_name, save_formats, save_html=False, save_path=""):
    c_lib_palette = {'ARCHS4 Coexpression':'rgb(196, 8, 8)',
                 'ENCODE ChIP-seq':'rgb(244, 109, 67)',
                 'Enrichr Queries':'rgb(242, 172, 68)', 
                 'GTEx Coexpression':'rgb(236, 252, 68)',
                 'Literature ChIP-seq':'rgb(165, 242, 162)',
                 'ReMap ChIP-seq':'rgb(92, 217, 78)'}
    # this sets all the color values for all the libraries that will be displayed in the bar chart

    # NOTE: removed Integrated mean/topRank since those are compiled from the above 6 libraries 
    # afterwards and so none of the TFs will have Integrated mean/topRank as one of their libraries

    c_lib_means = {'ARCHS4 Coexpression': [0] * num_tfs, 'ENCODE ChIP-seq': [0] * num_tfs, 
                'Enrichr Queries': [0] * num_tfs, 'GTEx Coexpression': [0] * num_tfs,
                'Literature ChIP-seq': [0] * num_tfs, 'ReMap ChIP-seq': [0] * num_tfs}
    # creates a dictionary where each library is a key, and the values are empty lists with as
    # many indices/spaces as the user has requested transcription factors (ex: if the user
    # requests 15 TFs to be returned, the lists will have 15 spaces)


    libs_sorted = ['ARCHS4 Coexpression','ENCODE ChIP-seq','Enrichr Queries',
                'GTEx Coexpression','Literature ChIP-seq','ReMap ChIP-seq']



    mr_results = results['Integrated--meanRank']
    ###### NOTE: for meanRank, the TFs are already ranked by Score ######

    for i in range(len(mr_results)):
        for lib in libs_sorted:
            mr_results[i].update({lib:0})
            
    for i in range(len(mr_results)):
        thing = mr_results[i]['Library'].split(';')
        for a in range(len(thing)):
            library, value = thing[a].split(',')
            mr_results[i].update({library:int(value)})
        
    sortedARCHS4 = sorted(mr_results, key = lambda k: k['ARCHS4 Coexpression'])
    sortedGTEx = sorted(mr_results, key = lambda k: k['GTEx Coexpression']) 
    sortedEnrichr = sorted(mr_results, key = lambda k: k['Enrichr Queries']) 
    sortedENCODE = sorted(mr_results, key = lambda k: k['ENCODE ChIP-seq']) 
    sortedReMap = sorted(mr_results, key = lambda k: k['ReMap ChIP-seq']) 
    sortedLit = sorted(mr_results, key = lambda k: k['Literature ChIP-seq']) 

    rankedARCHS4 = [entry['ARCHS4 Coexpression'] for entry in sortedARCHS4]
    rankedENCODE = [entry['ENCODE ChIP-seq'] for entry in sortedENCODE]
    rankedEnrichr = [entry['Enrichr Queries'] for entry in sortedEnrichr] 
    rankedGTEx = [entry['GTEx Coexpression'] for entry in sortedGTEx]
    rankedLit = [entry['Literature ChIP-seq'] for entry in sortedLit]
    rankedReMap = [entry['ReMap ChIP-seq'] for entry in sortedReMap] 


    ranking_dict = {'ARCHS4 Coexpression':rankedARCHS4,
                    'ENCODE ChIP-seq':rankedENCODE,
                    'Enrichr Queries':rankedEnrichr,
                    'GTEx Coexpression':rankedGTEx,
                    'Literature ChIP-seq':rankedLit,
                    'ReMap ChIP-seq':rankedReMap}

    for tfentry in mr_results:
        tfentry.update( [('SumRank', 0), ('AvgRank', 0) ])
        library_scores = tfentry['Library'].split(';')
        lib_counter = 0
        for a in library_scores:
            l, v = a.split(',')
            v = int(v)
            #scorerank = ranking_dict[l].index(v) + 1
            scorerank = indexfinder(ranking_dict[l], int(v))
            tfentry['SumRank'] += int(scorerank)
            lib_counter += 1
        tfentry['AvgRank'] = (tfentry['SumRank'] / lib_counter)
        
    sorted_results = sorted(mr_results, key = lambda k: k['AvgRank'])

    sorted_top_results = []
    index = 0
    while (len(sorted_top_results) < num_tfs):
        if len(sorted_results[index]['Library'].split(';')) >= threshold:
            sorted_top_results.append(sorted_results[index])
        index += 1
        # moves on to the next index
        
    sorted_top_results = sorted_top_results[::-1]

    # set up a list with all the TFs, sorted by rank (lowest to highest, in line with top_results)
    sorted_tfs = []
    for i in range(0, len(sorted_top_results)):
        sorted_tfs.append(sorted_top_results[i].get('TF'))
        # this pulls only the TF name from top_results and adds it to sorted_tfs

    for i, tfentry in enumerate(sorted_top_results):
        libscores = tfentry['Library'].split(';')
        for a in libscores:
            lib, value = a.split(',')
            rank = indexfinder(ranking_dict[lib], int(value))
            avg = tfentry['AvgRank']
            tot = tfentry['SumRank']
            bar_length = (rank*avg)/tot
            c_lib_means[lib][i] = float(bar_length)

    fig = go.Figure(data = [go.Bar(name = c_lib, 
                                x = c_lib_means[c_lib], 
                                y = sorted_tfs,
                                marker = go.bar.Marker(color = c_lib_palette[c_lib]), 
                                orientation = 'h') 
                            for c_lib in libs_sorted])

    fig.update_layout(barmode = 'stack')
    fig.update_layout(
        title = {
            #'text': 'Stacked Bar Chart of Average Ranks in Different Libraries',
            'text': '',
            'y': 0.87,
            'x': 0.5,
            'xanchor': 'center',
            'yanchor': 'top',
        },
        xaxis_title = 'Average of Ranks Across All Libraries',
        yaxis_title = 'Transcription Factors',
        font = dict(
            size = 13,
            color = 'black',
            family = 'Arial'
        ),
        width=700,
        margin=dict(
            t=30
        )
    )

    for fmt in save_formats: 
        file_name = save_name+'.'+fmt
        fig.write_image(os.path.join(save_path, file_name), scale=2)

    if save_html:
        fig.write_html(os.path.join(save_path, f"{save_name}.html"))
    else:
        fig.show()

### **Upregulated Set**

In [None]:
#TFs of upregulated genes
for sig_name in sig_names_clean:
    save_name = sig_name + '_upchea'
    up_results = get_chea3_results(upreg[sig_name], 'query')
    mean_rank_bar(up_results, save_name=save_name, save_formats=save_formats, save_html=save_html, save_path=resource_path)
    
    display(Markdown(f"#### **{sig_name}**"))
    #if save_html: display(HTML(os.path.join(resource_path, f"{save_name}.html")))
    display(Image(os.path.join(resource_path, f"{save_name}.png"), width=700))
    display(Markdown(f"**Figure {fig_num}**: Horizontal bar chart, y-axis represents transcription factors. Displays the top ranked transcription factors for the upregulated set according to their average integrated scores across all the libraries."))
    fig_num+=1

    for fmt in save_formats:
        file_name=os.path.join(resource_path, f"{save_name}.{fmt}")
        display(FileLink(file_name, result_html_prefix=f"Download bar plot as {fmt}: "))


### **Downregulated Set**

In [None]:
for sig_name in sig_names_clean:
    save_name = sig_name + '_dnchea'
    dn_results = get_chea3_results(downreg[sig_name], 'query')
    mean_rank_bar(dn_results, save_name=save_name, save_formats=save_formats, save_html=save_html, save_path=resource_path)
    
    display(Markdown(f"#### **{sig_name}**"))
    #if save_html: display(HTML(os.path.join(resource_path, f"{save_name}.html")))
    display(Image(os.path.join(resource_path, f"{save_name}.png"), width=700))
    display(Markdown(f"**Figure {fig_num}**: Horizontal bar chart, y-axis represents transcription factors. Displays the top ranked transcription factors for the upregulated set according to their average integrated scores across all the libraries."))
    fig_num+=1

    for fmt in save_formats:
        file_name=os.path.join(resource_path, f"{save_name}.{fmt}")
        display(FileLink(file_name, result_html_prefix=f"Download bar plot as {fmt}: "))

## **L2S2 and DRUG-seqr: Reverser and Mimicker Drugs**

In [None]:
import json
import time
import requests
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# from matplotlib_venn import venn2
from matplotlib.ticker import MaxNLocator
import seaborn as sns
# from scipy.stats import fisher_exact
from IPython.display import HTML, display, Markdown, FileLink, Image
import os
pd.set_option('display.float_format', '{:.2e}'.format)

class druganalysis:
    def __init__(self, geneset, geneset_dn, save_path, save_name="", direction="down-regulators", tab_num=1, fig_num=1):
        self.direction=direction
        self.save_path=save_path
        self.save_name = save_name
        self.fig_num=fig_num
        self.tab_num=tab_num
        if self.direction == 'up-regulators' or self.direction == 'mimickers':
            self.direction_str = 'up'
        else:
            self.direction_str = 'down'

        self.geneset = self.get_l2s2_valid_genes(geneset)
        self.geneset_dn = self.get_l2s2_valid_genes(geneset_dn)

        if len(geneset) == 0 or len(geneset_dn) == 0:
            raise ValueError("Insufficient genes in the input gene sets that overlap with the L2S2 database.")
        
        self.l2s2_geneset_up_id, self.l2s2_geneset_dn_id = self.add_user_geneset(self.geneset, geneset_dn=self.geneset_dn)
        self.drugseqr_geneset_up_id, self.drugseqr_geneset_dn_id = self.add_user_geneset(self.geneset, geneset_dn=self.geneset_dn, url="http://drugseqr.maayanlab.cloud/graphql")
        
        self.l2s2_df = self.enrich_up_down(self.geneset, self.geneset_dn, first=500).dropna()
        self.drugseqr_df = self.enrich_up_down(self.geneset, self.geneset_dn, url="http://drugseqr.maayanlab.cloud/graphql", first=500).dropna()

        self.l2s2_df_nofda = self.enrich_up_down(self.geneset, self.geneset_dn, first=500, fda_approved=False).dropna()
        self.drugseqr_df_nofda = self.enrich_up_down(self.geneset, self.geneset_dn, url="http://drugseqr.maayanlab.cloud/graphql", first=500, fda_approved=False).dropna()
        
        # if self.l2s2_df.empty and self.drugseqr_df.empty:
        #     raise ValueError("No results found for the provided gene set(s).")

        # self.l2s2_df['perturbation'] =self.l2s2_df['term'].apply(lambda x: x.split('_')[4].lower() if len(x.split('_')) > 4 else None)
        # self.drugseqr_df['perturbation'] = self.drugseqr_df['term'].apply(lambda x: x.split('_')[0].lower() if len(x.split('_')) > 0 else None)
        # self.l2s2_df_nofda['perturbation'] = self.l2s2_df_nofda['term'].apply(lambda x: x.split('_')[4] if len(x.split('_')) > 4 else None)
        # self.drugseqr_df_nofda['perturbation'] = self.drugseqr_df_nofda['term'].apply(lambda x: x.split('_')[0].lower() if len(x.split('_')) > 0 else None)

        if not self.l2s2_df.empty:
            self.l2s2_df['perturbation'] =self.l2s2_df['term'].apply(lambda x: x.split('_')[4].lower() if len(x.split('_')) > 4 else None)
        # else:
        #     print("no FDA-approved L2S2 drugs.")

        if not self.l2s2_df_nofda.empty:
            self.l2s2_df_nofda['perturbation'] = self.l2s2_df_nofda['term'].apply(lambda x: x.split('_')[4] if len(x.split('_')) > 4 else None)
        # else:
        #     print("no L2S2 drugs.")

        if not self.drugseqr_df.empty:
            self.drugseqr_df['perturbation'] = self.drugseqr_df['term'].apply(lambda x: x.split('_')[0].lower() if len(x.split('_')) > 0 else None)
        # else:
        #     print("no FDA-approved DRUG-seqr drugs.")

        if not self.drugseqr_df_nofda.empty:
            self.drugseqr_df_nofda['perturbation'] = self.drugseqr_df_nofda['term'].apply(lambda x: x.split('_')[0].lower() if len(x.split('_')) > 0 else None)
        # else:
        #     print("no DRUG-seqr drugs.")
        


    def enrich_single_set(self, geneset: list, first=500, url="http://l2s2.maayanlab.cloud/graphql", fda_approved=True):
        query = {
        "operationName": "EnrichmentQuery",
        "variables": {
            "filterTerm": f" {self.direction_str}",
            "offset": 0,
            "first": first,
            "filterFda": fda_approved,
            "sortBy": "pvalue",
            "genes": geneset,
        },
        "query": """query EnrichmentQuery(
                        $genes: [String]!
                        $filterTerm: String = ""
                        $offset: Int = 0
                        $first: Int = 10
                        $filterFda: Boolean = false
                        $sortBy: String = ""
                        ) {
                        currentBackground {
                            enrich(
                            genes: $genes
                            filterTerm: $filterTerm
                            offset: $offset
                            first: $first
                            filterFda: $filterFda
                            sortby: $sortBy
                            ) {
                            nodes {
                                geneSetHash
                                pvalue
                                adjPvalue
                                oddsRatio
                                nOverlap
                                geneSets {
                                nodes {
                                    term
                                    id
                                    nGeneIds
                                    geneSetFdaCountsById {
                                    nodes {
                                        approved
                                        count
                                    }
                                    }
                                }
                                totalCount
                                }
                            }
                            totalCount
                            }
                        }
                        }
                        """,
        }

        headers = {
            "Accept": "application/json",
            "Content-Type": "application/json"
        }

        response = requests.post(url, data=json.dumps(query), headers=headers)
        response.raise_for_status()
        res = response.json()

        enrichment = res['data']['currentBackground']['enrich']['nodes']# %%

        df_enrichment = pd.json_normalize(
            enrichment, 
            record_path=['geneSets', 'nodes'], 
            meta=['geneSetHash', 'pvalue', 'adjPvalue', 'oddsRatio', 'nOverlap']
        )

        if df_enrichment.empty:
            return pd.DataFrame()
        
        df_enrichment["approved"] = df_enrichment["geneSetFdaCountsById.nodes"].map(lambda x: x[0]['approved'] if len(x) > 0 else False)
        df_enrichment["count"] = df_enrichment["geneSetFdaCountsById.nodes"].map(lambda x: x[0]['count'] if len(x) > 0 else 0)
        df_enrichment.drop(columns=['geneSetFdaCountsById.nodes'], inplace=True)

        return df_enrichment

    def enrich_up_down(self, genes_up: list[str], genes_down: list[str], first=500, url="http://l2s2.maayanlab.cloud/graphql", fda_approved=True):
        query = {
            "operationName": "PairEnrichmentQuery",
            "variables": {
            "filterTerm": f" {self.direction_str}",
            "offset": 0,
            "first": first,
            "filterFda": fda_approved,
            "sortBy": "pvalue_mimic" if self.direction_str == "up" else "pvalue_reverse",
            "pvalueLe": 0.05,
            "genesUp": genes_up,
            "genesDown": genes_down
            },
            "query": """query PairEnrichmentQuery($genesUp: [String]!, $genesDown: [String]!, $filterTerm: String = "", $offset: Int = 0, $first: Int = 10, $filterFda: Boolean = false, $sortBy: String = "", $pvalueLe: Float = 0.05) {{
                            currentBackground {{
                                {}(
                                filterTerm: $filterTerm
                                offset: $offset
                                first: $first
                                filterFda: $filterFda
                                sortby: $sortBy
                                pvalueLe: $pvalueLe
                                genesDown: $genesDown
                                genesUp: $genesUp
                                ) {{
                                totalCount
                                nodes {{
                                    adjPvalueMimic
                                    adjPvalueReverse
                                    mimickerOverlap
                                    oddsRatioMimic
                                    oddsRatioReverse
                                    pvalueMimic
                                    pvalueReverse
                                    reverserOverlap
                                    geneSet {{
                                    nodes {{
                                        id
                                        nGeneIds
                                        term
                                        geneSetFdaCountsById {{
                                        nodes {{
                                            count
                                            approved
                                        }}
                                        }}
                                    }}
                                    }}
                                }}
                                }}
                            }}
                            }}""".format("pairedEnrich" if 'l2s2' in url else "pairEnrich")
        }

        headers = {
                "Accept": "application/json",
                "Content-Type": "application/json"
        }

        response = requests.post(url, data=json.dumps(query), headers=headers)

        response.raise_for_status()
        res = response.json()
        if 'pairEnrich' in res['data']['currentBackground']:
            enrichment = res['data']['currentBackground']['pairEnrich']['nodes']
        else: 
            enrichment = res['data']['currentBackground']['pairedEnrich']['nodes']
        
        df_enrichment_pair = pd.DataFrame(enrichment)

        if df_enrichment_pair.empty:
            return pd.DataFrame()
        
        df_enrichment_pair["geneSetIdUp"] = df_enrichment_pair["geneSet"].map(
            lambda t: next((node['id'] for node in t['nodes'] if ' up' in node['term']), None)
        )

        df_enrichment_pair["geneSetIdDown"] = df_enrichment_pair["geneSet"].map(
            lambda t: next((node['id'] for node in t['nodes'] if ' down' in node['term']), None)
        )
        
        df_enrichment_pair["term"] = df_enrichment_pair["geneSet"].map(
            lambda t: t['nodes'][0]['term']
        )
        
        def try_or_else_factory(fn, other):
            def try_or_else(*args, **kwargs):
                try: return fn(*args, **kwargs)
                except: return other
            return try_or_else
        
        df_enrichment_pair["approved"] = df_enrichment_pair["geneSet"].map(
            try_or_else_factory(lambda t: t['nodes'][0]['geneSetFdaCountsById']['nodes'][0]['approved'], False)
        )
        
        df_enrichment_pair["count"] = df_enrichment_pair["geneSet"].map(
            try_or_else_factory(lambda t: t['nodes'][0]['geneSetFdaCountsById']['nodes'][0]['count'], 0)
        )
        
        df_enrichment_pair = df_enrichment_pair.drop(columns=['geneSet']).reset_index(drop=True)
        
        return df_enrichment_pair

    def get_overlap(self, genes, id, url="http://l2s2.maayanlab.cloud/graphql"):
        query = {
        "operationName": "OverlapQuery",
        "variables": {
            "id": id,
            "genes": genes
        },
        "query": """query OverlapQuery($id: UUID!, $genes: [String]!) {geneSet(id: $id) {
        overlap(genes: $genes) {
        nodes {
            symbol
            ncbiGeneId
            description
            summary
        }   }}}"""
        }
        
        headers = {
            "Accept": "application/json",
            "Content-Type": "application/json"
        }

        response = requests.post(url, data=json.dumps(query), headers=headers)
        
        response.raise_for_status()
        res = response.json()
        return [item['symbol'] for item in res['data']['geneSet']['overlap']['nodes']]

    def get_up_dn_overlap(self, genes_up: list[str], genes_down: list[str], id_up: str, id_down: str, overlap_type: str,  url="http://l2s2.maayanlab.cloud/graphql"):
        if overlap_type == 'mimickers':
            up_up_overlap = self.get_overlap(genes_up, id_up, url)
            dn_dn_overlap = self.get_overlap(genes_down, id_down, url)
            return list(set(up_up_overlap) | set(dn_dn_overlap))
        elif overlap_type == 'reversers':
            up_dn_overlap = self.get_overlap(genes_up, id_down, url)
            dn_up_overlap = self.get_overlap(genes_down, id_up, url)
            return list(set(up_dn_overlap) | set(dn_up_overlap))
        
    def add_user_geneset(self, geneset, geneset_dn = None, url="http://l2s2.maayanlab.cloud/graphql"):
        query = {
                "query": "mutation AddUserGeneSet($genes: [String] = [\"AKT1\"], $description: String = \"\") {\n  addUserGeneSet(input: {genes: $genes, description: $description}) {\n    userGeneSet {\n      id\n    }\n  }\n}",
                "variables": {
                    "genes": geneset,
                    "description": "User gene set" if geneset_dn is not None else "User gene set (up)"
                },
                "operationName": "AddUserGeneSet"
        }
        
        headers = {
            "Accept": "application/json",
            "Content-Type": "application/json"
        }

        response = requests.post(url, data=json.dumps(query), headers=headers)
        
        response.raise_for_status()
        res = response.json()
        
        if geneset_dn is not None:
            query = {
                "query": "mutation AddUserGeneSet($genes: [String] = [\"AKT1\"], $description: String = \"\") {\n  addUserGeneSet(input: {genes: $genes, description: $description}) {\n    userGeneSet {\n      id\n    }\n  }\n}",
                "variables": {
                    "genes": geneset_dn,
                    "description": "User gene set (down)"
                },
                "operationName": "AddUserGeneSet"
            }
            
            response = requests.post(url, data=json.dumps(query), headers=headers)
            
            response.raise_for_status()
            res_dn = response.json()
            return res['data']['addUserGeneSet']['userGeneSet']['id'], res_dn['data']['addUserGeneSet']['userGeneSet']['id']
        
        return res['data']['addUserGeneSet']['userGeneSet']['id']

    def get_l2s2_valid_genes(self, genes: list[str], url="http://l2s2.maayanlab.cloud/graphql"):
        query = {
        "query": """query GenesQuery($genes: [String]!) {
            geneMap2(genes: $genes) {
                nodes {
                    gene
                    geneInfo {
                        symbol
                        }
                    }
                }
            }""",
        "variables": {"genes": genes},
        "operationName": "GenesQuery"
        }
        
        headers = {
            "Accept": "application/json",
            "Content-Type": "application/json"
        }

        response = requests.post(url, data=json.dumps(query), headers=headers)

        response.raise_for_status()
        res = response.json()
        return [g['geneInfo']['symbol'] for g in res['data']['geneMap2']['nodes'] if g['geneInfo'] != None]

         
    def display_table(self, db):
        if db=="l2s2_fda":
            approval = "FDA-approved"
            library = "LINCS L1000"
            df = self.l2s2_df
            up_id = self.l2s2_geneset_up_id
            down_id = self.l2s2_geneset_dn_id

        elif db=="l2s2_all":
            approval = ""
            library = "LINCS L1000"
            df = self.l2s2_df_nofda
            up_id = self.l2s2_geneset_up_id
            down_id = self.l2s2_geneset_dn_id

        elif db=="drugseqr_fda":
            approval = "FDA-approved"
            library = "DRUG-seq"
            df = self.drugseqr_df
            up_id = self.drugseqr_geneset_up_id
            down_id = self.drugseqr_geneset_dn_id

        elif db=="drugseqr_all":
            approval = ""
            library = "DRUG-seq"
            df = self.drugseqr_df_nofda
            up_id = self.drugseqr_geneset_up_id
            down_id = self.drugseqr_geneset_dn_id

        else:
            raise ValueError("Choose a valid dataset.")

        termdict = {
            "mimickers": "mimic",
            "reversers": "reverse"
        }

        if df.empty:
            raise ValueError(f"No Results for {db}")

        df_t20 = df.iloc[:20]
        
        # print(df_t20.columns)
        # print(df_t20.empty)
        
        columns = ['perturbation', 'term']
        if self.direction == "mimickers":
            columns.extend(['pvalueMimic', 'adjPvalueMimic', 'oddsRatioMimic', 'mimickerOverlap'])
        else:
            columns.extend(['pvalueReverse', 'adjPvalueReverse', 'oddsRatioReverse','reverserOverlap'])

        columns.extend(['approved', 'count'])

        display(df_t20[columns])

        # display(df_t20)
        
        display(Markdown(f"Table {self.tab_num}: Ranked {approval} {library} signatures predicted to {termdict[self.direction]} the uploaded geneset."))
        display(HTML(f"<a href=\"https://l2s2.maayanlab.cloud/enrichpair?dataset={up_id}&dataset={down_id}&fda=true&dir={self.direction_str.strip()}&sort={'pvalue_reverse' if self.direction_str == 'down' else 'pvalue_mimic'}\" target=\"_blank\">View in L2S2</a>"))
        self.tab_num += 1

        filename = os.path.join(self.save_path, f"{self.save_name}_{self.direction}_{db}.tsv")
        df[:200].to_csv(filename, sep='\t')
        display(FileLink(filename, result_html_prefix="Download table: "))

    def display_barplot(self, db, save_formats, color='tomato'): 
        if db == "l2s2_fda":
            df = self.l2s2_df
            approval = "FDA-approved"
        elif db == "l2s2_all":
            df = self.l2s2_df_nofda
            approval = ""
        elif db == "drugseqr_fda":
            df = self.drugseqr_df
            approval = "FDA-approved"
        elif db == "drugseqr_all":
            df = self.drugseqr_df_nofda
            approval = ""
        else:
            raise ValueError("Choose a valid dataset.")
        
        if df.empty:
            raise ValueError(f"No Results for {db}")
        df_t20 = df.iloc[:20]
        bar_color_not_sig = "lightgrey"
        bar_color = color
        edgecolor=None

        if self.direction == 'mimickers':
            pvalcol = 'pvalueMimic'
        else:
            pvalcol = 'pvalueReverse'
        

        
        df_t20['-log10(pvalue)'] = np.log10(df_t20[pvalcol])*-1
        df_t20 = df_t20.groupby(by="perturbation", level=None, sort=False).mean().reset_index()
        df_t20.sort_values(by='-log10(pvalue)', ascending=True)
        bar_colors = [bar_color if (x < 0.05) else bar_color_not_sig for x in df_t20[pvalcol].tolist()]
        
        fig=sns.barplot(
            data=df_t20,
            x="-log10(pvalue)", 
            y="perturbation",
            palette=bar_colors,
            legend=False,
            edgecolor=edgecolor,
            linewidth=1,
            orient='y',
            errorbar=None,
            #ax=ax
        )

        fig.xaxis.set_major_locator(MaxNLocator(integer=True))
        fig.tick_params(axis='x', which='major', labelsize=10)

        fig.axes.get_yaxis().set_visible(False)
        for i in range(len(df_t20)):
            if df_t20[pvalcol].iloc[i] < 0.05:
                annot = f" *{df_t20['perturbation'].iloc[i]} {np.format_float_scientific(df_t20[pvalcol].iloc[i],precision=2)}"
            else:
                annot = f" {df_t20['perturbation'].iloc[i]} {np.format_float_scientific(df_t20[pvalcol].iloc[i],precision=2)}"
            
            title_start= max(fig.axes.get_xlim())/200
            fig.text(title_start,i,annot,ha='left', va='center', wrap = True, fontsize = 8)

        for fmt in save_formats:
            file_path = os.path.join(self.save_path, f"{self.save_name}_{self.direction}_{db}.{fmt}")
            plt.savefig(file_path, bbox_inches="tight", dpi=300)
        
        #plt.show()

        display(Image(os.path.join(self.save_path, f"{self.save_name}_{self.direction}_{db}.png"), width=600))
        display(Markdown(f"Figure {self.fig_num}: barplot representation depicting the -log10p values of the top {approval} {db} {self.direction}. Red bars represent statistically significant results; otherwise gray."))
        self.fig_num += 1

        for fmt in save_formats:
            file_path = os.path.join(self.save_path, f"{self.save_name}_{self.direction}_{db}.{fmt}")
            display(FileLink(file_path, result_html_prefix=f"Download bar plot as {fmt}"))
        
        plt.close()

### **Reverser Results**

In [None]:
dbs = ['l2s2_fda', 'l2s2_all', 'drugseqr_fda', 'drugseqr_all']

for sig_name in sig_names_clean:
    # up_genes_t500 = upreg_t500[sig_name]
    # down_genes_t500 = downreg_t500[sig_name]
    up_genes = upreg[sig_name]
    down_genes = upreg[sig_name]
    
    rev_drugs = druganalysis(geneset=up_genes_t500, geneset_dn=down_genes_t500, direction="reversers", save_path=resource_path, save_name=sig_name, tab_num=tab_num, fig_num=fig_num)
    display(Markdown(f"#### **{sig_name}**"))
    
    for db in dbs:
        display(Markdown(f"**{db}**"))
        try:
            rev_drugs.display_table(db=db)
            rev_drugs.display_barplot(db=db, save_formats=save_formats)
        except ValueError as e:
            print("Caught error: ", e)

    fig_num = rev_drugs.fig_num
    tab_num  = rev_drugs.tab_num

    display(Markdown("---"))
    

#### **Mimicker Results**

In [None]:
for sig_name in sig_names_clean:
    # up_genes_t500 = upreg_t500[sig_name]
    # down_genes_t500 = downreg_t500[sig_name]

    up_genes = upreg[sig_name]
    down_genes = upreg[sig_name]

    mim_drugs = druganalysis(geneset=up_genes_t500, geneset_dn=down_genes_t500, direction="mimickers", save_path=resource_path, save_name=sig_name, tab_num=tab_num, fig_num=fig_num)
    display(Markdown(f"#### **{sig_name}**"))
    
    for db in dbs:
        display(Markdown(f"\n**{db}**"))
        try:
            mim_drugs.display_table(db=db)
            mim_drugs.display_barplot(db=db, save_formats=save_formats)
        except ValueError as e:
            print("Caught error: ", e)

    fig_num = mim_drugs.fig_num
    tab_num  = mim_drugs.tab_num

    display(Markdown("---"))
    

## **References**

In [None]:
references = f'''
[1] {citation}

[2] McInnes L, Healy J, Saul N, Großberger L. UMAP: Uniform manifold approximation and projection. Journal of Open Source Software. 2018;3(29):861. doi:10.21105/joss.00861

[3] Clark NR, Ma’ayan A. Introduction to statistical methods to analyze large data sets: Principal Components Analysis. Science Signaling. 2011;4(190):tr3-tr3. doi:10.1126/scisignal.2001967 

[4] van der Maaten L, Hinton G. Visualizing Data using t-SNE. Journal of Machine Learning Research. 2008;9(86):2579-2605.

[5] Chen EY, Tan CM, Kou Y, Duan Q, Wang Z, Meirelles GV, Clark NR, Ma'ayan A. Enrichr: interactive and collaborative HTML5 gene list enrichment analysis tool. BMC Bioinformatics. 2013;128(14)

[6] Kuleshov MV, Jones MR, Rouillard AD, Fernandez NF, Duan Q, Wang Z, Koplev S, Jenkins SL, Jagodnik KM, Lachmann A, McDermott MG, Monteiro CD, Gundersen GW, Ma'ayan A. Enrichr: a comprehensive gene set enrichment analysis web server 2016 update. Nucleic Acids Research. 2016; gkw377.

[7] Xie Z, Bailey A, Kuleshov MV, Clarke DJB., Evangelista JE, Jenkins SL, Lachmann A, Wojciechowicz ML, Kropiwnicki E, Jagodnik KM, Jeon M, & Ma’ayan A. Gene set knowledge discovery with Enrichr. Current Protocols, 1, e90. 2021. doi: 10.1002/cpz1.90

[8] Keenan AB, Torre D, Lachmann A, Leong AK, Wojciechowicz M, Utti V, Jagodnik K, Kropiwnicki E, Wang Z, Ma'ayan A (2019) ChEA3: transcription factor enrichment analysis by orthogonal omics integration. Nucleic Acids Research. doi: 10.1093/nar/gkz446

[9] Marino GB, Evangelista JE, Clarke DJB, Ma’ayan A. L2S2: chemical perturbation and CRISPR KO LINCS L1000 signature search engine. Nucleic Acids Res. 2025; gkaf373. doi:10.1093/nar/gkaf373

[10] Li J, Ho DJ, Henault M, et al. DRUG-seq Provides Unbiased Biological Activity Readouts for Neuroscience Drug Discovery. ACS Chem Biol. 2022;17(6):1401-1414. doi:10.1021/acschembio.1c00920

[11] Lachmann A, Torre D, Keenan AB, Jagodnik KM, Lee HJ, Wang L, Silverstein MC, Ma'ayan A. Massive mining of publicly available RNA-seq data from human and mouse. Nature Communications 9. Article number: 1366 (2018), doi: 10.1038/s41467-018-03751-6.

[12] Bray, N., Pimentel, H., Melsted, P. et al. Near-optimal probabilistic RNA-seq quantification. Nat Biotechnol 34, 525–527 (2016). https://doi.org/10.1038/nbt.3519

[13] Fernandez, N. F. et al. Clustergrammer, a web-based heatmap visualization and analysis tool for high-dimensional biological data. Sci. Data 4:170151 doi: 10.1038/sdata.2017.151 (2017).

[14] Ritchie ME, Phipson B, Wu D, Hu Y, Law CW, Shi W, Smyth GK. limma powers differential expression analyses for RNA-sequencing and microarray studies. Nucleic Acids Res. 2015 Apr 20;43(7):e47. doi: 10.1093/nar/gkv007.

[15] Milacic M, Beavers D, Conley P, Gong C, Gillespie M, Griss J, Haw R, Jassal B, Matthews L, May B, Petryszak R, Ragueneau E, Rothfels K, Sevilla C, Shamovsky V, Stephan R, Tiwari K, Varusai T, Weiser J, Wright A, Wu G, Stein L, Hermjakob H, D’Eustachio P. The Reactome Pathway Knowledgebase 2024. Nucleic Acids Research. 2024. doi: 10.1093/nar/gkad1025.

[16] Eppig JT, Smith CL, Blake JA, Ringwald M, Kadin JA, Richardson JE, Bult CJ. Mouse Genome Informatics (MGI): Resources for Mining Mouse Genetic, Genomic, and Biological Data in Support of Primary and Translational Research. Methods Mol Biol. 2017;1488:47-73. doi: 10.1007/978-1-4939-6427-7_3.

[17] Ashburner M, Ball CA, Blake JA, Botstein D, Butler H, Cherry JM, Davis AP, Dolinski K, Dwight SS, Eppig JT, Harris MA, Hill DP, Issel-Tarver L, Kasarskis A, Lewis S, Matese JC, Richardson JE, Ringwald M, Rubin GM, Sherlock G. Gene ontology: tool for the unification of biology. The Gene Ontology Consortium. Nat Genet. 2000 May;25(1):25-9. doi: 10.1038/75556.

[18] Cerezo M, Sollis E, Ji Y, et al. The NHGRI-EBI GWAS Catalog: standards for reusability, sustainability and diversity. Nucleic Acids Res. 2025;53(D1):D998-D1005. doi:10.1093/nar/gkae1070

[19] Kanehisa M, Furumichi M, Sato Y, Matsuura Y, Ishiguro-Watanabe M. KEGG: biological systems database as a model of the real world. Nucleic Acids Res. 2025;53(D1):D672-D677. doi:10.1093/nar/gkae909

[20] Kanehisa M, Goto S. KEGG: kyoto encyclopedia of genes and genomes. Nucleic Acids Res. 2000;28(1):27-30. doi:10.1093/nar/28.1.27

[21] Kanehisa M. Toward understanding the origin and evolution of cellular organisms. Protein Sci. 2019;28(11):1947-1951. doi:10.1002/pro.3715

[22] Pico AR, Kelder T, van Iersel MP, Hanspers K, Conklin BR, Evelo C. WikiPathways: pathway editing for the people. PLoS Biol. 2008 Jul 22;6(7):e184. doi: 10.1371/journal.pbio.0060184.

[23] GTEx Consortium. The Genotype-Tissue Expression (GTEx) project. Nat Genet. 2013 Jun;45(6):580-5. doi: 10.1038/ng.2653.

[24] ENCODE Project Consortium. An integrated encyclopedia of DNA elements in the human genome. Nature. 2012;489(7414):57-74. doi:10.1038/nature11247

[25] Luo Y, Hitz BC, Gabdank I, et al. New developments on the Encyclopedia of DNA Elements (ENCODE) data portal. Nucleic Acids Res. 2020;48(D1):D882-D889. doi:10.1093/nar/gkz1062

[26] Hammal F, de Langen P, Bergon A, Lopez F, Ballester B. ReMap 2022: a database of Human, Mouse, Drosophila and Arabidopsis regulatory regions from an integrative analysis of DNA-binding sequencing experiments. Nucleic Acids Res. 2022;50(D1):D316-D325. doi:10.1093/nar/gkab996

'''

In [None]:
display(Markdown(references))