Skip to content

Commit

Permalink
Adding train_predict.py
Browse files Browse the repository at this point in the history
  • Loading branch information
w4nderlust committed May 15, 2019
1 parent 00c0762 commit b420755
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 26 deletions.
6 changes: 6 additions & 0 deletions ludwig/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self):
experiment Runs a full experiment training a model and testing it
train Trains a model
predict Predicts using a pretrained model
test Tests a pretrained model
visualize Visualizes experimental results
collect_weights Collects tensors containing a pretrained model weights
collect_activations Collects tensors for each datapoint using a pretrained model
Expand Down Expand Up @@ -76,6 +77,11 @@ def predict(self):
ludwig.contrib.contrib_command("predict", *sys.argv)
predict.cli(sys.argv[2:])

def test(self):
from ludwig import test_performance
ludwig.contrib.contrib_command("test", *sys.argv)
test_performance.cli(sys.argv[2:])

def visualize(self):
from ludwig import visualize
ludwig.contrib.contrib_command("visualize", *sys.argv)
Expand Down
22 changes: 19 additions & 3 deletions ludwig/contribs/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Comet():
"""
Class that defines the methods necessary to hook into process.
"""

@staticmethod
def import_call(argv, *args, **kwargs):
"""
Expand All @@ -33,7 +34,8 @@ def import_call(argv, *args, **kwargs):
try:
import comet_ml
except ImportError:
logging.error("Ignored --comet: Please install comet_ml; see www.comet.ml")
logging.error(
"Ignored --comet: Please install comet_ml; see www.comet.ml")
return None

try:
Expand All @@ -50,7 +52,8 @@ def experiment(self, *args, **kwargs):
try:
self.cometml_experiment = comet_ml.Experiment(log_code=False)
except Exception:
logging.error("comet_ml.Experiment() had errors. Perhaps you need to define COMET_API_KEY")
logging.error(
"comet_ml.Experiment() had errors. Perhaps you need to define COMET_API_KEY")
return

logging.info("comet.experiment() called......")
Expand All @@ -66,7 +69,8 @@ def train(self, *args, **kwargs):
try:
self.cometml_experiment = comet_ml.Experiment(log_code=False)
except Exception:
logging.error("comet_ml.Experiment() had errors. Perhaps you need to define COMET_API_KEY")
logging.error(
"comet_ml.Experiment() had errors. Perhaps you need to define COMET_API_KEY")
return

logging.info("comet.train() called......")
Expand Down Expand Up @@ -139,6 +143,18 @@ def predict(self, *args, **kwargs):
cli = self._make_command_line(args)
self._log_html(cli)

def test(self, *args, **kwargs):
import comet_ml
try:
self.cometml_experiment = comet_ml.ExistingExperiment()
except Exception:
logging.error("Ignored --comet. No '.comet.config' file")
return

logging.info("comet.test() called......")
cli = self._make_command_line(args)
self._log_html(cli)

def _save_config(self, config):
## save the .comet.config here:
config["comet.experiment_key"] = self.cometml_experiment.id
Expand Down
9 changes: 4 additions & 5 deletions ludwig/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from ludwig.data.postprocessing import postprocess
from ludwig.globals import LUDWIG_VERSION, set_on_master, is_on_master
from ludwig.predict import predict
from ludwig.predict import print_prediction_results
from ludwig.predict import print_test_results
from ludwig.predict import save_prediction_outputs
from ludwig.predict import save_prediction_statistics
from ludwig.predict import save_test_statistics
from ludwig.train import full_train
from ludwig.utils.defaults import default_random_seed
from ludwig.utils.print_utils import logging_level_registry
Expand Down Expand Up @@ -233,10 +233,9 @@ def experiment(
)

if is_on_master():
print_prediction_results(test_results)

print_test_results(test_results)
save_prediction_outputs(postprocessed_output, experiment_dir_name)
save_prediction_statistics(test_results, experiment_dir_name)
save_test_statistics(test_results, experiment_dir_name)

model.close_session()

Expand Down
26 changes: 8 additions & 18 deletions ludwig/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def full_predict(
save_prediction_outputs(postprocessed_output, experiment_dir_name)

if evaluate_performance:
print_prediction_results(prediction_results)
save_prediction_statistics(prediction_results, experiment_dir_name)
print_test_results(prediction_results)
save_test_statistics(prediction_results, experiment_dir_name)

logging.info('Saved to: {0}'.format(experiment_dir_name))

Expand Down Expand Up @@ -208,18 +208,18 @@ def save_prediction_outputs(
save_csv(csv_filename.format(output_field, output_type), values)


def save_prediction_statistics(prediction_stats, experiment_dir_name):
def save_test_statistics(test_stats, experiment_dir_name):
test_stats_fn = os.path.join(
experiment_dir_name,
'prediction_statistics.json'
'test_statistics.json'
)
save_json(test_stats_fn, prediction_stats)
save_json(test_stats_fn, test_stats)


def print_prediction_results(prediction_stats):
for output_field, result in prediction_stats.items():
def print_test_results(test_stats):
for output_field, result in test_stats.items():
if (output_field != 'combined' or
(output_field == 'combined' and len(prediction_stats) > 2)):
(output_field == 'combined' and len(test_stats) > 2)):
logging.info('\n===== {} ====='.format(output_field))
for measure in sorted(list(result)):
if measure != 'confusion_matrix' and measure != 'roc_curve':
Expand Down Expand Up @@ -314,16 +314,6 @@ def cli(sys_argv):
default=128,
help='size of batches'
)
parser.add_argument(
'-ep',
'--evaluate_performance',
action='store_true',
default=False,
help='performs performance metrics calculation.'
'Requires that the dataset contains one column '
'for each output feature the model predicts '
'to use as ground truth for the performance calculation.'
)

# ------------------
# Runtime parameters
Expand Down
169 changes: 169 additions & 0 deletions ludwig/test_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#! /usr/bin/env python
# coding=utf-8
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import logging
import sys

from ludwig.contrib import contrib_command
from ludwig.globals import set_on_master, is_on_master, LUDWIG_VERSION
from ludwig.predict import full_predict
from ludwig.utils.print_utils import logging_level_registry, print_ludwig


def cli(sys_argv):
parser = argparse.ArgumentParser(
description='This script loads a pretrained model '
'and tests its performance by comparing'
'its predictions with ground truth.',
prog='ludwig test',
usage='%(prog)s [options]'
)

# ---------------
# Data parameters
# ---------------
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'--data_csv',
help='input data CSV file. '
'If it has a split column, it will be used for splitting '
'(0: train, 1: validation, 2: test), '
'otherwise the dataset will be randomly split'
)
group.add_argument(
'--data_hdf5',
help='input data HDF5 file. It is an intermediate preprocess version of'
' the input CSV created the first time a CSV file is used in the '
'same directory with the same name and a hdf5 extension'
)
parser.add_argument(
'--train_set_metadata_json',
help='input metadata JSON file. It is an intermediate preprocess file '
'containing the mappings of the input CSV created the first time '
'a CSV file is used in the same directory with the same name and '
'a json extension'
)

parser.add_argument(
'-s',
'--split',
default='test',
choices=['training', 'validation', 'test', 'full'],
help='the split to test the model on'
)

# ----------------
# Model parameters
# ----------------
parser.add_argument(
'-m',
'--model_path',
help='model to load',
required=True
)

# -------------------------
# Output results parameters
# -------------------------
parser.add_argument(
'-od',
'--output_directory',
type=str,
default='results',
help='directory that contains the results'
)
parser.add_argument(
'-ssuo',
'--skip_save_unprocessed_output',
help='skips saving intermediate NPY output files',
action='store_true', default=False
)

# ------------------
# Generic parameters
# ------------------
parser.add_argument(
'-bs',
'--batch_size',
type=int,
default=128,
help='size of batches'
)

# ------------------
# Runtime parameters
# ------------------
parser.add_argument(
'-g',
'--gpus',
type=int,
default=0,
help='list of gpu to use'
)
parser.add_argument(
'-gf',
'--gpu_fraction',
type=float,
default=1.0,
help='fraction of gpu memory to initialize the process with'
)
parser.add_argument(
'-uh',
'--use_horovod',
action='store_true',
default=False,
help='uses horovod for distributed training'
)
parser.add_argument(
'-dbg',
'--debug',
action='store_true',
default=False,
help='enables debugging mode'
)
parser.add_argument(
'-l',
'--logging_level',
default='info',
help='the level of logging to use',
choices=['critical', 'error', 'warning', 'info', 'debug', 'notset']
)

args = parser.parse_args(sys_argv)
args.evaluate_performance = True

logging.basicConfig(
stream=sys.stdout,
level=logging_level_registry[args.logging_level],
format='%(message)s'
)

set_on_master(args.use_horovod)

if is_on_master():
print_ludwig('Test', LUDWIG_VERSION)

full_predict(**vars(args))


if __name__ == '__main__':
contrib_command("test", *sys.argv)
cli(sys.argv[1:])

0 comments on commit b420755

Please sign in to comment.