Skip to content

Commit

Permalink
Initial commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
tkarras committed Feb 5, 2019
0 parents commit f3a0446
Show file tree
Hide file tree
Showing 32 changed files with 6,346 additions and 0 deletions.
410 changes: 410 additions & 0 deletions LICENSE.txt

Large diffs are not rendered by default.

230 changes: 230 additions & 0 deletions README.md

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

"""Global configuration."""

#----------------------------------------------------------------------------
# Paths.

result_dir = 'results'
data_dir = 'datasets'
cache_dir = 'cache'

#----------------------------------------------------------------------------
645 changes: 645 additions & 0 deletions dataset_tool.py

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions dnnlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

from . import submission

from .submission.run_context import RunContext

from .submission.submit import SubmitTarget
from .submission.submit import PathType
from .submission.submit import SubmitConfig
from .submission.submit import get_path_from_template
from .submission.submit import submit_run

from .util import EasyDict

submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
9 changes: 9 additions & 0 deletions dnnlib/submission/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

from . import run_context
from . import submit
45 changes: 45 additions & 0 deletions dnnlib/submission/_internal/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

"""Helper for launching run functions in computing clusters.
During the submit process, this file is copied to the appropriate run dir.
When the job is launched in the cluster, this module is the first thing that
is run inside the docker container.
"""

import os
import pickle
import sys

# PYTHONPATH should have been set so that the run_dir/src is in it
import dnnlib

def main():
if not len(sys.argv) >= 4:
raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")

run_dir = str(sys.argv[1])
task_name = str(sys.argv[2])
host_name = str(sys.argv[3])

submit_config_path = os.path.join(run_dir, "submit_config.pkl")

# SubmitConfig should have been pickled to the run dir
if not os.path.exists(submit_config_path):
raise RuntimeError("SubmitConfig pickle file does not exist!")

submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
dnnlib.submission.submit.set_user_name_override(submit_config.user_name)

submit_config.task_name = task_name
submit_config.host_name = host_name

dnnlib.submission.submit.run_wrapper(submit_config)

if __name__ == "__main__":
main()
99 changes: 99 additions & 0 deletions dnnlib/submission/run_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

"""Helpers for managing the run/training loop."""

import datetime
import json
import os
import pprint
import time
import types

from typing import Any

from . import submit


class RunContext(object):
"""Helper class for managing the run/training loop.
The context will hide the implementation details of a basic run/training loop.
It will set things up properly, tell if run should be stopped, and then cleans up.
User should call update periodically and use should_stop to determine if run should be stopped.
Args:
submit_config: The SubmitConfig that is used for the current run.
config_module: The whole config module that is used for the current run.
max_epoch: Optional cached value for the max_epoch variable used in update.
"""

def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
self.submit_config = submit_config
self.should_stop_flag = False
self.has_closed = False
self.start_time = time.time()
self.last_update_time = time.time()
self.last_update_interval = 0.0
self.max_epoch = max_epoch

# pretty print the all the relevant content of the config module to a text file
if config_module is not None:
with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)

# write out details about the run to a text file
self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)

def __enter__(self) -> "RunContext":
return self

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()

def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
"""Do general housekeeping and keep the state of the context up-to-date.
Should be called often enough but not in a tight loop."""
assert not self.has_closed

self.last_update_interval = time.time() - self.last_update_time
self.last_update_time = time.time()

if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
self.should_stop_flag = True

max_epoch_val = self.max_epoch if max_epoch is None else max_epoch

def should_stop(self) -> bool:
"""Tell whether a stopping condition has been triggered one way or another."""
return self.should_stop_flag

def get_time_since_start(self) -> float:
"""How much time has passed since the creation of the context."""
return time.time() - self.start_time

def get_time_since_last_update(self) -> float:
"""How much time has passed since the last call to update."""
return time.time() - self.last_update_time

def get_last_update_interval(self) -> float:
"""How much time passed between the previous two calls to update."""
return self.last_update_interval

def close(self) -> None:
"""Close the context and clean up.
Should only be called once."""
if not self.has_closed:
# update the run.txt with stopping time
self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)

self.has_closed = True
Loading

0 comments on commit f3a0446

Please sign in to comment.