Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Replacing tf.contrib.data.batch_and_drop_remainder by batch(..., drop…
…_remainder=True). Also checkpointing at (epoch + 1) % x while saving the model to consider the last epoch's variables.
  • Loading branch information
yashk2810 committed Aug 12, 2018
1 parent d880275 commit b416db3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
Expand Up @@ -495,7 +495,7 @@
" random_vector_for_generation)\n",
" \n",
" # saving (checkpoint) the model every 15 epochs\n",
" if epoch % 15 == 0:\n",
" if (epoch + 1) % 15 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
" \n",
" print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n",
Expand Down
Expand Up @@ -132,6 +132,7 @@
"tf.enable_eager_execution()\n",
"\n",
"import numpy as np\n",
"import os\n",
"import re\n",
"import random\n",
"import unidecode\n",
Expand Down Expand Up @@ -313,7 +314,7 @@
"outputs": [],
"source": [
"dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n",
"dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
"dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
]
},
{
Expand Down Expand Up @@ -493,7 +494,7 @@
"source": [
"# Training step\n",
"\n",
"EPOCHS = 30\n",
"EPOCHS = 20\n",
"\n",
"for epoch in range(EPOCHS):\n",
" start = time.time()\n",
Expand All @@ -520,7 +521,7 @@
" batch,\n",
" loss))\n",
" # saving (checkpoint) the model every 5 epochs\n",
" if epoch % 5 == 0:\n",
" if (epoch + 1) % 5 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
"\n",
" print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n",
Expand Down
Expand Up @@ -319,7 +319,7 @@
"vocab_tar_size = len(targ_lang.word2idx)\n",
"\n",
"dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
"dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
"dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
]
},
{
Expand Down Expand Up @@ -619,7 +619,7 @@
" batch,\n",
" batch_loss.numpy()))\n",
" # saving (checkpoint) the model every 2 epochs\n",
" if epoch % 2 == 0:\n",
" if (epoch + 1) % 2 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
" \n",
" print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
Expand Down
Expand Up @@ -701,7 +701,7 @@
" generate_images(generator, inp, tar)\n",
" \n",
" # saving (checkpoint) the model every 20 epochs\n",
" if epoch % 20 == 0:\n",
" if (epoch + 1) % 20 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
"\n",
" print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
Expand Down

0 comments on commit b416db3

Please sign in to comment.