Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.<br /><br />Licensed under the Amazon Software License (the "License"). You may not<br />use this file except in compliance with the License. A copy of the<br />License is located at:<br />   http://aws.amazon.com/asl/<br />or in the "license" file accompanying this file. This file is distributed<br />on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express<br />or implied. See the License for the specific language governing permissions<br />and limitations under the License.
# Verifying metric definitions

This notebook helps you verify that your metric definitions are behaving the way you expect them to.
It lets you quickly try different regular expressions (regexes) on actual TrainingJob log files.

In [1]:
import boto3

region = boto3.Session().region_name
sagemaker = boto3.Session().client('sagemaker')

## Look up training jobs & metrics from a tuning job
To test metric definitions, we'll apply them to CloudWatch logs from real training jobs.
The easiest way to get both of these is from an actual HyperParamaterTuningJob, which we'll show here.
You can override these if you'd like.

In [2]:
TUNING_JOB_NAME='xgboost-tuningjob-21-14-31-09'

In [3]:
training_job_list = sagemaker.list_training_jobs_for_hyper_parameter_tuning_job(HyperParameterTuningJobName=TUNING_JOB_NAME)['TrainingJobSummaries']
training_job_names = [desc['TrainingJobName'] for desc in training_job_list]
print("Found %d training jobs starting with %s" % (len(training_job_names), training_job_names[:5]))
# Note: this will be an incomplete list for large tuning jobs, because we're not paginating through the results.
# This is fine for what we're doing, which is just verifying -- we just need a sample.

Found 10 training jobs starting with ['xgboost-tuningjob-21-14-31-09-020-5b3b2be0', 'xgboost-tuningjob-21-14-31-09-019-ffb835b5', 'xgboost-tuningjob-21-14-31-09-018-ad3f88ee', 'xgboost-tuningjob-21-14-31-09-017-40c72940', 'xgboost-tuningjob-21-14-31-09-016-1b5dac84']


In [4]:
# Pick the specific training job to try applying metrics to.  You can specify it explicitly here if you want.
# TRAINING_JOB_NAME = 'your-training-job-name'  
# But by default, we'll take the first one from the tuning job.
TRAINING_JOB_NAME = training_job_names[0]
print("Using logs for training job %s" % TRAINING_JOB_NAME)

Using logs for training job xgboost-tuningjob-21-14-31-09-020-5b3b2be0


In [5]:
# Now get the metric definitions
tuning_job_desc = sagemaker.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=TUNING_JOB_NAME)
metric_definitions = tuning_job_desc['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions']
print("Found metric definitions:")
from pprint import pprint
pprint(metric_definitions)

