In [None]:
%%capture
#no need to refresh kernel when changes are made to the helper scripts
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import pandas as pd
import numpy as np
import os
import sys
from IPython.display import display,FileLink, Markdown, HTML, Image

cwd = os.path.dirname(os.getcwd()) #add the cwd so that python scripts can be imported.
if cwd not in sys.path:
    sys.path.insert(0, cwd)

#to save in the same directory as the notebook, change to resource_path="".

# print(cwd)
# print(project_root)
# print(resource_path)

In [None]:
#parameters
gse = "GSE247175"
project_root = cwd

In [None]:
from pathlib import Path
resource_path = os.path.join(project_root, "public", gse) #make sure to replace with output!!!
Path(resource_path).mkdir(parents=True, exist_ok=True)

In [None]:
import contextlib

@contextlib.contextmanager
def suppress_output(stdout=True, stderr=True, dest='/dev/null'):
    ''' Usage:
    with suppress_output():
        print('hi')
    '''
    dev_null = open(dest, 'a')
    if stdout:
        _stdout = sys.stdout
        sys.stdout = dev_null
    if stderr:
        _stderr = sys.stderr
        sys.stderr = dev_null
    try:
        yield
    finally:
        if stdout:
            sys.stdout = _stdout
        if stderr:
            sys.stderr = _stderr

In [None]:
from Bio import Entrez
from dotenv import load_dotenv

load_dotenv()

os.environ['ENTREZ_EMAIL'] = os.getenv('ENTREZ_EMAIL')

Entrez.email = os.environ['ENTREZ_EMAIL']

id_handle = Entrez.esearch(db="gds", term=f"{gse}[Accession]", retmax=1)
id_record = Entrez.read(id_handle)
gds_id = id_record["IdList"][0]

In [None]:
stream = Entrez.esummary(db='gds', id=gds_id)
record = Entrez.read(stream)

In [None]:
map_species = {
    "homo sapiens": "human",
    "mus musculus": "mouse"
}

In [None]:
species = map_species[record[0]['taxon'].lower()]

In [None]:
if len(record[0]['PubMedIds'])==0: #discard studies with no pubmed citation
    raise ValueError("No PubMed citation found for this study.")

In [None]:
pmid = int(record[0]['PubMedIds'][0])

In [None]:
if len(record[0]['Samples']) not in range(6, 25): #discard studies that dont have 6-24 samples
    raise ValueError("Number of samples need to be within 6-24.")

In [None]:
def fetch_pubmed_metadata(pmid):
    handle = Entrez.efetch(db="pubmed", id=pmid, retmode="xml")
    records = Entrez.read(handle)
    article = records['PubmedArticle'][0]['MedlineCitation']['Article']

    title = article['ArticleTitle']
    journal = article['Journal']['Title']
    journal_abbr = article['Journal']['ISOAbbreviation']
    year = article['Journal']['JournalIssue']['PubDate'].get('Year', '')
    volume = article['Journal']['JournalIssue'].get('Volume', '')
    issue = article['Journal']['JournalIssue'].get('Issue', '')
    pages = article.get('Pagination', {}).get('MedlinePgn', '')
    authors = article.get('AuthorList', [])

    def format_author(author):
        initials = ''.join(author.get('Initials', ''))
        return f"{author['LastName']} {initials}"

    authors = [format_author(a) for a in authors]

    return {
        "title": title,
        "journal": journal_abbr,
        "year": year,
        "volume": volume,
        "issue": issue,
        "pages": pages,
        "authors": authors
    }

In [None]:
pmdict = fetch_pubmed_metadata(pmid)

In [None]:
def format_ama_citation(metadata, pmid=pmid):

    authors = metadata['authors']
    if len(authors) > 6:
        author_str = ', '.join([a for a in authors[:6]]) + ', et al'
    else:
        author_str = ', '.join([a for a in authors])

    citation = (
        f"{author_str}. {metadata['title']} "
        f"{metadata['journal']}. {metadata['year']};"
        f"{metadata['volume']}({metadata['issue']}):{metadata['pages']}."
    )

    if pmid:
        citation += f" PMID: {pmid}"

    return citation

