Permalink
Switch branches/tags
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
53 lines (41 sloc) 1.29 KB
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from wandb.tensorflow import WandbHook
import wandb
mnist = input_data.read_data_sets('MNIST_data')
def input(dataset):
return dataset.images, dataset.labels.astype(np.int32)
# Specify feature
feature_columns = [tf.feature_column.numeric_column("x", shape=[28, 28])]
# Build 2 layer DNN classifier
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf.train.AdamOptimizer(1e-4),
n_classes=10,
dropout=0.1,
model_dir="./tmp/mnist_model"
)
wandb.init()
summary_op = tf.summary.merge_all()
hook = WandbHook(summary_op)
# Define the training inputs
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": input(mnist.train)[0]},
y=input(mnist.train)[1],
num_epochs=None,
batch_size=50,
shuffle=True,
)
classifier.train(input_fn=train_input_fn, steps=100000, hooks=[hook])
# Define the test inputs
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": input(mnist.test)[0]},
y=input(mnist.test)[1],
num_epochs=1,
shuffle=False
)
# Evaluate accuracy
accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"]
print("\nTest Accuracy: {0:f}%\n".format(accuracy_score*100))