Skip to content

Commit

Permalink
Connected GO grouping and xlsx writer in find_enrichment script
Browse files Browse the repository at this point in the history
  • Loading branch information
dvklopfenstein committed Jul 22, 2018
1 parent d79d7ef commit ab6434e
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 30 deletions.
74 changes: 51 additions & 23 deletions goatools/cli/find_enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from goatools.grouper.sorter import Sorter
from goatools.grouper.aart_geneproducts_all import AArtGeneProductSetsAll
from goatools.grouper.wr_sections import WrSectionsTxt
from goatools.grouper.wrxlsx import WrXlsxSortedGos


# pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -144,7 +145,7 @@ def __init__(self, args):
# Run GOEA
self.results_all = self.objgoea.run_study(_study)
# Prepare for grouping, if user-specified. Create GroupItems
self.prepgrp = GroupItems(_assoc, self) if self.sections else None
self.prepgrp = GroupItems(_assoc, self, self.godag.version) if self.sections else None

def prt_results(self, goea_results):
"""Print GOEA results to the screen or to a file."""
Expand All @@ -154,16 +155,28 @@ def prt_results(self, goea_results):
else:
# Users can print to both tab-separated file and xlsx file in one run.
outfiles = self.args.outfile.split(",")
self.prt_outfiles(goea_results, outfiles)
grpwr = self.prepgrp.get_objgrpwr(goea_results) if self.prepgrp else None
if grpwr is None:
self.prt_outfiles_flat(goea_results, outfiles)
else:
self.prt_outfiles_grouped(grpwr, goea_results, outfiles)

def prt_outfiles(self, goea_results, outfiles):
def prt_outfiles_flat(self, goea_results, outfiles):
"""Write to outfiles."""
for outfile in outfiles:
if outfile.endswith(".xlsx"):
self.objgoea.wr_xlsx(outfile, goea_results, indent=self.args.indent)
else:
self.objgoea.wr_tsv(outfile, goea_results, indent=self.args.indent)

def prt_outfiles_grouped(self, grpwr, goea_results, outfiles):
"""Write to outfiles."""
for outfile in outfiles:
if outfile.endswith(".xlsx"):
grpwr.wr_xlsx(outfile)
else:
self.objgoea.wr_tsv(outfile, goea_results, indent=self.args.indent)

