Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/mnist/keras/mnist_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def main_fun(args, ctx):
import numpy
import os
import tensorflow as tf
import tensorflow.contrib.keras as keras
from tensorflow.contrib.keras.api.keras import backend as K
from tensorflow.contrib.keras.api.keras.models import Sequential, load_model, save_model
from tensorflow.contrib.keras.api.keras.layers import Dense, Dropout
from tensorflow.contrib.keras.api.keras.optimizers import RMSprop
from tensorflow.contrib.keras.python.keras.callbacks import LambdaCallback, TensorBoard
from tensorflow.python import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Sequential, load_model, save_model
from tensorflow.python.keras.layers import Dense, Dropout
from tensorflow.python.keras.optimizers import RMSprop
from tensorflow.python.keras.callbacks import LambdaCallback, TensorBoard
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
Expand Down Expand Up @@ -51,7 +51,7 @@ def generate_rdd_data(tf_feed, batch_size):

# the data, shuffled and split between train and test sets
if args.input_mode == 'tf':
from tensorflow.contrib.keras.api.keras.datasets import mnist
from tensorflow.python.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
Expand Down