In [1]:
import re
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from functools import partial
from sklearn.model_selection import train_test_split
import tempfile

In [None]:
from source.support_functions import load_test_dataset

In [None]:
# sagemaker parameters
import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri
import boto3

session = sagemaker.Session()
role = get_execution_role()
bucket = session.default_bucket()
prefix = 'kaggle/melanoma'  # Prefix should not tontain '/' at the end!
s3 = boto3.client('s3')

In [None]:
data_location = os.path.join('s3://', bucket, prefix)

In [None]:
train_new = True
if train_new:
    from sagemaker.tensorflow import TensorFlow
    # Initiate a TensorFlow instance.

    """When using version 2.1, 'setup.py' is needed in the 'source' directory"""

    tf_estimator = TensorFlow(
        entry_point='VGG16_train.py',
        source_dir='source',
        role=role,
        train_instance_count=1,
        train_instance_type='ml.p2.xlarge',
        #framework_version='2.1.0',
        #py_version='py3',
        output_path=os.path.join('s3://', bucket, prefix, 'model'),
        model_dir=os.path.join('s3://', bucket, prefix, 'model'),
        hyperparameters={
            'epochs': 15,
            'lr_max': 0.0001,
            'batch_size': 32,
            'image_size': 1024
        }
    )
    tf_estimator.fit(data_location)
    
    model_data = tf_estimator.model_data
    subprocess.check_call('echo {} > bert_model.txt'.format(model_data), shell=True)

else:
    with open('bert_model.txt', 'r') as f:
        model_data = f.read().split()[0]
    print("Use a previously trained model.")
    
print(model_data)

In [16]:
ds = load_test_dataset(PATH_TO_TEST_DATA)


In [17]:
pred = model.predict(ds)

In [18]:
pred


array([[0.74847436],
       [0.7555566 ],
       [0.75395626],
       [0.7540812 ],
       [0.7481322 ],
       [0.7573752 ],
       [0.74900055],
       [0.7445854 ],
       [0.7304162 ],
       [0.75058496],
       [0.7622591 ],
       [0.76029503],
       [0.75988936],
       [0.755166  ],
       [0.7539537 ],
       [0.7537733 ],
       [0.75265926],
       [0.7496214 ],
       [0.75989544],
       [0.75561464],
       [0.7473942 ],
       [0.75829566],
       [0.7496162 ],
       [0.75186676],
       [0.7457945 ],
       [0.7582307 ],
       [0.7597212 ],
       [0.76188326],
       [0.75703835],
       [0.7518357 ],
       [0.7542299 ],
       [0.75614274]], dtype=float32)