Skip to content

Commit

Permalink
Merge pull request #12 from psteinb/try-tf-1.6
Browse files Browse the repository at this point in the history
Try tf 1.6
  • Loading branch information
Peter Steinbach committed Apr 23, 2018
2 parents 98fa9c6 + 227a4d2 commit e9ec14e
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 12 deletions.
Empty file removed models/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def versions(self):

value = ""

if "keras" in self.backend.lower():
if self.backend.lower().startswith("keras"):

import keras
from keras import backend as K
Expand All @@ -143,7 +143,7 @@ def versions(self):

else:

if "tensorflow" in self.backend.lower():
if self.backend.lower() == "tensorflow" or self.backend.lower() == "tf":
import tensorflow as tf
value = "tensorflow:{ver}".format(ver=tf.__version__)
else:
Expand Down
7 changes: 6 additions & 1 deletion models/tf_details/resnet_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def can_train():
warnings.simplefilter(action='ignore', category=FutureWarning)

from tensorflow import __version__ as tfv
required = "1.7.0"
required = "1.6.0"

#only require major and minor release number as the patch number may contain 'rc' etc
if versiontuple(tfv,2) >= versiontuple(required,2):
Expand Down Expand Up @@ -111,4 +111,9 @@ def train(train, test, datafraction, opts):
logging.info('handing over \n >> %s \n >> %s',flags,opts)
history, timings = run_loop.resnet_main(flags, cfmain.cifar10_model_fn, cfmain.input_fn, opts)

if not opts['checkpoint_epochs']:
logging.info("unable to ensure pure no-checkpoint behavior with resnet in pure tensorflow, removing result directory")
import shutil
shutil.rmtree(model_dir)

return history, timings, { 'num_weights' : None }
7 changes: 1 addition & 6 deletions models/tf_details/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def resnet_main(flags, model_function, input_function, opts = None):
logging.warning("batch sizes differ in model %i %s", flags.batch_size, opts["batch_size"])

if ngpus > 1:
steps_per_epoch -= 1
validate_batch_size_for_multi_gpu(flags.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
Expand Down Expand Up @@ -407,12 +408,6 @@ def input_fn_eval():
validation_results = classifier.evaluate(input_fn=input_fn_eval,
steps=flags.max_train_steps)


# for (k,v) in train_hooks["CaptureTensorsHook"].captured.items():
# print(">> ",k,v[:5],v[-2:])

#epoch_times.extend(train_hooks["TimePerEpochHook"].epoch_durations)

for k in validation_results.keys():
if "global_step" in k:
continue
Expand Down
8 changes: 5 additions & 3 deletions models/tf_details/utils/logging/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def add(self,other):
class TimePerEpochHook(tf.train.SessionRunHook):
def __init__(self,
every_n_steps,
warm_steps=0):
warm_steps=-1):

self.every_n_steps = every_n_steps
logging.info("TimePerEpochHook triggering every %i steps",every_n_steps)
Expand Down Expand Up @@ -112,8 +112,8 @@ def after_run(self, run_context, run_values): # pylint: disable=unused-argument
global_step = run_values.results
sess = run_context.session


if self._timer.should_trigger_for_step(global_step) and global_step > self._warm_steps:
#if self._timer.should_trigger_for_step(global_step) and global_step > self._warm_steps:
if self._step % self.every_n_steps == 0:
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
Expand All @@ -124,6 +124,8 @@ def after_run(self, run_context, run_values): # pylint: disable=unused-argument
tf.logging.info('Epoch [%g steps]: %g (%s)', self._total_steps,self._epoch_train_time,str(self.epoch_durations))

self._epoch_train_time = 0
else:
logging.warning("step %i, elapsed_time is None!", global_step)



Expand Down

0 comments on commit e9ec14e

Please sign in to comment.