In [None]:
# default_exp entry_point

In [3]:
#export

import os
import argparse
import time
import importlib
import inspect
import logging
import re
from gcp_runner.ai_platform_constants import DistributionStrategyType

def get_distribution_strategy_instance(distribution_strategy_type:DistributionStrategyType):
    import tensorflow as tf

    logging.info('initializing distribution strategy: %s', distribution_strategy_type)
    if distribution_strategy_type == DistributionStrategyType.TPU_STRATEGY:
        tpu = None
        try:
            logging.info('resolving to TPU cluster')
            tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            logging.info('connecting to TPU cluster')
            tf.config.experimental_connect_to_cluster(tpu)
        except ValueError as e:
            logging.info('error connecting to TPU cluster: %s', e)
            return None
        logging.info('initializing TPU system')
        tf.tpu.experimental.initialize_tpu_system(tpu)
        distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
        logging.info(
            'training using TPUStrategy, tpu.cluster_spec: %s', tpu.cluster_spec())
        return distribution_strategy
    elif distribution_strategy_type == DistributionStrategyType.ONE_DEVICE_STRATEGY:
        return tf.distribute.OneDeviceStrategy(device="/cpu:0")
    elif distribution_strategy_type:
        return eval(distribution_strategy_type.value)()
    
def parse_unknown_args(unknown_args):
    kwargs = {}
    for arg in unknown_args:
        parse_result = re.search('--(.+)=(.+)', arg, re.IGNORECASE)
        if parse_result is not None and parse_result.groups() is not None and len(parse_result.groups()) == 2:
            kwargs[parse_result.group(1).replace('-', '_')] = parse_result.group(2)
        else:
            print("can't parse argument: %s" % arg)
    return kwargs



In [4]:
#export

import os
import json
import socket

def setup_keras_tuner_config():
    if 'TF_CONFIG' in os.environ:
        try:
            tf_config = json.loads(os.environ['TF_CONFIG'])
            cluster = tf_config['cluster']
            task = tf_config['task']
            chief_addr = cluster['chief'][0].split(':')
            chief_ip = socket.gethostbyname(chief_addr[0])
            chief_port = chief_addr[1]
            os.environ['KERASTUNER_ORACLE_IP'] = chief_ip
            os.environ['KERASTUNER_ORACLE_PORT'] = chief_port
            if task['type'] == 'chief':
                os.environ['KERASTUNER_TUNER_ID'] = 'chief'
            else:
                os.environ['KERASTUNER_TUNER_ID'] = 'tuner{}'.format(task['index'])

            print('set following environment arguments:')
            print('KERASTUNER_ORACLE_IP: %s' % os.environ['KERASTUNER_ORACLE_IP'])
            print('KERASTUNER_ORACLE_PORT: %s' % os.environ['KERASTUNER_ORACLE_PORT'])
            print('KERASTUNER_TUNER_ID: %s' % os.environ['KERASTUNER_TUNER_ID'])
        except Exception as ex:
            print('Error setting up keras tuner config: %s' % str(ex))

In [None]:
#export

def main():
    print('in gcp_runner entry point')
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument(
        '--module-name',
        help='module name of an app to run',
        required=True)
    args_parser.add_argument(
        '--function-name',
        help='function name to run',
        required=True)
    args_parser.add_argument(
        '--distribution-strategy-type',
        help='distribution strategy',
        choices=list([e.value for e in DistributionStrategyType]))
    args_parser.add_argument(
        '--use-distribution-strategy-scope',
        action='store_true',
        help='whether to run training in a distribution strategy scope',
        default=False)

    setup_keras_tuner_config()    
    args, unknown_args = args_parser.parse_known_args()
    print('running entrypoint function: %s.%s' % (args.module_name, args.function_name))
    module = importlib.import_module(args.module_name)
    func = getattr(module, args.function_name)
    if unknown_args is None or len(unknown_args) == 0:
        return func()
    
    print('additional args: %s' % str(unknown_args))
    args_spec = inspect.getargspec(func)
    print(args_spec)
    if args_spec is None or args_spec.keywords is None or len(args_spec.keywords) == 0:
        print('provided function does not take any arguments, running as is')
        return func()
        
    distribution_strategy = None
    kwargs = parse_unknown_args(unknown_args)
    if (args.distribution_strategy_type is not None):
        distribution_strategy_type = DistributionStrategyType(args.distribution_strategy_type)
        distribution_strategy = get_distribution_strategy_instance(distribution_strategy_type)
        kwargs['distribution_strategy_type'] = distribution_strategy_type
        
    if distribution_strategy is not None:
        kwargs['distribution_strategy'] = distribution_strategy
        if args.use_distribution_strategy_scope:
            print('running code in %s scope' % args.distribution_strategy_type)
            with distribution_strategy.scope():
                return func(**kwargs)
    
    func(**kwargs)
    
if __name__ == '__main__':
    main()

In [6]:
#hide

from gcp_runner import core
core.export_and_reload_all(silent=True, ignore_errors=False)