In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from PyTorchEstimator import PyTorchEstimator
from azureml.core.workspace import Workspace
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession

sc = SparkContext.getOrCreate()
spark = SparkSession(sc)

In [3]:
# Get the CIFAR10 dataset as Python dictionary
import os, tarfile, pickle
import urllib.request
cdnURL = "https://amldockerdatasets.azureedge.net"
# Please note that this is a copy of the CIFAR10 dataset originally found here:
# http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
dataFile = "cifar-10-python.tar.gz"
dataURL = cdnURL + "/CIFAR10/" + dataFile
if not os.path.isfile(dataFile):
    urllib.request.urlretrieve(dataURL, dataFile)
with tarfile.open(dataFile, "r:gz") as f:
    test_dict = pickle.load(f.extractfile("cifar-10-batches-py/test_batch"),
                            encoding="latin1")

In [4]:
from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField
from petastorm.codecs import ScalarCodec, NdarrayCodec
from pyspark.sql.types import *
import numpy as np

# Generate Petastorm dataset
image_zip = zip(test_dict["data"], test_dict["labels"])

CIFARSchema = Unischema('CIFARSchema', [
    UnischemaField('image', np.uint8, (3,32,32), NdarrayCodec(), False),
    UnischemaField('label', np.int32, (), ScalarCodec(IntegerType()), False),
])

def reshape_image(record):
    image, label = record
    return {'image': image.reshape(3,32,32), 'label': label}

rows_rdd = sc.parallelize(image_zip)\
    .map(reshape_image)\
    .map(lambda x: dict_to_spark_row(CIFARSchema, x))
    
imagesWithLabels = spark.createDataFrame(rows_rdd)

In [5]:
# Split the images with labels into a train and test data
train, test = imagesWithLabels.randomSplit([0.8, 0.2], seed=123)

In [6]:
train.show(3)
train.printSchema()

+--------------------+-----+
|               image|label|
+--------------------+-----+
|[93 4E 55 4D 50 5...|    2|
|[93 4E 55 4D 50 5...|    2|
|[93 4E 55 4D 50 5...|    1|
+--------------------+-----+
only showing top 3 rows

root
 |-- image: binary (nullable = true)
 |-- label: long (nullable = true)



In [None]:
# Initializing the estimator
workspace = Workspace('e54229a3-0e6f-40b3-82a1-ae9cda6e2b81', 'mmlspark-serano', 'playground')
clusterName = 'train-target'
trainingScript = 'pytorch_train.py'
nodeCount = 1
modelPath = 'outputs/model.pt'
experimentName = 'pytorch-cifar'
unischema = CIFARSchema

estimator = PyTorchEstimator(workspace, clusterName, trainingScript, nodeCount, modelPath, experimentName, unischema)

In [None]:
model = estimator.fit(train)

Found existing compute target.
Uploading /tmp/data/dataset.parquet/._SUCCESS.crc
Uploading /tmp/data/dataset.parquet/.part-00000-cb6d2487-ab82-4841-addf-e5b58ee4f07a-c000.snappy.parquet.crc
Uploading /tmp/data/dataset.parquet/.part-00001-cb6d2487-ab82-4841-addf-e5b58ee4f07a-c000.snappy.parquet.crc
Uploading /tmp/data/dataset.parquet/_SUCCESS
Uploading /tmp/data/dataset.parquet/_common_metadata
Uploading /tmp/data/dataset.parquet/part-00000-cb6d2487-ab82-4841-addf-e5b58ee4f07a-c000.snappy.parquet
Uploading /tmp/data/dataset.parquet/part-00001-cb6d2487-ab82-4841-addf-e5b58ee4f07a-c000.snappy.parquet
Uploaded /tmp/data/dataset.parquet/._SUCCESS.crc, 1 files out of an estimated total of 7
Uploaded /tmp/data/dataset.parquet/_common_metadata, 2 files out of an estimated total of 7
Uploaded /tmp/data/dataset.parquet/_SUCCESS, 3 files out of an estimated total of 7
Uploaded /tmp/data/dataset.parquet/.part-00000-cb6d2487-ab82-4841-addf-e5b58ee4f07a-c000.snappy.parquet.crc, 4 files out of an est