Skip to content

Commit

Permalink
handle empty plotvals (in yet another place...)
Browse files Browse the repository at this point in the history
and fiddle with sklearn/numpy imports to avoid warning
  • Loading branch information
psathyrella committed Nov 15, 2019
1 parent f2ce3cf commit 82f4581
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
7 changes: 5 additions & 2 deletions python/mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy
import itertools
import subprocess
import warnings

import utils

Expand Down Expand Up @@ -233,8 +234,10 @@ def run_bios2mds(n_components, n_clusters, seqfos, base_workdir, seed, aligned=F
# ----------------------------------------------------------------------------------------
def run_sklearn_mds(n_components, n_clusters, seqfos, seed, reco_info=None, region=None, aligned=False, n_init=4, max_iter=300, eps=1e-3, n_jobs=-1, plotdir=None):
print '%s not testing this after moving these imports down here' % utils.color('red', 'hey')
from sklearn import manifold # these are both slow af to import, even on local ssd
from sklearn.cluster import KMeans
with warnings.catch_warnings(): # NOTE not sure this is actually catching the warnings
warnings.simplefilter('ignore') # numpy is complaining about how sklearn is importing something, and I really don't want to *@*($$ing hear about it
from sklearn import manifold # these are both slow af to import, even on local ssd
from sklearn.cluster import KMeans

if len(set(sfo['name'] for sfo in seqfos)) != len(seqfos):
raise Exception('duplicate sequence ids in <seqfos>')
Expand Down
4 changes: 4 additions & 0 deletions python/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,8 @@ def get_xval_dict(uids, xkey):
# ----------------------------------------------------------------------------------------
def getbounds(xkey):
all_xvals = [x for c in sorted_clusters for x in get_xval_list(c, xkey)]
if len(all_xvals) == 0:
return None
bounds = [f(all_xvals) for f in [min, max]]
if global_max_vals is not None and xkey in global_max_vals:
bounds[1] = global_max_vals[xkey]
Expand Down Expand Up @@ -1243,6 +1245,8 @@ def add_hist(xkey, xvals, yval, iclust, cluster, median_x1, fixed_x1max, base_al
xbounds = {x1key : getbounds(x1key)} # these are the smallest/largest x values in any of <sorted_clusters>, whereas <high_x_val> is a fixed calling-fcn-specified value that may be more or less (kind of wasteful to get all the x vals here and then also in the main loop)
if x2key is not None:
xbounds[x2key] = getbounds(x2key)
if any(xbounds[xk] is None for xk in xbounds):
return 'no values' if high_x_val is None else high_x_clusters # 'no values' isn't really a file name, it just shows up as a dead link in the html
fixed_xmax = high_x_val if high_x_val is not None else xbounds[x1key][1] # xmax to use for the plotting (ok now there's three max x values, this is getting confusing)

if debug:
Expand Down
8 changes: 6 additions & 2 deletions python/treeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
import dendropy
import time
import math
from sklearn import tree as sktree
from sklearn import ensemble
import yaml
import pickle
import warnings
if StrictVersion(dendropy.__version__) < StrictVersion('4.0.0'): # not sure on the exact version I need, but 3.12.0 is missing lots of vital tree fcns
raise RuntimeError("dendropy version 4.0.0 or later is required (found version %s)." % dendropy.__version__)

Expand Down Expand Up @@ -87,6 +86,11 @@ def dtrfname(dpath, cg):

# ----------------------------------------------------------------------------------------
def train_dtr(cgroup, dtrfo, dmodels, outdir, min_samples_leaf=5, max_depth=10, n_estimators=10, dump_training_data=False):
with warnings.catch_warnings(): # NOTE not sure this is actually catching the warnings
warnings.simplefilter('ignore', category=DeprecationWarning) # numpy is complaining about how sklearn is importing something, and I really don't want to *@*($$ing hear about it
from sklearn import tree as sktree
from sklearn import ensemble

print ' %s' % cgroup.replace('-', ' ')

base_regr = sktree.DecisionTreeRegressor(min_samples_leaf=min_samples_leaf, max_depth=max_depth)
Expand Down

0 comments on commit 82f4581

Please sign in to comment.