In [1]:
!pip install -q kfp==1.7.1 kfpdist  -i https://pypi.tuna.tsinghua.edu.cn/simple

In [5]:
import kfp
from kfp import dsl as v1dsl
from kfp.v2 import dsl
from kfp.v2.dsl import Input, InputPath, Output, OutputPath, Dataset, Model, Metrics, component
import kfp.compiler as compiler
import os
from typing import List
from kfpdist import set_dist_train_config

@component
def prepare(n_workers: int,
            output_list_path: OutputPath(List[int])):
    import json
    with open(output_list_path, 'w') as fn:
        fn.write(json.dumps(list(range(n_workers))))

@component(base_image='tensorflow/tensorflow:2.5.1-gpu', packages_to_install=['tensorflow_datasets'])
def train(rank:   int,
          nranks: int,
          output_model_path: OutputPath('output_model_path'),
          output_tblog_path: OutputPath('output_tblog_path'),
          output_ckpt_path:  OutputPath('output_ckpt_path')):
    import tensorflow as tf
    from tensorflow.keras import layers, models
    import os
    
    set_dist_train_config(rank, nranks, 'train', port=9888)
    
    def make_datasets_unbatched():
        import tensorflow_datasets as tfds
        BUFFER_SIZE = 10000
        # Scaling MNIST data from (0, 255] to (0., 1.]
        def scale(image, label):
            image = tf.cast(image, tf.float32)
            image /= 255
            return image, label
        datasets, _ = tfds.load(name='mnist', with_info=True, as_supervised=True)

        return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)
    
    def build_and_compile_cnn_model():
        model = models.Sequential()
        model.add(
          layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(10, activation='softmax'))
        model.summary()
        model.compile(optimizer='adam',
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])
        return model

    def decay(epoch):
        if epoch < 3:
            return 1e-3
        if 3 <= epoch < 7:
            return 1e-4
        return 1e-5

    # Use NCCL if you need to use GPU
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
      communication=tf.distribute.experimental.CollectiveCommunication.RING)
    BATCH_SIZE_PER_REPLICA = 64
    BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

    with strategy.scope():
        ds_train = make_datasets_unbatched_s3().batch(BATCH_SIZE).repeat()
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = \
            tf.data.experimental.AutoShardPolicy.DATA
        ds_train = ds_train.with_options(options)
        multi_worker_model = build_and_compile_cnn_model()

    checkpoint_prefix = os.path.join(output_ckpt_path, "ckpt_{epoch}")

    class PrintLR(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None): #pylint: disable=no-self-use
            print('\nLearning rate for epoch {} is {}'.format(
                epoch + 1, multi_worker_model.optimizer.lr.numpy()))

    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir=output_tblog_path),
        tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                           save_weights_only=True),
        tf.keras.callbacks.LearningRateScheduler(decay),
        PrintLR()
    ]
    multi_worker_model.fit(ds_train,
                           epochs=10,
                           steps_per_epoch=70,
                           callbacks=callbacks)

    def is_chief():
        return rank == 0

    if is_chief():
        model_path = output_model_path + '/v1'
    else:
        model_path = output_model_path + '/worker_tmp_' + str(rank)
    multi_worker_model.save(model_path)

@dsl.pipeline(pipeline_root='', name='quickstart-pipeline')
def dist_train_pipeline(n_workers:       int = 2,
                        use_gpu:         int = 0,
                        gpus_per_worker: int = 1,
                        num_epochs:      int = 15,
                        train_batchsize: int = 128,
                        test_batchsize:  int = 128,
                        learning_rate:   float = 0.001,
                        model_version:   int = 1):
    prep_step = prepare(n_workers)
    with dsl.ParallelFor(prep_step.outputs['output_list']) as rank:
        train_step = train(rank, n_workers)
        # Uncomment below line if you need to use GPU.
        # train_step.set_gpu_limit(1)

if __name__ == '__main__':
    limg = "typhoon1986/ml-pipeline-kfp-launcher:1.7.1"
    kfp.Client().create_run_from_pipeline_func(dist_train_pipeline,
                                               launcher_image=limg,
                                               mode=v1dsl.PipelineExecutionMode.V2_COMPATIBLE,
                                               arguments={})