In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.python.lib.io import file_io
import pathlib

import sys
import json
import pandas as pd
import os
import argparse

from sklearn.metrics import confusion_matrix

from datetime import datetime

# Helper libraries
import numpy as np

In [2]:
def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--bucket_name',
                        type=str,
                        default='gs://kbc/ccc',
                        help='The bucket where the model has to be stored')
    parser.add_argument('--epochs',
                        type=int,
                        default=1,
                        help='Number of epochs for training the model')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='the batch size for each epoch')
    parser.add_argument('--katib',
                        type=int,
                        default=0,
                        help='to save model or not')
    parser.add_argument('--optimizer_name',
                        type=str,
                        default='Adam',
                        help='optimizer to use in model')
    
    return parser

In [3]:
def train(bucket_name, epochs, batch_size, katib, optimizer_name):
    
    testX, testy, trainX, trainy = load_data(bucket_name)
    
    dnn = create_tfmodel(
        # optimizer=tf.keras.optimizers.get(optimizer_name),
        # optimizer=tf.keras.optimizers.get('SGD'),     
        optimizer=tf.keras.optimizers.get('Adam'),
        loss='binary_crossentropy',
        metrics=['accuracy'],
        input_dim=trainX.shape[1])

    dnn.summary()

    dnn.fit(trainX, trainy, epochs=epochs, batch_size=batch_size)

    test_loss, test_acc = dnn.evaluate(testX, testy, verbose=2)
    print("accuracy={:.2f}".format(test_acc))
    print("test-loss={:.2f}".format(test_loss))

    predictions = dnn.predict_classes(testX)

    if katib == 0:
        save_tfmodel_in_gcs(bucket_name, dnn)
        create_kf_visualization(bucket_name, testy, predictions, test_acc)

In [4]:
def save_tfmodel_in_gcs(bucket_name, model):
    export_path = bucket_name + '/export/model/1'
    tf.saved_model.save(model, export_dir=export_path)
    

def create_tfmodel(optimizer, loss, metrics, input_dim):
    model = Sequential()
    model.add(Dense(input_dim, activation='relu', input_dim=input_dim))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(optimizer, loss, metrics)
    return model


def create_kf_visualization(bucket_name, test_label, predict_label, test_acc):
    metrics = {
        'metrics': [{
            'name': 'accuracy-score',
            'numberValue': str(test_acc),
            'format': "PERCENTAGE"
        }]
    }

    with file_io.FileIO('/mlpipeline-metrics.json', 'w') as f:
        json.dump(metrics, f)

    pred = pd.DataFrame(data=predict_label, columns=['predicted'])

    vocab = [0, 1]
    cm = confusion_matrix(test_label, pred['predicted'], labels=vocab)
    data = []
    for target_index, target_row in enumerate(cm):
        for predicted_index, count in enumerate(target_row):
            data.append((vocab[target_index], vocab[predicted_index], count))
    df_cm = pd.DataFrame(data, columns=['target', 'predicted', 'count'])
    cm_file = bucket_name + '/metadata/cm.csv'
    print(df_cm)
    with file_io.FileIO(cm_file, 'w') as f:
        df_cm.to_csv(f, columns=['target', 'predicted', 'count'], header=False, index=False)

    print("***************************************")
    print("Writing the confusion matrix to ", cm_file)
    metadata = {
        'outputs': [{
            'type': 'confusion_matrix',
            'format': 'csv',
            'schema': [
                {'name': 'target', 'type': 'CATEGORY'},
                {'name': 'predicted', 'type': 'CATEGORY'},
                {'name': 'count', 'type': 'NUMBER'},
            ],
            'source': cm_file,
            'labels': list(map(str, vocab)),
        }]
    }

    with file_io.FileIO('/mlpipeline-ui-metadata.json', 'w') as f:
        json.dump(metadata, f)

    return df_cm


def load_data(bucket_name):
    # load dataset
    train_file = bucket_name + '/output/train.csv'
    test_file = bucket_name + '/output/test.csv'
    train_labels = bucket_name + '/output/train_label.csv'
    test_labels = bucket_name + '/output/test_label.csv'

    trainDF = pd.read_csv(train_file)
    trainLabelDF = pd.read_csv(train_labels)
    testX = pd.read_csv(test_file)
    testy = pd.read_csv(test_labels)
    trainX = trainDF.drop(trainDF.columns[0], axis=1)
    trainy = trainLabelDF.drop(trainLabelDF.columns[0], axis=1)
    testy = testy.drop(testy.columns[0], axis=1)
    testX = testX.drop(testX.columns[0], axis=1)

    return testX, testy, trainX, trainy

