Skip to content

Commit

Permalink
Merge pull request #20 from LokiLuciferase/feature-shap-summary-plot
Browse files Browse the repository at this point in the history
Feature shap summary plot
  • Loading branch information
LokiLuciferase committed Mar 7, 2020
2 parents 02da355 + 280f29f commit 582fe11
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 33 deletions.
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ phenotrex
:target: https://codecov.io/gh/univieCUBE/phenotrex
:alt: Codecov

.. image:: https://img.shields.io/lgtm/grade/python/g/LokiLuciferase/phenotrex.svg?logo=lgtm&logoWidth=18
:target: https://lgtm.com/projects/g/LokiLuciferase/phenotrex/context:python
:alt: Code Quality

.. image:: https://travis-ci.com/univieCUBE/phenotrex.svg?branch=master
:target: https://travis-ci.com/univieCUBE/phenotrex
:alt: Travis CI
Expand Down
46 changes: 40 additions & 6 deletions phenotrex/cli/generic_func.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from pprint import pformat

from phenotrex.io.flat import (load_training_files, load_params_file,
from phenotrex.io.flat import (load_training_files, load_genotype_file, load_params_file,
write_weights_file, write_params_file,
write_misclassifications_file,
write_cccv_accuracy_file)
from phenotrex.io.serialization import save_classifier
from phenotrex.io.serialization import save_classifier, load_classifier
from phenotrex.util.logging import get_logger
from phenotrex.ml import TrexSVM, TrexXGB
from phenotrex.ml import TrexSVM, TrexXGB, ShapHandler
from phenotrex.transforms.annotation import fastas_to_grs

CLF_MAPPER = {'svm': TrexSVM, 'xgb': TrexXGB}
logger = get_logger("phenotrex", verb=True)
Expand All @@ -23,7 +24,9 @@ def _fix_uppercase(kwargs):

def generic_train(type, genotype, phenotype, verb, weights, out,
n_features=None, params_file=None, *args, **kwargs):
"""Train and save a TrexClassifier model."""
"""
Train and save a TrexClassifier model.
"""
kwargs = _fix_uppercase(kwargs)
training_records, *_ = load_training_files(genotype_file=genotype,
phenotype_file=phenotype,
Expand Down Expand Up @@ -90,7 +93,8 @@ def generic_cv(type, genotype, phenotype, folds, replicates, threads, verb, opti


def generic_cccv(type, genotype, phenotype, folds, replicates, threads, comple_steps, conta_steps,
verb, groups=None, rank=None, optimize=False, out=None, n_features=None, params_file=None,
verb, groups=None, rank=None, optimize=False, out=None, n_features=None,
params_file=None,
*args, **kwargs):
"""
Perform crossvalidation over a range of simulated completeness/contamination values,
Expand All @@ -111,5 +115,35 @@ def generic_cccv(type, genotype, phenotype, folds, replicates, threads, comple_s
reduce_features = True if n_features is not None else False
cccv = clf.crossvalidate_cc(records=training_records, cv=folds, n_replicates=replicates,
comple_steps=comple_steps, conta_steps=conta_steps,
n_jobs=threads, reduce_features=reduce_features, n_features=n_features)
n_jobs=threads, reduce_features=reduce_features,
n_features=n_features)
write_cccv_accuracy_file(out, cccv)


def generic_compute_shaps(fasta_files, genotype, classifier, n_samples, verb):
"""
Given a genotype file and/or a collection of possibly gzipped FASTA files as well as a
phenotrex classifier, collect genotype information from both, get SHAP information about the
genotypes using the classifier, and return a finished ShapHandler object as well as the list
of GenotypeRecords created.
"""
if not len(fasta_files) and genotype is None:
raise RuntimeError(
'Must either supply FASTA file(s) or single genotype file for prediction.')
if len(fasta_files):
grs_from_fasta = fastas_to_grs(fasta_files, n_threads=None, verb=verb)
else:
grs_from_fasta = []

grs_from_file = load_genotype_file(genotype) if genotype is not None else []
gr = grs_from_fasta + grs_from_file

model = load_classifier(filename=classifier, verb=verb)
sh = ShapHandler.from_clf(model)
try:
fs, sv, bv = model.get_shap(gr, nsamples=n_samples)
except TypeError:
raise RuntimeError('This TrexClassifier is not capable of generating SHAP explanations.')
sh.add_feature_data(sample_names=[x.identifier for x in gr],
features=fs, shaps=sv, base_value=bv)
return sh, gr
55 changes: 31 additions & 24 deletions phenotrex/cli/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,34 @@ def cccv(inputs, out, title):
compleconta_plot(cccv_results=cccv_results, conditions=conditions, title=title, save_path=out)


@plot.command('shap-force')
@plot.command('shap-summary', short_help='Plot summary of SHAP feature contributions.')
@click.argument('fasta_files', type=click.Path(exists=True), nargs=-1)
@click.option('--genotype', type=click.Path(exists=True),
required=False, help='Input genotype file.')
@click.option('--classifier', required=True, type=click.Path(exists=True),
help='Path of pickled classifier file.')
@click.option('--out', required=True, type=click.Path(dir_okay=False),
help='The file to save the generated summary plot at.')
@click.option('--n_max_features', type=int, default=20,
help='The number of top most important features (by absolute SHAP value) to plot.')
@click.option('--n_samples', type=int, default=None,
help='The nsamples parameter of SHAP. '
'Only used by models which utilize a `shap.KernelExplainer` (e.g. TrexSVM).')
@click.option('--title', type=str, default='', help='Plot title.')
@click.option('--verb', is_flag=True)
def shap_summary(out, n_max_features, title, **kwargs):
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from .generic_func import generic_compute_shaps

sh, gr = generic_compute_shaps(**kwargs)
sh.plot_shap_summary(title=title, n_max_features=n_max_features)
plt.tight_layout()
plt.savefig(out)


@plot.command('shap-force', short_help='Plot SHAP feature contributions per sample.')
@click.argument('fasta_files', type=click.Path(exists=True), nargs=-1)
@click.option('--genotype', type=click.Path(exists=True),
required=False, help='Input genotype file.')
Expand All @@ -43,7 +70,7 @@ def cccv(inputs, out, title):
help='The nsamples parameter of SHAP. '
'Only used by models which utilize a `shap.KernelExplainer` (e.g. TrexSVM).')
@click.option('--verb', is_flag=True)
def shap_force(fasta_files, genotype, classifier, out_prefix, n_samples, verb):
def shap_force(out_prefix, **kwargs):
"""
Generate SHAP force plots for each sample (passed either as FASTA files or as genotype file).
All plots will be saved at the path `{out_prefix}_{sample_identifier}_force_plot.png`.
Expand All @@ -53,29 +80,9 @@ def shap_force(fasta_files, genotype, classifier, out_prefix, n_samples, verb):
mpl.use('Agg')
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
try:
from phenotrex.transforms import fastas_to_grs
except ModuleNotFoundError:
from phenotrex.util.helpers import fail_missing_dependency as fastas_to_grs
from phenotrex.io.flat import load_genotype_file
from phenotrex.io.serialization import load_classifier
from phenotrex.ml.shap_handler import ShapHandler
if not len(fasta_files) and genotype is None:
raise RuntimeError(
'Must either supply FASTA file(s) or single genotype file for prediction.')
if len(fasta_files):
grs_from_fasta = fastas_to_grs(fasta_files, n_threads=None, verb=verb)
else:
grs_from_fasta = []

grs_from_file = load_genotype_file(genotype) if genotype is not None else []
gr = grs_from_fasta + grs_from_file
from .generic_func import generic_compute_shaps

model = load_classifier(filename=classifier, verb=verb)
sh = ShapHandler.from_clf(model)
fs, sv, bv = model.get_shap(gr, nsamples=n_samples)
sh.add_feature_data(sample_names=[x.identifier for x in gr],
features=fs, shaps=sv, base_value=bv)
sh, gr = generic_compute_shaps(**kwargs)
for record in tqdm(gr, unit='samples', desc='Generating force plots'):
sh.plot_shap_force(record.identifier)
out_path = Path(f'{out_prefix}_{record.identifier}_force_plot.png')
Expand Down
3 changes: 2 additions & 1 deletion phenotrex/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from .clf.svm import TrexSVM
from .clf.xgbm import TrexXGB
from .shap_handler import ShapHandler


__all__ = ['TrexXGB', 'TrexSVM']
__all__ = ['TrexXGB', 'TrexSVM', 'ShapHandler']
5 changes: 4 additions & 1 deletion phenotrex/ml/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def predict(fasta_files=tuple(), genotype=None, classifier=None,

model = load_classifier(filename=classifier, verb=verb)
if out_explain_per_sample is not None or out_explain_summary is not None:
try:
fs, sv, bv = model.get_shap(gr, nsamples=shap_n_samples)
except TypeError:
raise RuntimeError('This TrexClassifier is not capable of generating SHAP explanations.')
sh = ShapHandler.from_clf(model)
fs, sv, bv = model.get_shap(gr, nsamples=shap_n_samples)
sh.add_feature_data(sample_names=[x.identifier for x in gr],
features=fs, shaps=sv, base_value=bv)
if out_explain_per_sample is not None:
Expand Down
2 changes: 1 addition & 1 deletion phenotrex/ml/shap_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def plot_shap_summary(self, title=None, n_max_features: int = 20,
max_display=n_max_features,
class_names=class_names,
title=f'SHAP Summary',
show=True,
show=False,
**kwargs)

def get_shap_force(self, sample_name: str, n_max_features: int = 20) -> pd.DataFrame:
Expand Down

0 comments on commit 582fe11

Please sign in to comment.