Skip to content

Commit

Permalink
For hierarchical reports, allow used-defined mark characters (expands…
Browse files Browse the repository at this point in the history
… '>')
  • Loading branch information
dvklopfenstein committed Sep 18, 2018
1 parent 34c1ed5 commit bacc42a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 25 deletions.
6 changes: 3 additions & 3 deletions goatools/cli/wr_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, args=None, prt=sys.stdout):
def _init_goids(self):
goids_ret = []
godagconsts = Consts()
print("WWWWWWWWWWWWWWWWWWWWWW", self.kws)
# print("WWWWWWWWWWWWWWWWWWWWWW", self.kws)
if 'GO' in self.kws:
for goid in self.kws['GO']:
if goid[:3] == "GO:":
Expand Down Expand Up @@ -144,7 +144,7 @@ def _adj_item_marks(self):
# --item_marks=GO:0043473,GO:0009987
# --item_marks=item_marks.txt
if goids:
self.kws['item_marks'] = goids
self.kws['item_marks'] = {go:'>' for go in goids}
else:
raise Exception("NO GO IDs FOUND IN item_marks")

Expand All @@ -166,7 +166,7 @@ def _adj_for_assc(self):
if self.gene2gos:
gos_assoc = set(get_b2aset(self.gene2gos).keys())
if 'item_marks' not in self.kws:
self.kws['item_marks'] = set(gos_assoc)
self.kws['item_marks'] = {go:'>' for go in gos_assoc}
if 'include_only' not in self.kws:
gosubdag = GoSubDag(gos_assoc, self.gosubdag.go2obj,
self.gosubdag.relationships)
Expand Down
47 changes: 32 additions & 15 deletions goatools/gosubdag/rpt/write_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def __init__(self, gosubdag, **kws):
self.usrset = set([k for k, v in kws.items() if k in kws and v])
# ' {NS} {dcnt:6,} L{level:02} D{depth:02} {D1:5} {GO_name}'

def prt_hier_all(self, prt=sys.stdout, sortby=None):
def prt_hier_all(self, prt=sys.stdout):
"""Write hierarchy for all GO Terms in obo file."""
# Print: [biological_process, molecular_function, and cellular_component]
items_printed = set()
items_list = set()
for goid in ['GO:0008150', 'GO:0003674', 'GO:0005575']:
items_printed.update(self.prt_hier_down(goid, prt, sortby))
return items_printed
items_list.update(self.prt_hier_down(goid, prt))
return items_list

def prt_hier_down(self, goid, prt=sys.stdout, sortby=None):
def prt_hier_down(self, goid, prt=sys.stdout):
"""Write hierarchy for all GO IDs below GO ID in arg, goid."""
wrhiercfg = self._get_wrhiercfg()
obj = WrHierPrt(self.gosubdag.go2obj, self.gosubdag.go2nt, wrhiercfg, prt)
Expand All @@ -42,13 +42,11 @@ def prt_hier_up(self, goids, prt=sys.stdout):
"""Write hierarchy for all GO IDs below GO ID in arg, goid."""
go2goterm_all = {go:self.gosubdag.go2obj[go] for go in goids}
objp = GoPaths()
items_printed = set()
wrhiercfg = self._get_wrhiercfg()
items_list = []
for namespace, go2term_ns in self._get_namespace2go2term(go2goterm_all).items():
go_root = self.consts.NAMESPACE2GO[namespace]
goids_all = set() # GO IDs from user-specfied GO to root
for goid, goterm in go2term_ns.items():
goids_all.add(goid)
for goid_usr, goterm in go2term_ns.items():
goids_all.add(goid_usr)
paths = objp.get_paths_from_to(goterm, goid_end=None, dn0_up1=True)
goids_all.update(set(o.id for p in paths for o in p))
# Only include GO IDs from user-specified GO to the root
Expand All @@ -57,12 +55,17 @@ def prt_hier_up(self, goids, prt=sys.stdout):
self.usrdct['include_only'].update(goids_all)
# Mark the user-specfied GO term
if 'item_marks' not in self.usrdct:
self.usrdct['item_marks'] = set()
self.usrdct['item_marks'].update(go2term_ns.keys())
self.usrdct['item_marks'] = {}
for goid_usr in go2term_ns.keys():
if goid_usr not in self.usrdct['item_marks']:
self.usrdct['item_marks'][goid_usr] = '*'
# Write the hierarchy
wrhiercfg = self._get_wrhiercfg()
obj = WrHierPrt(self.gosubdag.go2obj, self.gosubdag.go2nt, wrhiercfg, prt)
items_printed.update(obj.items_printed)
go_root = self._get_goroot(goids_all, namespace)
obj.prt_hier_rec(go_root)
return items_printed
items_list.extend(obj.items_list)
return items_list

@staticmethod
def _get_namespace2go2term(go2terms):
Expand All @@ -80,13 +83,27 @@ def _get_wrhiercfg(self):
return {'name2prtfmt':{'ITEM':prtfmt, 'ID':'{GO}{alt:1}'},
'max_indent': self.usrdct.get('max_indent'),
'include_only': self.usrdct.get('include_only'),
'item_marks': self.usrdct.get('item_marks', set()),
'item_marks': self.usrdct.get('item_marks', {}),
'concise_prt': 'concise' in self.usrset,
'indent': 'no_indent' not in self.usrset,
'dash_len': self.usrdct.get('dash_len', 6),
'sortby': self.usrdct.get('sortby')
}