def _prt_results(self, goea_results):
"""Print GOEA results to the screen."""
min_ratio = self.args.ratio
Expand Down Expand Up @@ -196,8 +209,7 @@ def chk_genes(self, study, pop):
if len(pop) < len(study):
exit("\nERROR: The study file contains more elements than the population file. "
"Please check that the study file is a subset of the population file.\n")
# check the fraction of genomic ids that overlap between study
# and population
# check the fraction of genomic ids that overlap between study and population
overlap = self.get_overlap(study, pop)
if overlap < 0.95:
sys.stderr.write("\nWARNING: only {} fraction of genes/proteins in study are found in "
Expand Down Expand Up @@ -225,9 +237,6 @@ def get_overlap(study, pop):
def get_pval_field(self):
"""Get 'p_uncorrected' or the user-specified field for determining significant results."""
pval_fld = self.args.pval_field
# print("FFFFFFFFFFFF", pval_fld)
# print("FFFFFFFFFFFF", self.args)
# print("FFFFFFFFFFFF", self.methods)
# If --pval_field [VAL] was specified
if pval_fld is not None:
if pval_fld[:2] != 'p_':
Expand Down Expand Up @@ -273,7 +282,7 @@ def _read_geneset(self, study_fn, pop_fn):
class GroupItems(object):
"""Prepare for grouping, if specified by the user."""

def __init__(self, gene2gos, objcli):
def __init__(self, gene2gos, objcli, godag_version):
# _goids = set(o.id for o in godag.values() if not o.children)
_goids = set(r.GO for r in objcli.results_all)
_tobj = TermCounts(objcli.godag, gene2gos)
Expand All @@ -282,12 +291,13 @@ def __init__(self, gene2gos, objcli):
self.grprdflt = GrouperDflts(self.gosubdag, objcli.args.goslim)
self.hdrobj = HdrgosSections(self.grprdflt.gosubdag, self.grprdflt.hdrgos_dflt, objcli.sections)
self.pval_fld = objcli.get_pval_field() # primary pvalue of interest
self.ver_list = [godag_version, self.grprdflt.ver_goslims]
# self.objaartall = self._init_objaartall()

def get_objgrpwr(self, goea_results):
"""Get a GrpWr object to write grouped GOEA results."""
sortobj = self.get_sortobj(goea_results)
return GrpWr(sortobj, self.pval_fld)
return GrpWr(sortobj, self.pval_fld, ver_list=self.ver_list)

def get_sortobj(self, goea_results, **kws):
"""Return a Grouper object, given a list of GOEnrichmentRecord."""
Expand Down Expand Up @@ -325,20 +335,41 @@ def _init_objaartall(self):
class GrpWr(object):
"""Write GO term GOEA information, grouped."""

def __init__(self, sortobj, pval_fld):
def __init__(self, sortobj, pval_fld, ver_list):
self.sortobj = sortobj
self.pval_fld = pval_fld
self.ver_list = ver_list
self.objprt = PrtFmt()
self.flds = self._get_flds()

def prt_txt(self, prt=sys.stdout, hdrgo_prt=False):
self.flds_all = next(iter(self.sortobj.grprobj.go2nt.values()))._fields
self.flds_cur = self._init_flds_cur()
self.desc2nts = self.sortobj.get_desc2nts(hdrgo_prt=False)
# print("nnnnnnnnnnnnnnnnnnnnnttttttttttttttttt", self.flds_all)

def wr_xlsx(self, fout_xlsx):
"""Print grouped GOEA results into an xlsx file."""
objwr = WrXlsxSortedGos("GOEA", self.sortobj)
#### fld2fmt['ratio_in_study'] = '{:>8}'
#### fld2fmt['ratio_in_pop'] = '{:>12}'
#### ntfld2wbfmtdict = {
# ntfld_wbfmt = {
# 'ratio_in_study': {'align':'right'},
# 'ratio_in_pop':{'align':'right'}}
kws_xlsx = {
'title': "; ".join(self.ver_list),
'fld2fmt': {f:'{:8.2e}' for f in self.flds_cur if f[:2] == 'p_'},
#'ntfld_wbfmt': ntfld_wbfmt,
#### 'ntval2wbfmtdict': ntval2wbfmtdict,
#'hdrs': [],
'prt_flds': self.flds_cur}
objwr.wr_xlsx_nts(fout_xlsx, self.desc2nts, **kws_xlsx)

def prt_txt(self, prt=sys.stdout):
"""Print an ASCII text format."""
prtfmt = self.objprt.get_prtfmt(self._get_flds())
desc2nts = self.sortobj.get_desc2nts(hdrgo_prt=hdrgo_prt)
prt.write("{FLDS}\n".format(FLDS=" ".join(self.flds)))
WrSectionsTxt.prt_sections(prt, desc2nts['sections'], prtfmt, secspc=True)
prtfmt = self.objprt.get_prtfmt(self.flds_cur)
prt.write("{FLDS}\n".format(FLDS=" ".join(self.flds_cur)))
WrSectionsTxt.prt_sections(prt, self.desc2nts['sections'], prtfmt, secspc=True)

def _get_flds(self):
def _init_flds_cur(self):
"""Choose fields to print from a multitude of available fields."""
flds = []
# ('GO', 'NS', 'enrichment', 'name', 'ratio_in_study', 'ratio_in_pop', 'depth',
Expand All @@ -352,16 +383,13 @@ def _get_flds(self):
# 'REL_short', 'rel', 'id')
flds0 = ['GO', 'NS', 'enrichment', self.pval_fld, 'dcnt', 'tinfo', 'depth',
'ratio_in_study', 'ratio_in_pop', 'name']
flds_all = next(iter(self.sortobj.grprobj.go2nt.values()))._fields
flds_p = [f for f in flds_all if f[:2] == 'p_' and f != self.pval_fld]
flds_p = [f for f in self.flds_all if f[:2] == 'p_' and f != self.pval_fld]
flds.extend(flds0)
if flds_p:
flds.extend(flds_p)
flds.append('study_count')
flds.append('study_items')
# print("nnnnnnnnnnnnnnnnnnnnnttttttttttttttttt", flds_all)
return flds



# Copyright (C) 2010-2018, H Tang et al. All rights reserved.
3 changes: 2 additions & 1 deletion goatools/grouper/wrxlsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def prt_txt_desc2nts(self, prt, desc2nts, prtfmt=None):
def _get_xlsx_kws(self, **kws_usr):
"""Return keyword arguments relevant to writing an xlsx."""
kws_xlsx = {'fld2col_widths':self._get_fld2col_widths(**kws_usr), 'items':'GO IDs'}
remaining_keys = set(["title", "hdrs", "prt_flds"])
remaining_keys = set(['title', 'hdrs', 'prt_flds', 'fld2fmt',
'ntval2wbfmtdict', 'ntfld_wbfmt'])
for usr_key, usr_val in kws_usr.items():
if usr_key in remaining_keys:
kws_xlsx[usr_key] = usr_val
Expand Down
2 changes: 1 addition & 1 deletion goatools/wr_tbl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
For adding color or other formatting to a row based on value in a row:
'ntfld_wbfmt': namedtuple field containing a value used as a key for a xlsx format
'ntval2wbfmtdict': namedtuple value and corresponding xlsx format dict. Examples:
'ntval2wbfmtdict': namedtuple value and corresponding xlsx format dict.
"""

__copyright__ = "Copyright (C) 2016-2018, DV Klopfenstein, H Tang, All rights reserved."
Expand Down
13 changes: 8 additions & 5 deletions goatools/wr_tbl_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def wr_data(self, xlsx_data, row_i, worksheet):
try:
for data_nt in xlsx_data:
if prt_if is None or prt_if(data_nt):
wbfmt = get_wbfmt(data_nt)
wbfmt = get_wbfmt(data_nt) # xlsxwriter.format.Format created w/add_format
# Print an xlsx row by printing each column in order.
for col_i, fld in enumerate(prt_flds):
try:
Expand Down Expand Up @@ -176,15 +176,18 @@ def get_wbfmt(self, data_nt=None):
if wbfmt is not None:
return wbfmt
# 'ntfld_wbfmt': namedtuple field which contains a value used as a key for a xlsx format
# 'ntval2wbfmtdict': namedtuple value and corresponding xlsx format dict. Examples:
# 'ntval2wbfmtdict': namedtuple value and corresponding xlsx format dict.
return self.fmtname2wbfmtobj.get('plain')

def __get_wbfmt_usrfld(self, data_nt):
"""Return format for text cell from namedtuple field specified by 'ntfld_wbfmt'"""
if self.ntfld_wbfmt is not None:
ntval = getattr(data_nt, self.ntfld_wbfmt, None) # Ex: 'section'
if ntval is not None:
return self.fmtname2wbfmtobj.get(ntval, None)
if isinstance(self.ntfld_wbfmt, str):
ntval = getattr(data_nt, self.ntfld_wbfmt, None) # Ex: 'section'
if ntval is not None:
return self.fmtname2wbfmtobj.get(ntval, None)
#### elif isinstance(self.ntfld_wbfmt, dict):
#### print("DDDDDDDDDDDD IIIIIIIII CCCCCCCCCC TTTTTTTTTTTTTTTT")

def __get_wbfmt_format_txt(self, data_nt):
"""Return format for text cell from namedtuple field, 'format_txt'."""
Expand Down

0 comments on commit ab6434e

Please sign in to comment.