In [None]:
citation = format_ama_citation(pmdict)

In [None]:
display(Markdown(f"# **Reanalysis of \"{pmdict['title']}\" by {pmdict['authors'][0]} et al., {pmdict['journal']}, {pmdict['year']}**"))

In [None]:
link = f"https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc={gse}"
display(Markdown(f"{citation}"))
HTML(f'<a href="{link}" target="_blank">Visit GEO accession page</a>')

In [None]:
from google import genai

os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")

client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
#print(os.environ["GOOGLE_API_KEY"])
prompt = f'''
  You are an expert academic writer. Your task is to reformat the provided research information into a concise abstract around 250 words following this exact template:

  "In this study, <FIRST AUTHOR> et al. [1] profiled <CELLS AND CONDITIONS> to further our understanding of <TOPIC>. The reanalysis of this dataset include <FILL IN>"

  Here is the contextual information:
  Author: {pmdict['authors']}
  Title: {pmdict['title']}
  Summary: {record[0]['summary']}

  In the reanalysis explanation, use the following information: the reanalysis is a full RNA-seq analysis pipeline that consists of: UMAP[2], PCA[3], t-SNE[4] plots of the samples; clustergram heatmap; differential gene expression analysis
  for each pair of control and perturbation samples; Enrichment analysis for each gene signature using Enrichr [5, 6, 7]; Transcription factor analysis of gene signatures
  using ChEA3 [8] ; Reverser and mimicker drug match analysis using L2S2 [9] and DRUG-seqr [10], both FDA and non-FDA approved. Results are provided as tables in addition to bar charts.

  Please write the reanalysis as a complete paragraph with smoothly transitioning sentences. Use consistent, present tense.
  Do not omit or change the ordering of the reference numbers.
  Do not change the reference and insert it, in parentheses, where indicated. 

  Now, generate the abstract strictly following the template. Do not include any other text or introductory/concluding remarks.
'''

response = client.models.generate_content(
    model="gemini-2.0-flash", contents=prompt
)
#print(response.text)

## **Abstract**

In [None]:
display(Markdown(f"{response.text}\n*This abstract was generated with the assistance of Gemini 2.0 Flash.*"))

## **Methods**

*RNA-seq alignment*

Gene count matrices were obtained from ARCHS4 [11], which preprocessed the raw FASTQ data using the Kallisto [12] and STAR [] pseudoalignment algorithm.

*Gene matrix processing* 

The raw gene matrix was filtered to remove genes that do not have an average of 3 reads across the samples. It was then quantile, log2, and z-score normalized. A regex-based function was used to infer whether individual samples belong to a “control” or a “perturbation” group by processing the metadata associated with each sample. 

*Dimensionality Reduction Visualization*

Three types of dimensionality reduction techniques were applied on the processed expression matrices: UMAP[2], PCA[3], and t-SNE[4]. UMAP was calculated by the UMAP Python package and PCA and t-SNE were calculated using the Scikit-Learn Python library. The samples were then represented on 2D scatterplots.

*Clustergram Heatmap*

As a preliminary step, the top 1000 genes exhibiting most variability were selected. Using this new set, clustergram heatmaps were generated. Two versions of the clustergram exist: an interactive one generated by Clustergrammer [13] and a publication-ready alternative.

*Differentially Expressed Genes Calculation and Volcano Plot*

Differentially expressed genes between the control and perturbation samples were calculated using Limma Voom [14]. The logFC and -log10p values of each gene were visualized as a volcano scatterplot. Upregulated and downregulated genes were selected according to this criteria: p < 0.05 and |logFC| > 1.0.

*Enrichr Enrichment Analysis*

The upregulated and down-regulated sets were separately submitted to Enrichr [5, 6, 7]. These sets were compared against libraries from ChEA [8], ARCHS4 [12], Reactome Pathways [15], MGI Mammalian Phenotype [16], Gene Ontology Biological Processes [17], GWAS Catalog [18], KEGG [19, 20, 21], and WikiPathways [22]. The top matched terms from each library and their respective -log10p values were visualized as barplots.

