In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from PyTorchEstimator import PyTorchEstimator
from azureml.core.workspace import Workspace
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
sc = SparkContext('local')
spark = SparkSession(sc)

In [3]:
# Get the CIFAR10 dataset as Python dictionary
import os, tarfile, pickle
import urllib.request
cdnURL = "https://amldockerdatasets.azureedge.net"
# Please note that this is a copy of the CIFAR10 dataset originally found here:
# http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
dataFile = "cifar-10-python.tar.gz"
dataURL = cdnURL + "/CIFAR10/" + dataFile
if not os.path.isfile(dataFile):
    urllib.request.urlretrieve(dataURL, dataFile)
with tarfile.open(dataFile, "r:gz") as f:
    test_dict = pickle.load(f.extractfile("cifar-10-batches-py/test_batch"),
                            encoding="latin1")

In [4]:
# Create the images with labels from CIFAR dataset,
# reformat the labels using OneHotEncoder
import array
from pyspark.sql.functions import udf
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.ml.feature import OneHotEncoderEstimator
from pyspark.sql.functions import col
from pyspark.sql.types import *

def reshape_image(record):
    image, label, filename = record
    data = [float(x) for x in image.reshape(3,32,32).flatten()]
    return data, label, filename

convert_to_double = udf(lambda x: x, ArrayType(DoubleType()))

image_rdd = zip(test_dict["data"], test_dict["labels"], test_dict["filenames"])
image_rdd = spark.sparkContext.parallelize(image_rdd).map(reshape_image)

imagesWithLabels = image_rdd.toDF(["images", "labels", "filename"])

list_to_vector_udf = udf(lambda l: Vectors.dense(l), VectorUDT())

imagesWithLabels = imagesWithLabels.withColumn(
                       "images",
                       list_to_vector_udf(convert_to_double(col("images")))) \
                       .select("images", "labels")

ohe = OneHotEncoderEstimator() \
        .setInputCols(["labels"]).setOutputCols(["tmplabels"]) \
        .setDropLast(False)
imagesWithLabels = ohe.fit(imagesWithLabels) \
                      .transform(imagesWithLabels) \
                      .select("images", "tmplabels") \
                      .withColumnRenamed("tmplabels", "labels")

imagesWithLabels.printSchema()

imagesWithLabels.cache()
print(imagesWithLabels.count())
imagesWithLabels.show(5)

root
 |-- images: vector (nullable = true)
 |-- labels: vector (nullable = true)