def _get_goroot(self, goids_all, namespace):
"""Get the top GO for the set of goids_all."""
root_goid = self.consts.NAMESPACE2GO[namespace]
if root_goid in goids_all:
return root_goid
root_goids = set()
for goid in goids_all:
goterm = self.gosubdag.go2obj[goid]
if goterm.depth == 0:
root_goids.add(goterm.id)
if len(root_goids) == 1:
return next(iter(root_goids))
raise RuntimeError("UNEXPECTED NUMBER OF ROOTS: {R}".format(R=root_goids))

#### Examples:
####
#### Print the GO IDs associated with human genes
Expand Down
13 changes: 11 additions & 2 deletions goatools/rpt/write_hierarchy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, id2obj, id2nt, cfg, prt=sys.stdout):
self.nm2prtfmt = cfg['name2prtfmt'] if self.do_prtfmt else None
self.max_indent = cfg['max_indent']
self.include_only = cfg['include_only']
self.item_marks = cfg['item_marks']
self.item_marks = self._init_item_marks(cfg.get('item_marks'))
self.concise_prt = cfg.get('concise_prt', False)
self.indent = cfg.get('indent', True)
self.space_branches = cfg.get('space_branches', False)
Expand All @@ -42,7 +42,8 @@ def prt_hier_rec(self, item_id, depth=1):
self.prt.write("\n")
# Print marks if provided
if self.item_marks:
self.prt.write('{MARK} '.format(MARK='>' if item_id in self.item_marks else ' '))
self.prt.write('{MARK} '.format(
MARK=self.item_marks[item_id] if item_id in self.item_marks else ' '))

no_repeat = self.concise_prt and item_id in self.items_printed
# Print content
Expand Down Expand Up @@ -89,5 +90,13 @@ def _str_dash(self, depth, no_repeat, obj):
return ''.join([letter]*depth)
return ""

@staticmethod
def _init_item_marks(item_marks):
"""Initialize the makred item dict."""
if isinstance(item_marks, dict):
return item_marks
if item_marks:
return {item_id:'>' for item_id in item_marks}


# Copyright (C) 2016-2018, DV Klopfenstein, H Tang. All rights reserved.
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ clobber:
rm -f *.png
rm -f gos_*
rm -f cell_cycle_genes_*.txt
rm -r *.gpa.gz
rm -f *.gpa.gz

# Tests which run longer and have much functionality covered by other tests
# tests/test_annotations_gaf.py \
Expand Down
31 changes: 27 additions & 4 deletions tests/test_write_hier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
"""Tests alternate OBOReader."""

import os
Expand Down Expand Up @@ -58,17 +59,37 @@ def write_hier_lim(gosubdag, out):
'GO:0000009']


def write_hier_mrk(gosubdag, out):
def write_hier_mrk_lst(gosubdag, out):
"""Print all paths, but mark GO Terms of interest. """
mark_lst = ['GO:0000001', 'GO:0000003', 'GO:0000006', 'GO:0000008', 'GO:0000009']
out.write('\nTEST MARK: 01->03->06->08->09:\n')
out.write('\nTEST MARK LIST: 01->03->06->08->09:\n')
objwr = WrHierGO(gosubdag, item_marks=mark_lst, sortby=lambda o: o.item_id)
gos_printed = objwr.prt_hier_down("GO:0000001", out)
assert gos_printed == ['GO:0000001', 'GO:0000002', 'GO:0000005', 'GO:0000010',
'GO:0000003', 'GO:0000004', 'GO:0000007', 'GO:0000009',
'GO:0000005', 'GO:0000010', 'GO:0000006', 'GO:0000008',
'GO:0000009', 'GO:0000010']
#item_marks=[oGO.id for oGO in oGOs_in_cluster])

def write_hier_mrk_dct(gosubdag, out):
"""Print all paths, but mark GO Terms of interest. """
mark_dct = {'GO:0000001':'a', 'GO:0000003':'b', 'GO:0000006':'c',
'GO:0000008':'d', 'GO:0000009':'e'}
out.write('\nTEST MARK DICT: 01->03->06->08->09:\n')
objwr = WrHierGO(gosubdag, item_marks=mark_dct, sortby=lambda o: o.item_id)
gos_printed = objwr.prt_hier_down("GO:0000001", out)
assert gos_printed == ['GO:0000001', 'GO:0000002', 'GO:0000005', 'GO:0000010',
'GO:0000003', 'GO:0000004', 'GO:0000007', 'GO:0000009',
'GO:0000005', 'GO:0000010', 'GO:0000006', 'GO:0000008',
'GO:0000009', 'GO:0000010']

def write_hier_up(gosubdag, out):
"""Print all paths, but mark GO Terms of interest. """
mark_dct = {'GO:0000001':'a', 'GO:0000003':'b', 'GO:0000006':'c', 'GO:0000008':'d'}
out.write('\nTEST MARK DICT: 01->03->06->08->09:\n')
objwr = WrHierGO(gosubdag, item_marks=mark_dct, sortby=lambda o: o.item_id)
gos_printed = objwr.prt_hier_up(["GO:0000005"], out)
assert gos_printed == ['GO:0000001', 'GO:0000002', 'GO:0000005', 'GO:0000003', 'GO:0000005']


#################################################################
# Tests
Expand All @@ -84,7 +105,9 @@ def test_all():
write_hier_all(gosubdag, out)
write_hier_norep(gosubdag, out)
write_hier_lim(gosubdag, out)
write_hier_mrk(gosubdag, out)
write_hier_mrk_lst(gosubdag, out)
write_hier_mrk_dct(gosubdag, out)
write_hier_up(gosubdag, out)
msg = "Elapsed HMS: {}\n\n".format(str(datetime.timedelta(seconds=(toc-tic))))
sys.stdout.write(msg)

Expand Down

0 comments on commit bacc42a

Please sign in to comment.