In [7]:
if __name__ == '__main__':
    print("The arguments are ", str(sys.argv))
    if len(sys.argv) < 1:
        print("Usage: train bucket-name epochs batch-size katib optimizer")
        sys.exit(-1)

    parser = parse_arguments()
    args = parser.parse_known_args()[0]    
    print(args)
    train(args.bucket_name, int(args.epochs), int(args.batch_size), int(args.katib), args.optimizer_name)

The arguments are  ['/opt/tljh/user/lib/python3.7/site-packages/ipykernel_launcher.py', '-f', '/home/jupyter-tryster7/.local/share/jupyter/runtime/kernel-81832a79-a6dd-48e5-9d94-5da536a5b41c.json']
Namespace(batch_size=64, bucket_name='gs://kbc/ccc', epochs=1, katib=0, optimizer_name='Adam')


_call out of retries on exception: ('Failed to retrieve http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/?recursive=true from the Google Compute Enginemetadata service. Status: 404 Response:\nb\'<!DOCTYPE html>\\n<html lang=en>\\n  <meta charset=utf-8>\\n  <meta name=viewport content="initial-scale=1, minimum-scale=1, width=device-width">\\n  <title>Error 404 (Not Found)!!1</title>\\n  <style>\\n    *{margin:0;padding:0}html,code{font:15px/22px arial,sans-serif}html{background:#fff;color:#222;padding:15px}body{margin:7% auto 0;max-width:390px;min-height:180px;padding:30px 0 15px}* > body{background:url(//www.google.com/images/errors/robot.png) 100% 5px no-repeat;padding-right:205px}p{margin:11px 0 22px;overflow:hidden}ins{color:#777;text-decoration:none}a img{border:0}@media screen and (max-width:772px){body{background:none;margin-top:0;max-width:none;padding-right:0}}#logo{background:url(//www.google.com/images/branding/googlelogo/1x/googlelogo_colo

RefreshError: ('Failed to retrieve http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/?recursive=true from the Google Compute Enginemetadata service. Status: 404 Response:\nb\'<!DOCTYPE html>\\n<html lang=en>\\n  <meta charset=utf-8>\\n  <meta name=viewport content="initial-scale=1, minimum-scale=1, width=device-width">\\n  <title>Error 404 (Not Found)!!1</title>\\n  <style>\\n    *{margin:0;padding:0}html,code{font:15px/22px arial,sans-serif}html{background:#fff;color:#222;padding:15px}body{margin:7% auto 0;max-width:390px;min-height:180px;padding:30px 0 15px}* > body{background:url(//www.google.com/images/errors/robot.png) 100% 5px no-repeat;padding-right:205px}p{margin:11px 0 22px;overflow:hidden}ins{color:#777;text-decoration:none}a img{border:0}@media screen and (max-width:772px){body{background:none;margin-top:0;max-width:none;padding-right:0}}#logo{background:url(//www.google.com/images/branding/googlelogo/1x/googlelogo_color_150x54dp.png) no-repeat;margin-left:-5px}@media only screen and (min-resolution:192dpi){#logo{background:url(//www.google.com/images/branding/googlelogo/2x/googlelogo_color_150x54dp.png) no-repeat 0% 0%/100% 100%;-moz-border-image:url(//www.google.com/images/branding/googlelogo/2x/googlelogo_color_150x54dp.png) 0}}@media only screen and (-webkit-min-device-pixel-ratio:2){#logo{background:url(//www.google.com/images/branding/googlelogo/2x/googlelogo_color_150x54dp.png) no-repeat;-webkit-background-size:100% 100%}}#logo{display:inline-block;height:54px;width:150px}\\n  </style>\\n  <a href=//www.google.com/><span id=logo aria-label=Google></span></a>\\n  <p><b>404.</b> <ins>That\\xe2\\x80\\x99s an error.</ins>\\n  <p>The requested URL <code>/computeMetadata/v1/instance/service-accounts/default/?recursive=true</code> was not found on this server.  <ins>That\\xe2\\x80\\x99s all we know.</ins>\\n\'', <google.auth.transport.requests._Response object at 0x7fc1b3f46f28>)