Skip to content

Commit

Permalink
Merge pull request #51 from williamFalcon/1.2
Browse files Browse the repository at this point in the history
1.2
  • Loading branch information
williamFalcon committed Aug 8, 2019
2 parents 736ba69 + 3fba70a commit 93db0ad
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions test_tube/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ class Experiment(SummaryWriter):

def __init__(
self,
save_dir=None,
name='default',
debug=False,
version=None,
save_dir=None,
autosave=False,
description=None,
create_git_tag=False,
Expand Down Expand Up @@ -150,8 +150,8 @@ def __init__(
print('Test tube created git tag:', 'tt_{}'.format(self.exp_hash))

# set the tensorboardx log path to the /tf folder in the exp folder
logdir = self.get_tensorboardx_path(self.name, self.version)
super().__init__(log_dir=logdir, *args, **kwargs)
log_dir = self.get_tensorboardx_path(self.name, self.version)
super().__init__(log_dir=log_dir, *args, **kwargs)

# register on exit fx so we always close the writer
atexit.register(self.on_exit)
Expand All @@ -169,15 +169,15 @@ def on_exit(self):
self.close()

def __clean_dir(self):
files = os.listdir(self.log_dir)
files = os.listdir(self.save_dir)

if self.rank == 0:
return

for f in files:
if str(self.process) in f:
self.close()
os.remove(os.path.join(self.log_dir, f))
os.remove(os.path.join(self.save_dir, f))

def argparse(self, argparser):
parsed = vars(argparser)
Expand Down Expand Up @@ -501,9 +501,10 @@ def _get_file_writer(self):
return TTDummyFileWriter()

if self.all_writers is None or self.file_writer is None:
if 'purge_step' in self.kwargs.keys():
most_recent_step = self.kwargs.pop('purge_step')
self.file_writer = FileWriter(logdir=self.log_dir, **self.kwargs)
if self.purge_step is not None:
most_recent_step = self.purge_step
self.file_writer = FileWriter(self.save_dir, self.max_queue,
self.flush_secs, self.filename_suffix)
self.file_writer.debug = self.debug
self.file_writer.rank = self.rank

Expand All @@ -512,7 +513,8 @@ def _get_file_writer(self):
self.file_writer.add_event(
Event(step=most_recent_step, session_log=SessionLog(status=SessionLog.START)))
else:
self.file_writer = FileWriter(logdir=self.log_dir, **self.kwargs)
self.file_writer = FileWriter(self.save_dir, self.max_queue,
self.flush_secs, self.filename_suffix)
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
return self.file_writer

Expand Down

0 comments on commit 93db0ad

Please sign in to comment.