Found metric definitions:
[{'Name': 'train:mae', 'Regex': '.*\\[[0-9]+\\]#011train-mae:(\\S+).*'},
 {'Name': 'validation:auc',
  'Regex': '.*\\[[0-9]+\\].*#011validation-auc:(\\S+)'},
 {'Name': 'train:merror', 'Regex': '.*\\[[0-9]+\\]#011train-merror:(\\S+).*'},
 {'Name': 'train:auc', 'Regex': '.*\\[[0-9]+\\]#011train-auc:(\\S+).*'},
 {'Name': 'validation:mae',
  'Regex': '.*\\[[0-9]+\\].*#011validation-mae:(\\S+)'},
 {'Name': 'validation:error',
  'Regex': '.*\\[[0-9]+\\].*#011validation-error:(\\S+)'},
 {'Name': 'validation:merror',
  'Regex': '.*\\[[0-9]+\\].*#011validation-merror:(\\S+)'},
 {'Name': 'validation:logloss',
  'Regex': '.*\\[[0-9]+\\].*#011validation-logloss:(\\S+)'},
 {'Name': 'train:rmse', 'Regex': '.*\\[[0-9]+\\]#011train-merror:(\\S+).*'},
 {'Name': 'train:logloss', 'Regex': '.*\\[[0-9]+\\]#011train-logloss:(\\S+).*'},
 {'Name': 'train:mlogloss',
  'Regex': '.*\\[[0-9]+\\]#011train-mlogloss:(\\S+).*'},
 {'Name': 'validation:rmse',
  'Regex': '.*\\[[0-9]+\\].*#011va

## Fetch TrainingJob log from CloudWatch and apply MetricDefinition to it
Simulate the exist metric definitions in that hyperparamter tuning job on the first training job's log.
Then later we'll try changing the regex and see how that works on the same log.

In [6]:
import boto3
cwl = boto3.client("logs")

In [7]:
# Find log streams for this training job
log_stream_descs = cwl.describe_log_streams(logGroupName="/aws/sagemaker/TrainingJobs", 
                                            logStreamNamePrefix=TRAINING_JOB_NAME)['logStreams']
# Just pick the first one
log_stream = log_stream_descs[0]['logStreamName']

In [8]:
import boto3
import re
from IPython.core.display import display, HTML

class MetricDefinitionVerifier(object):
    
    def __init__(self, metric_definitions, log_stream, show_nonmatching=False):
        self._log_stream_name = log_stream
        self._show_nonmatching = show_nonmatching
        self.reset_log()
        self._next_token = None
        self._cwl = boto3.client('logs')
        self._log_group = "/aws/sagemaker/TrainingJobs"
        self._metric_defns = []
        for md in metric_definitions:
            self.set_metric_definition(md['Name'],md['Regex'])

    def set_metric_definition(self, name, regex):
        """Add or replace a metric definition with the specified name.
        """
        # Remove existing metric with this name if it's already in the list
        self._metric_defns = [md for md in self._metric_defns if md['Name'] != name]
        # Build the new entry
        md = {
            'Name': name,
            'Regex': regex,
        }
        try:
            md['re'] = re.compile(md['Regex'])
        except:
            print("Failed to compile regex for MetricDefinition %s." % md['Name'])
            raise
        self._metric_defns.append(md)           
    
    def reset_log(self):
        """Reset to the beginning of the log stream"""
        self._next_token = None
    
    def next_page(self, show_nonmatching=None):
        """Fetches a page of log events and processes them all"""
        if not self._metric_defns:
            raise RuntimeError("No metric definitions defined.  Use .set_metric_definition()")
        if show_nonmatching is not None:
            self._show_nonmatching = show_nonmatching
        token_args = {}
        if self._next_token:
            token_args['nextToken'] = self._next_token
        events_result = self._cwl.get_log_events(logGroupName=self._log_group, 
                                                 logStreamName=self._log_stream_name,
                                                 **token_args)
        self._next_token = events_result['nextForwardToken']
        matches = 0
        cnt = 0
        for event in events_result['events']:
            msg = event['message']
            matches += self.process_message(msg)
            cnt += 1
        print("Done with page.  Found matches on %d of %d lines" % (matches,cnt))
            
    def process_message(self, msg):
        """Processes a single cloudwatch event against the defined metrics"""
        html = None
        for md in self._metric_defns:
            match = md['re'].search(msg)
            if match:
                if not html:
                    # Print the line on the first match
                    html = "<div><div style='background-color:#afa'>line matched: '<tt>%s</tt>'</div>\n" % msg
                html += """
                    <ul>
                        <b>Metric %s matched</b><br/>
                        match='<tt>%s</tt>'<br/>
                        captured value='<tt>%s</tt>'
                    </ul>""" % (md['Name'], match.group(0), match.group(1))
        if html:
            html += "</div>\n"
            display(HTML(html))
            return 1
        else:
            if self._show_nonmatching:
                display(HTML("<div style='background-color:#ccc'>no-match: '<tt>%s</tt>'</div>" % msg))
            return 0

In [None]:
mdv = MetricDefinitionVerifier(metric_definitions, log_stream, show_nonmatching=True)

In [None]:
# Process the first page of log events, but only show lines with matches
mdv.next_page()

In [None]:
# Can call this method repeatedly to go through the log
mdv.next_page()

In [None]:
# Try setting a new metric definition and see how it works.
mdv.set_metric_definition(name="loss", regex="loss = ([0-9\\.])+")
mdv.reset_log()
mdv.next_page(show_nonmatching=False)