Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .travis.settings.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<settings xmlns="http://maven.apache.org/SETTINGS/1.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/SETTINGS/1.0.0 http://maven.apache.org/xsd/settings-1.0.0.xsd">
<servers>
<server>
<id>bintray-tensorflowonspark-repo</id>
<username>${env.BINTRAY_USER}</username>
<password>${env.BINTRAY_API_KEY}</password>
</server>
</servers>
</settings>
98 changes: 62 additions & 36 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,39 +1,65 @@
language: python
python:
- 2.7
- 3.6
cache: pip
before_install:
- curl -LO http://www-us.apache.org/dist/spark/spark-2.3.1/spark-2.3.1-bin-hadoop2.7.tgz
- export SPARK_HOME=./spark
- mkdir $SPARK_HOME
- tar -xf spark-2.3.1-bin-hadoop2.7.tgz -C $SPARK_HOME --strip-components=1
- export PATH=$SPARK_HOME/bin:$PATH
- export SPARK_LOCAL_IP=127.0.0.1
- export SPARK_CLASSPATH=./lib/tensorflow-hadoop-1.0-SNAPSHOT.jar
- export PYTHONPATH=$(pwd)
install:
- pip install -r requirements.txt
script:
- sphinx-build -b html docs/source docs/build/html
- test/run_tests.sh
matrix:
include:
- language: python
python: 2.7
before_install:
- curl -LO http://www-us.apache.org/dist/spark/spark-2.3.1/spark-2.3.1-bin-hadoop2.7.tgz
- export SPARK_HOME=./spark
- mkdir $SPARK_HOME
- tar -xf spark-2.3.1-bin-hadoop2.7.tgz -C $SPARK_HOME --strip-components=1
- export PATH=$SPARK_HOME/bin:$PATH
- export SPARK_LOCAL_IP=127.0.0.1
- export SPARK_CLASSPATH=./lib/tensorflow-hadoop-1.0-SNAPSHOT.jar
- export PYTHONPATH=$(pwd)
install:
- pip install -r requirements.txt
script:
- test/run_tests.sh
- language: python
python: 3.6
before_install:
- curl -LO http://www-us.apache.org/dist/spark/spark-2.3.1/spark-2.3.1-bin-hadoop2.7.tgz
- export SPARK_HOME=./spark
- mkdir $SPARK_HOME
- tar -xf spark-2.3.1-bin-hadoop2.7.tgz -C $SPARK_HOME --strip-components=1
- export PATH=$SPARK_HOME/bin:$PATH
- export SPARK_LOCAL_IP=127.0.0.1
- export SPARK_CLASSPATH=./lib/tensorflow-hadoop-1.0-SNAPSHOT.jar
- export PYTHONPATH=$(pwd)
install:
- pip install -r requirements.txt
script:
- sphinx-build -b html docs/source docs/build/html
- test/run_tests.sh
- language: java
jdk: oraclejdk8
notifications:
email: false
deploy:
- provider: pages
skip_cleanup: true
github_token: $GITHUB_TOKEN
local_dir: docs/build/html
on:
branch: master
python: 3.6
tags: true
- provider: pypi
user: leewyang
password:
secure: T2Q8VM6SgcMtJDO2kJbaELE/5ICR5mx8pkM6TyNAJZ2Mr3fLIy6iDfPKunBAYVljl+SDEWmuoPTWqJdqMyo47LBKPKtBHbGzATqGSRTLvxLOYNSXUX+uCpPtr7CMp1eP3xpZ3YbAJZvoEFlWnBQKeBtX/PjNCpmKdp7ir+46CvR/pR1tcM5cFnSgU+uCPAMUt8KTZIxeRo+oJtaE0DM2RxLJ9nGnaRNz9fdXxwhViNj/bMnDRUI0G6k+Iy4sO2669si8nhTDr+Oq66ONUcJtAQymNUM/hzBTCkrJvuIq1TqTlKkA39UrtD5/wCkCqPUbCLVuIfNwkYfW2C8AlXcbphBKN4PhwaoL5XECr3/AOsgNpnPWhCF1Z1uLi58FhIlSyp+5c/x2wVJLZi2IE+c996An7UO3t16ZFpFEgzS6m9PVbi6Qil6Tl4AhV5QLKb0Qn0hLe2l0WixzK9KLMHfkqX8h5ZGC7i0TvCNcU2uIFjY8we91GORZKZhwUVDKbPqiUZIKn64Qq8EwJIsk/S344OrUTzm7z0lFCqtPphg1duU42QOFmaYWi6hgsbtDxN6+CubLw23G3PtKjOpNt8hHnrjZsz9H1MKbSAoYQ4fo+Iwb3owTjXnSTBr94StW7qysggWH6xQimFDh/SKOE9MfroMGt5YTXfduTbqyeameYqE=
distributions: sdist bdist_wheel
on:
branch: master
python: 3.6
tags: true
- provider: pages
skip_cleanup: true
github_token: "$GITHUB_TOKEN"
local_dir: docs/build/html
on:
branch: master
python: 3.6
tags: true
condition: "$TRAVIS_TAG =~ ^v.*$"
- provider: pypi
user: leewyang
password:
secure: T2Q8VM6SgcMtJDO2kJbaELE/5ICR5mx8pkM6TyNAJZ2Mr3fLIy6iDfPKunBAYVljl+SDEWmuoPTWqJdqMyo47LBKPKtBHbGzATqGSRTLvxLOYNSXUX+uCpPtr7CMp1eP3xpZ3YbAJZvoEFlWnBQKeBtX/PjNCpmKdp7ir+46CvR/pR1tcM5cFnSgU+uCPAMUt8KTZIxeRo+oJtaE0DM2RxLJ9nGnaRNz9fdXxwhViNj/bMnDRUI0G6k+Iy4sO2669si8nhTDr+Oq66ONUcJtAQymNUM/hzBTCkrJvuIq1TqTlKkA39UrtD5/wCkCqPUbCLVuIfNwkYfW2C8AlXcbphBKN4PhwaoL5XECr3/AOsgNpnPWhCF1Z1uLi58FhIlSyp+5c/x2wVJLZi2IE+c996An7UO3t16ZFpFEgzS6m9PVbi6Qil6Tl4AhV5QLKb0Qn0hLe2l0WixzK9KLMHfkqX8h5ZGC7i0TvCNcU2uIFjY8we91GORZKZhwUVDKbPqiUZIKn64Qq8EwJIsk/S344OrUTzm7z0lFCqtPphg1duU42QOFmaYWi6hgsbtDxN6+CubLw23G3PtKjOpNt8hHnrjZsz9H1MKbSAoYQ4fo+Iwb3owTjXnSTBr94StW7qysggWH6xQimFDh/SKOE9MfroMGt5YTXfduTbqyeameYqE=
distributions: sdist bdist_wheel
on:
branch: master
python: 3.6
tags: true
condition: "$TRAVIS_TAG =~ ^v.*$"
- provider: script
script: mvn deploy -DskipTests --settings .travis.settings.xml
skip_cleanup: true
on:
branch: master
jdk: oraclejdk8
tags: true
condition: "$TRAVIS_TAG =~ ^scala_.*$"
1 change: 1 addition & 0 deletions examples/mnist/spark/mnist_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def feed_dict(batch):
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(task_index == 0),
checkpoint_dir=logdir,
save_checkpoint_secs=10,
hooks=[tf.train.StopAtStepHook(last_step=args.steps)],
chief_only_hooks=[ExportHook(ctx.absolute_path(args.export_dir), x, prediction)]) as mon_sess:
step = 0
Expand Down
22 changes: 11 additions & 11 deletions examples/mnist/spark/mnist_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--export_dir", help="HDFS path to export saved_model", default="mnist_export")
parser.add_argument("--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="csv")
parser.add_argument("--format", help="example format: (csv|tfr)", choices=["csv", "tfr"], default="csv")
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
Expand Down Expand Up @@ -56,22 +56,22 @@ def toNumpy(bytestr):
return (image, label)

dataRDD = images.map(lambda x: toNumpy(bytes(x[0])))
else:
if args.format == "csv":
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
else: # args.format == "pickle":
images = sc.pickleFile(args.images)
labels = sc.pickleFile(args.labels)
else: # "csv"
print("zipping images and labels")
# If partitions of images/labels don't match, you can use the following code:
# images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')]).zipWithIndex().map(lambda x: (x[1], x[0]))
# labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')]).zipWithIndex().map(lambda x: (x[1], x[0]))
# dataRDD = images.join(labels).map(lambda x: (x[1][0], x[1][1]))
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)

cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model)
if args.mode == "train":
cluster.train(dataRDD, args.epochs)
else:
labelRDD = cluster.inference(dataRDD)
labelRDD.saveAsTextFile(args.output)
else: # inference
predRDD = cluster.inference(dataRDD)
predRDD.saveAsTextFile(args.output)

cluster.shutdown(grace_secs=30)

Expand Down
20 changes: 16 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.yahoo.ml</groupId>
<artifactId>tensorflowonspark</artifactId>
<version>1.0-SNAPSHOT</version>
<version>1.0</version>
<packaging>jar</packaging>
<name>tensorflowonspark</name>
<description>Spark Scala inferencing for TensorFlowOnSpark</description>

<distributionManagement>
<repository>
<id>bintray-tensorflowonspark-repo</id>
<url>https://api.bintray.com/maven/yahoo/maven/tensorflowonspark</url>
</repository>
</distributionManagement>

<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
Expand All @@ -22,11 +29,11 @@
<scala.version>2.11.8</scala.version>
<scala-maven-plugin.version>3.2.1</scala-maven-plugin.version>
<scala-parser-combinators.version>1.1.0</scala-parser-combinators.version>
<scalatest.version>3.0.3</scalatest.version>
<scalatest.version>3.0.5</scalatest.version>
<scalatest-maven-plugin.version>1.0</scalatest-maven-plugin.version>
<scopt.version>3.7.0</scopt.version>
<tensorflow.version>1.8.0</tensorflow.version>
<tensorflow-hadoop.version>1.0-SNAPSHOT</tensorflow-hadoop.version>
<tensorflow.version>1.9.0</tensorflow.version>
<tensorflow-hadoop.version>1.9.0</tensorflow-hadoop.version>
</properties>
<dependencies>
<dependency>
Expand Down Expand Up @@ -67,6 +74,11 @@
<artifactId>hadoop</artifactId>
<version>${tensorflow-hadoop.version}</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>3.5.1</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_2.11</artifactId>
Expand Down