10000
+--------------------+--------------+
|              images|        labels|
+--------------------+--------------+
|[158.0,159.0,165....|(10,[3],[1.0])|
|[235.0,231.0,232....|(10,[8],[1.0])|
|[158.0,158.0,139....|(10,[8],[1.0])|
|[155.0,167.0,176....|(10,[0],[1.0])|
|[65.0,70.0,48.0,3...|(10,[6],[1.0])|
+--------------------+--------------+
only showing top 5 rows



In [5]:
# Split the images with labels into a train and test data
train, test = imagesWithLabels.randomSplit([0.8, 0.2], seed=123)

In [9]:
# Initializing the estimator
workspace = Workspace('e54229a3-0e6f-40b3-82a1-ae9cda6e2b81', 'mmlspark-serano', 'playground')
clusterName = 'train-target'
trainingScript = 'pytorch_train.py'
nodeCount = 1
modelPath = 'outputs/model.pt'
experimentName = 'pytorch-cifar'

estimator = PyTorchEstimator(workspace, clusterName, trainingScript, nodeCount, modelPath, experimentName)

In [10]:
model = estimator.fit(train)

Found existing compute target.
Uploading data/data.parquet/._SUCCESS.crc
Uploading data/data.parquet/.part-00000-bd2a5f56-d120-4ddf-a8d5-cb4447fbd213-c000.snappy.parquet.crc
Uploading data/data.parquet/_SUCCESS
Uploading data/data.parquet/part-00000-bd2a5f56-d120-4ddf-a8d5-cb4447fbd213-c000.snappy.parquet
Uploading data/dataset.parquet/._SUCCESS.crc
Uploading data/dataset.parquet/.part-00000-9a9b9f13-1741-4e94-aa28-d0cb638ca5db-c000.snappy.parquet.crc
Uploading data/dataset.parquet/_SUCCESS
Uploading data/dataset.parquet/part-00000-9a9b9f13-1741-4e94-aa28-d0cb638ca5db-c000.snappy.parquet
Uploading data/str_data.csv/._SUCCESS.crc
Uploading data/str_data.csv/.part-00000-70cd38bb-d800-40f6-9917-0d9b56003be6-c000.csv.crc
Uploading data/str_data.csv/_SUCCESS
Uploading data/str_data.csv/part-00000-70cd38bb-d800-40f6-9917-0d9b56003be6-c000.csv
Uploading data/str_data.parquet/._SUCCESS.crc
Uploading data/str_data.parquet/.part-00000-6538cf7b-02f4-409c-bcf5-76f44e3ea537-c000.snappy.parquet.crc




Uploaded data/str_data.csv/part-00000-70cd38bb-d800-40f6-9917-0d9b56003be6-c000.csv, 16 files out of an estimated total of 16
Job submitted!


_UserRunWidget(widget_settings={'childWidgetDisplay': 'popup', 'send_telemetry': False, 'log_level': 'INFO', '…

RunId: pytorch-cifar_1560437354_ff85bed4
Web View: https://mlworkspace.azure.ai/portal/subscriptions/e54229a3-0e6f-40b3-82a1-ae9cda6e2b81/resourceGroups/mmlspark-serano/providers/Microsoft.MachineLearningServices/workspaces/playground/experiments/pytorch-cifar/runs/pytorch-cifar_1560437354_ff85bed4

Streaming azureml-logs/70_driver_log.txt

bash: /azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/libtinfo.so.5: no version information available (required by bash)
bash: /azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/libtinfo.so.5: no version information available (required by bash)
JAVA_HOME is not set


The experiment failed. Finalizing run...
Logging experiment finalizing status in history service.
Cleaning up all outstanding Run operations, waiting 300.0 seconds
2 items cleaning up...
Cleanup took 0.0016446113586425781 seconds
Traceback (most recent call last):
  File "pytorch_train.py", line 27, in <module>
    sc = SparkContext('local')
  File "/azureml-envs/azur

ActivityFailedException: Activity Failed:
{
    "error": {
        "code": "UserError",
        "message": "Java gateway process exited before sending its port number",
        "details": [],
        "debugInfo": {
            "type": "Exception",
            "message": "Java gateway process exited before sending its port number",
            "stackTrace": "  File \"azureml-setup/context_manager_injector.py\", line 96, in execute_with_context\n    runpy.run_path(sys.argv[0], globals(), run_name=\"__main__\")\n  File \"/azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/python3.6/runpy.py\", line 263, in run_path\n    pkg_name=pkg_name, script_name=fname)\n  File \"/azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/python3.6/runpy.py\", line 96, in _run_module_code\n    mod_name, mod_spec, pkg_name, script_name)\n  File \"/azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/python3.6/runpy.py\", line 85, in _run_code\n    exec(code, run_globals)\n  File \"pytorch_train.py\", line 27, in <module>\n    sc = SparkContext('local')\n  File \"/azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/python3.6/site-packages/pyspark/context.py\", line 133, in __init__\n    SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)\n  File \"/azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/python3.6/site-packages/pyspark/context.py\", line 316, in _ensure_initialized\n    SparkContext._gateway = gateway or launch_gateway(conf)\n  File \"/azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/python3.6/site-packages/pyspark/java_gateway.py\", line 46, in launch_gateway\n    return _launch_gateway(conf)\n  File \"/azureml-envs/azureml_ff2276d1e455a6c03e7998425b41690f/lib/python3.6/site-packages/pyspark/java_gateway.py\", line 108, in _launch_gateway\n    raise Exception(\"Java gateway process exited before sending its port number\")\n"
        }
    }
}