*Chea3 Transcription Factor Analysis*

The upregulated and down-regulated sets were separately submitted to Chea3 [8]. These sets were compared against the libraries ARCHS4 Coexpression [12], GTEx Coexpression [23], Enrichr [5, 6, 7], ENCODE ChIP-seq [24, 25], ReMap ChIP-seq [26], and Literature-mined ChIP-seq. The top matched TFs were ranked according to their average score across each library and represented as barplots.

*L2S2 and Drug-seqr drug analysis*

The top 500 up and downregulated sets were submitted simulataneously to identify reverser and mimicker molecules, both FDA and non-FDA approved, from the L2S2 [9] and Drug-seqr [10] databases. The top matched molecules were compiled into tables and visualized as barplots. 


In [None]:
%%capture

import json
from datetime import datetime

#write a json file as a catalog list.
metadata_path = Path(os.path.join(project_root, "data", "metadata.json"))

# 1. Load existing metadata (or create empty if file doesn't exist)
if metadata_path.exists():
    with open(metadata_path, "r") as f:
        metadata = json.load(f)
else:
    metadata = {}


entry = {
    "GSE": gse, 
    "author": ", ".join(pmdict['authors']),
    "year": pmdict['year'],
    "species": species,
    "title": pmdict['title'],
    "pmid": pmid,
    "num_samps": len(record[0]['Samples']),
    "samples": ", ".join(sorted([w['Accession'] for w in record[0]['Samples']])),
    "citation": citation,
    "notebook_path": f"{resource_path}/{gse}.ipynb",
    #"report_path": f"{resource_path}/{gse}.html",
    "timestamp": datetime.now().isoformat()
}

# 3. Add entry only if it doesn't already exist
if gse not in metadata:
    metadata[gse] = entry
    print(f"[INFO] Added metadata for {gse}")
else:
    print(f"[INFO] {gse} already exists in metadata. Skipping update.")

# 4. Write updated metadata back to file
with open(metadata_path, "w") as f:
    json.dump(metadata, f, indent=4)


In [None]:
tab_num = 1
fig_num = 1
save_formats = ['png', 'svg', 'jpeg']

In [None]:
import archs4py as a4
#file_path = a4.download.counts("human", path="", version="latest") #comment out if the file is already downloaded
file = os.path.join("/home/ajy20/projects/8--auto-playbook-geo-reports", "human_gene_v2.latest.h5") 

In [None]:
metadata = a4.meta.series(file, gse)

In [None]:
%%capture
import re
import nltk
from nltk.corpus import stopwords

nltk.download('stopwords')
words_to_remove = ['experiement', 'tissue', 'type', 'batch', 'treatment', 'experiment', 'patient', 'batch', '1', '2', '3', '4', '5', '6', '7', '8', '9']
stopwords_plus = set(stopwords.words('english') + (words_to_remove))
pattern = r'[-,_.:]'


In [None]:
terms_to_remove = ["cell line", "cell type", "genotype", "treatment"]


pattern1 = r"\b(" + "|".join(map(re.escape, terms_to_remove)) + r")\b"

metadata["cleaned_characteristics"] = metadata["characteristics_ch1"].str.replace(
    pattern1, 
    "", 
    flags=re.IGNORECASE, 
    regex=True
).str.replace(r"\s+", " ", regex=True).str.strip()

metadata['cleaned_characteristics'] = metadata['cleaned_characteristics'].apply(lambda x: re.sub(pattern, " ", x).strip().lower())
metadata['cleaned_characteristics'] = metadata['cleaned_characteristics'].apply(
    lambda text: " ".join([word for word in text.split() if word not in stopwords_plus])
)

In [None]:
metadata['clean_title'] = metadata['title'].apply(lambda x: re.sub('[0-9]+', '', x))
metadata['clean_title'] = metadata['clean_title'].apply(lambda x: re.sub(pattern, " ", x).strip().lower())
metadata['clean_title'] = metadata['clean_title'].apply(
    lambda text: " ".join([word for word in text.split() if word not in stopwords_plus])
)

