Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate test checkpoint data automatically (#1738)
* Genearte checkpoint for test automatically * commit the python script to generate mnist checkpoint file * Reformat code by pre-commit * Recovery integration test of training and evaluation * Fix the path of python scripts for client_test.sh * Fix the command to generate checkpoint * Add log for checkpoint * Add log for checkpoints * fix mount path * print the mount dir content * Remove codes to print log * Remove a backslash * Only create train end task for training
- Loading branch information
1 parent
791484f
commit 36af4e8
Showing
10 changed files
with
132 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file removed
BIN
-2.29 MB
elasticdl/python/tests/testdata/functional_ckpt/version-100/variables-0-of-1.ckpt
Binary file not shown.
Binary file removed
BIN
-434 KB
elasticdl/python/tests/testdata/mnist_functional_api_model/version-110/variables-0-of-1.ckpt
Binary file not shown.
Binary file removed
BIN
-2.29 MB
elasticdl/python/tests/testdata/subclass_ckpt/version-100/variables-0-of-1.ckpt
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import argparse | ||
|
||
import tensorflow as tf | ||
|
||
from elasticdl.python.tests.test_utils import save_checkpoint_without_embedding | ||
|
||
|
||
def mnist_custom_model(): | ||
inputs = tf.keras.Input(shape=(28, 28), name="image") | ||
x = tf.keras.layers.Reshape((28, 28, 1))(inputs) | ||
x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu")(x) | ||
x = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu")(x) | ||
x = tf.keras.layers.BatchNormalization()(x) | ||
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x) | ||
x = tf.keras.layers.Dropout(0.25)(x) | ||
x = tf.keras.layers.Flatten()(x) | ||
outputs = tf.keras.layers.Dense(10)(x) | ||
|
||
return tf.keras.Model(inputs=inputs, outputs=outputs, name="mnist_model") | ||
|
||
|
||
def add_params(parser): | ||
parser.add_argument( | ||
"--checkpoint_dir", | ||
help="The directory to store the mnist checkpoint", | ||
default="", | ||
type=str, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
add_params(parser) | ||
args, _ = parser.parse_known_args() | ||
print(args) | ||
model = mnist_custom_model() | ||
save_checkpoint_without_embedding(model, args.checkpoint_dir) |