In [None]:
import sys, os, json, numpy as np, requests

In [None]:
!pip install clipper_admin

In [None]:
from clipper_admin import Clipper

In [None]:
clipper_client  = Clipper("localhost")

In [None]:
clipper_client.start()

In [None]:
clipper_client.get_all_apps()

In [None]:
# An application in Clipper corresponds to a REST prediction endpoint
clipper_client.register_application(
    "digits",
    "pyspark_svm", "ints", "-1.0", 100000)

In [None]:
# Send a test prediction
headers = {"Content-type": "application/json"}
requests.post(
    "http://localhost:1337/digits/predict",
    headers=headers,
    data=json.dumps({"input": [np.random.randint(255) for _ in range(784)]})).json()

# Train an SVM with PySpark

> Note that this code uses the `findspark` package to import Spark. You can install it with `pip install findspark`.

In [None]:
import findspark
findspark.init()
from pyspark.mllib.classification import LogisticRegressionWithLBFGS
from pyspark.mllib.regression import LabeledPoint
from pyspark.sql import SparkSession

In [None]:
spark = SparkSession\
        .builder\
        .appName("clipper-pyspark")\
        .getOrCreate()
sc = spark.sparkContext



In [None]:
def normalize(x):
    x = x.astype(np.double)
    mu = np.mean(x)
    sigma = np.var(x)
    if sigma > 0:
        return (x - mu) / np.sqrt(sigma)
    else:
        return 1

def obj(y):
    if y == 3:
        return 0
    else:
        return 1

def parse(line):
    fields = line.strip().split(',')
    return LabeledPoint(obj(int(fields[0])), normalize(np.array(fields[1:])))

train_path = "/Users/crankshaw/model-serving/data/mnist_data/train.data"
trainRDD = sc.textFile(train_path).map(
    lambda line: parse(line)).cache()

In [None]:
model = LogisticRegressionWithLBFGS.train(trainRDD)

In [None]:
def simple_predict(spark, model, xs):
    return [str(model.predict(normalize(x))) for x in xs]

In [None]:
test_point = np.array([np.random.randint(255) for _ in range(784)])
simple_predict(spark, model, [test_point])

In [None]:
clipper_client.deploy_pyspark_model("pyspark_svm", 1, simple_predict, model, sc, "ints")

In [None]:
requests.post(
    "http://localhost:1337/digits/predict",
    headers=headers,
    data=json.dumps({"input": [np.random.randint(255) for _ in range(784)]})).json()