#metadata

In [None]:
groups = metadata.groupby(by='clean_title', level=None)

In [None]:
ctrl_words = set(['wt', 'wildtype', 'control', 'cntrl', 'ctrl', 'uninfected', 'normal', 'untreated', 'unstimulated', 'shctrl', 'ctl', 'healthy', 'sictrl', 'sicontrol', 'ctr', 'wild', 'dmso', 'vehicle', 'naive'])

In [None]:
groupings = {}
for label, group in groups:
    if len(group) not in {3, 4}: #enforce 3-4 samples per group
        raise ValueError("Study does not have 3-4 samples per group")
    
    groupings[label] = group['geo_accession'].tolist()

# print(groupings)

In [None]:
title_conditions = list(groupings.keys())
title_ctrl = []
for c in title_conditions:
    if len(set(c.split()).intersection(ctrl_words)) > 0:
        title_ctrl.append(c)
        
# print(title_conditions)
# print(title_ctrl)    

In [None]:
og_labels = {}
labled_groupings = {}

for label in groupings:
    samps = groupings[label]
    data = list(map(lambda s: s.lower(), metadata.loc[samps]['characteristics_ch1'].values))
    data_clean = []
    for d in data:
        data_clean.append(set(filter(lambda w: w not in stopwords_plus, re.sub(pattern, ' ', d).split())))
    condition = set(data_clean[0])
    for s in data_clean[1:]:
        condition.intersection_update(s)
    condition = ' '.join(list(condition))
    labled_groupings[condition] = samps
    og_labels[condition] = label

ch1_ctrl = []
ch1_conditions = list(labled_groupings.keys())

for condition in labled_groupings:
    split_conditions = condition.lower().split()
    if len(set(split_conditions).intersection(ctrl_words)) > 0:
        ch1_ctrl.append(condition)

# print(og_labels)
# print(ch1_conditions)
# print(ch1_ctrl)

In [None]:
#must have 1-2 controls. Must have perturbation groups as well (not all groups can be controls).
def check_eligibility(conditions, ctrl_conditions):
    if len(ctrl_conditions) not in range(1, 3) or len(ctrl_conditions) == len(conditions):
        return False
    else:
        return True

In [None]:
ch1_eligibility = check_eligibility(ch1_conditions, ch1_ctrl)
title_eligibility = check_eligibility(title_conditions, title_ctrl)
#print(ch1_eligibility)
#rint(title_eligibility)

In [None]:
def compare_groups(title_ctrl, ch1_ctrl, og_labels):
    #convert ch1 condition to corresponding title condition, check if their respective sample sets are equal
    for c in ch1_ctrl:
        ch1_corresponding = og_labels[c]
        if set(groupings[ch1_corresponding]) != set(labled_groupings[c]):
            return False

    return True

In [None]:
if ch1_eligibility and title_eligibility:
    if compare_groups(title_ctrl, ch1_ctrl, og_labels):
        ctrl_conditions = title_ctrl
        conditions = title_conditions
    else:
        raise Exception("Group Assignment Failed")
    
elif ch1_eligibility ^ title_eligibility:
    if ch1_eligibility:
        ctrl_conditions = ch1_ctrl
        conditions = ch1_conditions
        groupings = labled_groupings
    else:
        ctrl_conditions = title_ctrl
        conditions = title_conditions
else:
    raise Exception("Group Assignment Failed")

# print(ctrl_conditions)
# print(conditions)

In [None]:
%%capture
gene_matrix = a4.data.series(file, gse) #raw counts
gene_matrix.to_csv(os.path.join(resource_path, "raw_counts.csv"))

In [None]:
gene_matrix.head(5)

In [None]:
display(Markdown(f"**table {tab_num}**: This is a preview of the first 5 rows of the raw RNA-seq expression matrix from {gse}."))
tab_num +=1
display(FileLink(os.path.join(resource_path, "raw_counts.csv"), result_html_prefix="Download raw counts: "))

