In [None]:
from sparktorch import serialize_torch_obj, SparkTorch
import torch
import torch.nn as nn
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.ml.pipeline import Pipeline
from pyspark.sql.functions import col, udf, column
import json
from pyspark.sql.types import StructType
from pyspark.conf import SparkConf
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [None]:
# Setting configuration
sparkConf = SparkConf()
sparkConf.setMaster("yarn")
sparkConf.setAppName("MNIST_TRAIN")
sparkConf.set("spark.hadoop.yarn.resourcemanager.address", "127.0.0.1:8032")
sparkConf.set("spark.driver.memory","6g").set("spark.executor.memory", '6g').set('spark.executor.cores', 2)
#sparkConf.setMaster("local[2]").set("spark.driver.memory","8g").set("spark.executor.memory", '8g').set('spark.executor.cores', 1)

spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

In [None]:
with open("schema.json") as f:
    schema = StructType.fromJson(json.load(f))

In [None]:
df = spark.read \
    .schema(schema) \
    .option("header", "true") \
    .csv('/user/hduser/input/mnist_train.csv') \
    .withColumnRenamed("_c0", "labels") \
    .coalesce(2)

In [None]:
# Number of record
df.count()

In [None]:
# Define neural network
network = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.Softmax(dim=1)
)

In [None]:
# Build the pytorch object
torch_obj = serialize_torch_obj(
    model=network,
    criterion=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam,
    lr=0.0001
)

In [None]:
# Setup features
vector_assembler = VectorAssembler(inputCols=df.columns[1:785], outputCol='features')

In [None]:
# Create a SparkTorch Model with torch distributed. Barrier execution is on by default for this mode.
spark_model = SparkTorch(
    inputCol='features',
    labelCol='labels',
    predictionCol='predictions',
    torchObj=torch_obj,
    iters=200,
    verbose=1,
    miniBatch=256,
    earlyStopPatience=40,
    validationPct=0.2
)

In [None]:
# Train model and create a pipeline
p = Pipeline(stages=[vector_assembler, spark_model]).fit(df)

In [None]:
# Evaluate Model
predictions = p.transform(df).persist()
evaluator = MulticlassClassificationEvaluator(labelCol="labels", predictionCol="predictions", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Train accuracy = %g" % accuracy)

In [None]:
# Save pipeline in hdfs
p.write().overwrite().save('/user/hduser/models/mnist')

In [None]:
spark.sparkContext.stop()

In [None]:
#END