Skip to content

Commit

Permalink
gene_set name for dict, gmt, #181
Browse files Browse the repository at this point in the history
  • Loading branch information
zqfang committed Dec 16, 2022
1 parent 207e01d commit 132d68e
Showing 1 changed file with 42 additions and 32 deletions.
74 changes: 42 additions & 32 deletions gseapy/enrichr.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,34 @@ def __init__(
self._organism = None
self.ENRICHR_URL = "http://maayanlab.cloud"
# init logger
logfile = self.prepare_outdir()
self._logger = log_init(
outlog=logfile, log_level=logging.INFO if self.verbose else logging.WARNING
)
self.prepare_outdir()

def __del__(self):
handlers = self._logger.handlers[:]
for handler in handlers:
handler.close() # close file
self._logger.removeHandler(handler)
if hasattr(self, "_tmpdir") and os.path.exists(self._logfile):
self._tmpdir.cleanup()

def prepare_outdir(self):
"""create temp directory."""
self._outdir = self.outdir
if self._outdir is None:
self._tmpdir = TemporaryDirectory()
self.outdir = self._tmpdir.name
elif isinstance(self.outdir, str):

if isinstance(self.outdir, str):
mkdirs(self.outdir)
else:
raise Exception("Error parsing outdir: %s" % type(self.outdir))

# handle gene_sets
self._tmpdir = TemporaryDirectory()
self.outdir = self._tmpdir.name
logfile = os.path.join(
self.outdir, "gseapy.%s.%s.log" % (self.module, self.descriptions)
self.outdir, "gseapy.%s.%s.log" % (self.module, id(self))
)
self._logfile = logfile
self._logger = log_init(
name=str(self.module) + str(id(self)),
log_level=logging.INFO if self.verbose else logging.WARNING,
filename=logfile,
)
return logfile

def __parse_gmt(self, g: str):
with open(g) as genesets:
Expand All @@ -86,20 +93,25 @@ def __parse_gmt(self, g: str):
}
return g_dict

def __gmt2dict(self, gene_sets: List[str]) -> List[Dict[str, List[str]]]:
def __gs2dict(self, gene_sets: List[str]) -> List[Dict[str, List[str]]]:
"""helper function, only convert gmt to dict and keep strings"""
gss = []
for g in gene_sets:
self._gs_name = []
for i, g in enumerate(gene_sets):
# only convert gmt to dict. local mode
if isinstance(g, str) and g.lower().endswith(".gmt"):
if os.path.exists(g):
self._logger.info("User Defined gene sets is given: %s" % g)
gss.append(self.__parse_gmt(g))
self._gs_name.append(os.path.basename(g))
else:
self._logger.warning("User Defined gene sets is not found: %s" % g)
else:
gss.append(g)

_name = g
if isinstance(g, dict):
_name = "gs_ind_" + str(i)
self._gs_name.append(_name)
return gss

def parse_genesets(self, gene_sets=None):
Expand All @@ -109,14 +121,13 @@ def parse_genesets(self, gene_sets=None):

gss = []
if isinstance(gene_sets, list):
gss = self.__gmt2dict(gene_sets)

gss = self.__gs2dict(gene_sets)
elif isinstance(self.gene_sets, str):
gss = [g.strip() for g in gene_sets.strip().split(",")]
gss = self.__gmt2dict(gss)
gss = self.__gs2dict(gss)

elif isinstance(gene_sets, dict):
gss = [gene_sets.copy()]
gss = self.__gs2dict([gene_sets.copy()])
else:
raise Exception(
"Error parsing enrichr libraries, please provided corrected one"
Expand All @@ -126,23 +137,28 @@ def parse_genesets(self, gene_sets=None):
if len(gss) < 1:
raise Exception("No GeneSets are valid !!! Check your gene_sets input.")
gss_exist = []
gss_name = []
enrichr_library = []
# if all local gmts (local mode), skip connect to enrichr server
if not all([isinstance(g, dict) for g in gss]):
enrichr_library = self.get_libraries()

# check enrichr libraries are valid

for g in gss:
for n, g in zip(self._gs_name, gss):
if isinstance(g, dict):
gss_exist.append(g)
gss_name.append(n)
continue
if isinstance(g, str):
if g in enrichr_library:
gss_exist.append(g)
gss_name.append(n)
else:
self._logger.warning("Enrichr library not found: %s" % g)

self._gs_name = gss_name # update names
if len(gss_exist) < 1:
raise Exception("No GeneSets are valid !!! Check your gene_sets input.")
return gss_exist

def parse_genelists(self) -> str:
Expand Down Expand Up @@ -469,24 +485,21 @@ def run(self):
)
self.results = []

for g in gss:
for name, g in zip(self._gs_name, gss):
if isinstance(g, dict):
## local mode
res = self.enrich(g)
shortID, self._gs = str(id(g)), "CUSTOM%s" % id(g)
shortID, self._gs = str(id(g)), name
if res is None:
self._logger.info(
"No hits return, for gene set: Custom%s" % shortID
)
continue
else:
## online mode
self._gs = str(g)
self._gs = name
self._logger.debug("Start Enrichr using library: %s" % (self._gs))
self._logger.info(
"Analysis name: %s, Enrichr Library: %s"
% (self.descriptions, self._gs)
)
# self._logger.info("Enrichr Library: %s"% self._gs)
shortID, res = self.get_results(genes_list)
# Remember gene set library used
res.insert(0, "Gene_set", self._gs)
Expand Down Expand Up @@ -520,8 +533,5 @@ def run(self):
self._logger.warning(msg)
self._logger.info("Done.\n")
self.results = pd.concat(self.results, ignore_index=True)
# clean up tmpdir
if self._outdir is None:
self._tmpdir.cleanup()

return

0 comments on commit 132d68e

Please sign in to comment.