In [None]:
# Remove genes with all-zero counts
filtered_matrix = gene_matrix.loc[gene_matrix.sum(axis=1) > 0, :]

# Then filter out low average expression
filtered_matrix = filtered_matrix.loc[filtered_matrix.mean(axis=1) >= 3, :]

In [None]:
from maayanlab_bioinformatics.normalization.log import log2_normalize
from maayanlab_bioinformatics.normalization.zscore import zscore_normalize 
from maayanlab_bioinformatics.normalization.quantile_legacy import quantile_normalize

def normalize(gene_counts):
    norm_exp = quantile_normalize(gene_counts)
    norm_exp = log2_normalize(norm_exp)
    norm_exp = zscore_normalize(norm_exp)
    return norm_exp


In [None]:
%%capture
# from python_scripts.matrix import normalize
norm_matrix = normalize(filtered_matrix) #normalize the matrix for dim reduction and clustergram

In [None]:
def annotate_matrix(expr_df, groupings, ctrl_conditions):
    sampdict = {}
    for group in groupings.keys():
        samps = groupings[group]
        for samp in samps:
            if group in ctrl_conditions:
                sampdict[samp] = "control"
            else:
                sampdict[samp] = "perturbation"
    
    annotat = pd.DataFrame.from_dict(sampdict, orient='index', columns=['group'])
    anndict = {
        'count': expr_df,
        'annotations': annotat
    }
    return anndict

In [None]:
annotated_norm_matrix = annotate_matrix(norm_matrix, groupings, ctrl_conditions)

#print(annotated_norm_matrix['annotations'])

In [None]:
annotated_matrix = annotate_matrix(filtered_matrix.astype('int64'), groupings, ctrl_conditions) #filtered but not normalized

# **Results**

In [None]:
save_html = True #if true, render plotly graphs as html and embed with ipython. else, use fig.show()
use_fig_plot = False #if true, render matplotlib graphs using show(), else it will render the saved pngs.

## **Dimensionality Reduction**

### **UMAP**

In [None]:
import python_scripts.visualizations as vis

vis.plot(annotated_norm_matrix['count'], annotated_norm_matrix['annotations'], n_components=2, save_formats=save_formats, decomp="umap", save_html=save_html, save_path=resource_path)
# if save_html: display(HTML(os.path.join(resource_path, "umap.html")))
display(Image(os.path.join(resource_path, "umap.png"), width=700))
display(Markdown(f"**Figure {fig_num}**: This figure displays a 2D scatter plot of a UMAP decomposition of the sample data. Each point represents an individual sample, colored by its experimental group."))
fig_num+=1

for fmt in save_formats:
    display(FileLink(os.path.join(resource_path, f"umap.{fmt}"), result_html_prefix=f"Download UMAP figure as {fmt}: "))


## **Clustergram Heatmaps**

In [None]:
from maayanlab_bioinformatics.normalization.filter import filter_by_var
norm_t1000 = annotated_norm_matrix['count'].copy()
norm_t1000 = filter_by_var(annotated_norm_matrix['count'], top_n=1000, axis=1)
norm_t1000.columns=metadata['title'].tolist()

In [None]:
t1000_path = os.path.join(resource_path, 'expression_matrix_top1000_genes.txt')
norm_t1000.to_csv(t1000_path, sep='\t')

In [None]:
import requests, json
clustergrammer_url = 'http://amp.pharm.mssm.edu/clustergrammer/matrix_upload/'

r = requests.post(clustergrammer_url, files={'file': open(t1000_path, 'rb')})
link = r.text

In [None]:
from IPython.display import IFrame
display(IFrame(link, width="600", height="650"))
display(Markdown(f"**Figure {fig_num}**: The figure contains an interactive heatmap displaying gene expression for each sample in the RNA-seq dataset. Every row of the heatmap represents a gene, every column represents a sample, and every cell displays normalized gene expression values. The heatmap additionally features color bars beside each column which represent prior knowledge of each sample, such as the tissue of origin or experimental treatment."))
fig_num+=1