Skip to content

Commit

Permalink
updated parallel GPU to work with trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Aug 3, 2019
2 parents fdf9b42 + 2230629 commit 4e900b1
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 70 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
from setuptools import find_packages, setup

version = '0.651'
version = '0.6.7.4'

setup(
name='test_tube',
Expand All @@ -13,7 +13,7 @@
'pandas>=0.20.3',
'numpy>=1.13.3',
'imageio>=2.3.0',
'tensorboard>=1.14.0',
'tb-nightly==1.15.0a20190708',
'torch>=1.1.0',
'future'
],
Expand Down
19 changes: 18 additions & 1 deletion test_tube/argparse_hopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
try:
import torch
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
# multiprocessing.set_start_method('spawn', force=True)
except ModuleNotFoundError:
pass

Expand Down Expand Up @@ -97,6 +97,23 @@ def __init__(self, strategy='grid_search', **kwargs):
self.json_config_arg_name = None
self.pool = None

def __getstate__(self):
# capture what is normally pickled
state = self.__dict__.copy()

# remove all functions from the namespace
clean_state = {}
for k, v in state.items():
if not hasattr(v, '__call__'):
clean_state[k] = v

# what we return here will be stored in the pickle
return clean_state

def __setstate__(self, newstate):
# re-instate our __dict__ state from the pickled state
self.__dict__.update(newstate)

def add_argument(self, *args, **kwargs):
super(HyperOptArgumentParser, self).add_argument(*args, **kwargs)

Expand Down
32 changes: 15 additions & 17 deletions test_tube/hpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@ def __init__(
python_cmd='python3',
enable_log_err=True,
enable_log_out=True,
test_tube_exp_name=None
):
self.hyperparam_optimizer = hyperparam_optimizer
self.log_path = log_path
if log_path is not None:
self.log_path = os.path.join(log_path, 'test_tube_data')

self.enable_log_err = enable_log_err
self.enable_log_out = enable_log_out
self.test_tube_exp_name = test_tube_exp_name
self.slurm_files_log_path = None
self.err_log_path = None
self.out_log_path = None
Expand All @@ -58,6 +54,7 @@ def __init__(
self.call_load_checkpoint = False
self.commands = []
self.slurm_commands = []
self.hpc_exp_number = 0

# these are set via getters and setters so we can use a BaseManager which can be shared across processes
self.checkpoint_save_function = None
Expand All @@ -66,6 +63,7 @@ def __init__(
# detect when this was called because a slurm object started a hopt.
# if true, remove the flag so tt logs don't show it
if hyperparam_optimizer is not None:

self.is_from_slurm_object = HyperOptArgumentParser.TRIGGER_CMD in vars(self.hyperparam_optimizer) and vars(self.hyperparam_optimizer)[HyperOptArgumentParser.TRIGGER_CMD] == True
if self.is_from_slurm_object:
self.hyperparam_optimizer.__delattr__(HyperOptArgumentParser.TRIGGER_CMD)
Expand All @@ -74,6 +72,8 @@ def __init__(
if self.call_load_checkpoint:
self.hyperparam_optimizer.__delattr__(HyperOptArgumentParser.SLURM_LOAD_CMD)

self.hpc_exp_number = self.hyperparam_optimizer.hpc_exp_number

def set_checkpoint_save_function(self, fx, kwargs):
self.checkpoint_save_function = [fx, kwargs]

Expand Down Expand Up @@ -171,11 +171,12 @@ def __optimize_parallel_cluster_internal(
trials = self.hyperparam_optimizer.generate_trials(nb_trials)

# get the max test tube exp version so far if it's there
next_test_tube_version = self.__get_max_test_tube_version(self.log_path)
scripts_path = os.path.join(self.log_path, 'slurm_out_logs')
next_trial_version = self.__get_max_trial_version(scripts_path)

# for each trial, generate a slurm command
for i, trial_params in enumerate(trials):
exp_i = i + next_test_tube_version
exp_i = i + next_trial_version
self.schedule_experiment(trial_params, exp_i)

def schedule_experiment(self, trial_params, exp_i):
Expand Down Expand Up @@ -315,12 +316,12 @@ def __save_slurm_cmd(self, slurm_cmd, slurm_cmd_script_path):
with open(slurm_cmd_script_path, mode='w') as file:
file.write(slurm_cmd)

def __get_max_test_tube_version(self, path):
def __get_max_trial_version(self, path):
files = os.listdir(path)
version_files = [f for f in files if 'version_' in f]
version_files = [f for f in files if 'trial_' in f]
if len(version_files) > 0:
# regex out everything except file version for ve
versions = [int(re.sub('version_', '', f_name)) for f_name in version_files]
versions = [int(f_name.split('_')[1]) for f_name in version_files]
max_version = max(versions)
return max_version + 1
else:
Expand All @@ -333,10 +334,7 @@ def __layout_logging_dir(self):
"""

# format the logging folder path
if self.test_tube_exp_name is not None:
slurm_out_path = os.path.join(self.log_path, self.test_tube_exp_name)
else:
slurm_out_path = os.path.join(self.log_path, self.job_name)
slurm_out_path = os.path.join(self.log_path, self.job_name)

self.log_path = slurm_out_path

Expand Down Expand Up @@ -446,14 +444,14 @@ def __build_slurm_command(self, trial, slurm_cmd_script_path, timestamp, exp_i,
# add nb of gpus
if self.per_experiment_nb_gpus > 0 and on_gpu:
command = [
'# gpus per cluster',
'#SBATCH --gres gpu:{}'.format(self.per_experiment_nb_gpus),
'# gpus per node',
'#SBATCH --gres=gpu:{}'.format(self.per_experiment_nb_gpus),
'#################\n'
]
if self.gpu_type is not None:
command = [
'# gpus per cluster',
'#SBATCH --gres gpu:{}:{}'.format(self.gpu_type, self.per_experiment_nb_gpus),
'# gpus per node',
'#SBATCH --gres=gpu:{}:{}'.format(self.gpu_type, self.per_experiment_nb_gpus),
'#################\n'
]
sub_commands.extend(command)
Expand Down

0 comments on commit 4e900b1

Please sign in to comment.