In [None]:
import tensorflow as tf
import os
import argparse
from tensorflow.python.keras.callbacks import Callback
from datetime import datetime, timezone
import logging

## 로그가 쌓일 위치를 정의
if os.path.exists("/result"):
    logging.basicConfig(filename='/result/mnist.log', level=logging.DEBUG)

## ML Model 정의 (학습된 모델이 저장될 위치, 학습을 위한 변수를 argument로 처리)
class MyFashionMnist(object):
  def train(self):
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--learning_rate', required=False, type=float, default=0.001)
    parser.add_argument('--dropout_rate', required=False, type=float, default=0.3)
    parser.add_argument('--epoch', required=False, type=int, default=5)    
    parser.add_argument('--act', required=False, type=str, default='relu')        
    parser.add_argument('--layer', required=False, type=int, default=1)      
    parser.add_argument('--model_version', required=False, type=str, default='0001')    
    parser.add_argument('--checkpoint_dir', required=False, default='/result/training_checkpoints')
    parser.add_argument('--saved_model_dir', required=False, default='/result/saved_model')        
    parser.add_argument('--tensorboard_log', required=False, default='/result/log')     
    args = parser.parse_args()    
    
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
    
    for i in range(int(args.layer)):    
        model.add(tf.keras.layers.Dense(128, activation=args.act))
        if(i > 2) :
            model.add(tf.keras.layers.Dropout(args.dropout_rate))
        
    model.add(tf.keras.layers.Dense(10, activation='softmax'))
    
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=args.learning_rate),
                  loss='sparse_categorical_crossentropy',
                  metrics=['acc'])
    
    model.summary()    
    
    checkpoint_dir = args.checkpoint_dir
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")        

    model.fit(x_train, y_train,
              verbose=0,
              validation_data=(x_test, y_test),
              epochs=args.epoch,
              callbacks=[KatibMetricLog(),
                        tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                               save_weights_only=True)
                        ])
    model.evaluate(x_test,  y_test, verbose=0)
    
    path = args.saved_model_dir + "/" + args.model_version        
    model.save(path, save_format='tf')

## 학습 로그 정의 (Katib에서 사용)
class KatibMetricLog(Callback):
    def on_epoch_end(self, epoch, logs=None):
        # RFC 3339
        local_time = datetime.now(timezone.utc).astimezone().isoformat()
        logging.info("\n{} accuracy={:.4f} loss={:.4f} Validation-accuracy={:.4f} Validation-loss={:.4f}"
                     .format(local_time, logs['acc'], logs['loss'], logs['val_acc'], logs['val_loss']))

## ML Model Image 생성 및 배포 
if __name__ == '__main__':
    if os.getenv('FAIRING_RUNTIME', None) is None:
        from kubeflow import fairing
        from kubeflow.fairing.kubernetes import utils as k8s_utils
        from kubeflow.fairing.builders.cluster.minio_context import MinioContextSource
        
        ## 이미지를 배포하기 위한 auth정보 등록
        ! kubectl delete cm docker-config
        ! kubectl create cm docker-config --from-file=/home/jovyan/.docker/config.json
        
        ## docker registry 계정
        DOCKER_REGISTRY = 'docker.io/kitaeyoo777'
        
        # kubeflow minio context
        minio_context_source = MinioContextSource(endpoint_url='http://minio-service.kubeflow.svc.cluster.local:9000', 
                                                  minio_secret='minio', 
                                                  minio_secret_key='minio123', 
                                                  region_name='us-east-1')
                
        fairing.config.set_builder(
            'cluster',
            image_name='sample-job',
            base_image='brightfly/kubeflow-jupyter-lab:tf2.0-cpu',
            registry=DOCKER_REGISTRY, 
            context_source=minio_context_source,
            push=True)


        fairing.config.run()
    else:
        remote_train = MyFashionMnist()
        remote_train.train()