Skip to content

Commit

Permalink
added daml code
Browse files Browse the repository at this point in the history
  • Loading branch information
tianheyu927 committed Jul 4, 2018
1 parent 7a201e8 commit cd5df89
Show file tree
Hide file tree
Showing 5 changed files with 657 additions and 570 deletions.
10 changes: 6 additions & 4 deletions data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,13 @@ def generate_batches(self, noisy=False):
offset = self.dataset_size - FLAGS.training_set_size - FLAGS.val_set_size
else:
offset = 0
train_img_folders = {i: os.path.join(self.demo_gif_dir, self.gif_prefix + '_%d' % i) for i in self.train_idx}
val_img_folders = {i: os.path.join(self.demo_gif_dir, self.gif_prefix + '_%d' % (i+offset)) for i in self.val_idx}
img_folders = natsorted(glob.glob(self.demo_gif_dir + self.gif_prefix + '_*'))
train_img_folders = {i: img_folders[i] for i in self.train_idx}
val_img_folders = {i: img_folders[i+offset] for i in self.val_idx}
if noisy:
noisy_train_img_folders = {i: os.path.join(self.noisy_demo_gif_dir, self.gif_prefix + '_%d' % i) for i in self.train_idx}
noisy_val_img_folders = {i: os.path.join(self.noisy_demo_gif_dir, self.gif_prefix + '_%d' % (i+offset)) for i in self.val_idx}
noisy_img_folders = natsorted(glob.glob(self.noisy_demo_gif_dir + self.gif_prefix + '_*'))
noisy_train_img_folders = {i: noisy_img_folders[i] for i in self.train_idx}
noisy_val_img_folders = {i: noisy_img_folders[i] for i in self.val_idx}
TEST_PRINT_INTERVAL = 500
TOTAL_ITERS = FLAGS.metatrain_iterations
self.all_training_filenames = []
Expand Down
19 changes: 17 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
flags.DEFINE_bool('no_action', False, 'do not include actions in the demonstrations for inner update')
flags.DEFINE_bool('no_state', False, 'do not include states in the demonstrations during training')
flags.DEFINE_bool('no_final_eept', False, 'do not include final ee pos in the demonstrations for inner update')
flags.DEFINE_bool('zero_state', False, 'zero-out states (meta-learn state) in the demonstrations for inner update')
flags.DEFINE_bool('zero_state', False, 'zero-out states (meta-learn state) in the demonstrations for inner update (used in the paper with video-only demos)')
flags.DEFINE_bool('two_arms', False, 'use two-arm structure when state is zeroed-out')
flags.DEFINE_integer('training_set_size', -1, 'size of the training set, 1500 for sim_reach, 693 for sim push, and \
-1 for all data except those in validation set')
Expand All @@ -55,6 +55,7 @@
flags.DEFINE_bool('learn_final_eept', False, 'learn an auxiliary loss for predicting final end-effector pose')
flags.DEFINE_bool('learn_final_eept_whole_traj', False, 'learn an auxiliary loss for predicting final end-effector pose \
by passing the whole trajectory of eepts (used for video-only models)')
flags.DEFINE_bool('stopgrad_final_eept', True, 'stop the gradient when concatenate the predicted final eept with the feature points')
flags.DEFINE_integer('final_eept_min', 6, 'first index of the final eept in the action array')
flags.DEFINE_integer('final_eept_max', 8, 'last index of the final eept in the action array')
flags.DEFINE_float('final_eept_loss_eps', 0.1, 'the coefficient of the auxiliary loss')
Expand All @@ -78,6 +79,14 @@
flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network, only applicable in some cases')
flags.DEFINE_integer('num_fc_layers', 3, 'number of fully-connected layers')
flags.DEFINE_integer('layer_size', 100, 'hidden dimension of fully-connected layers')
flags.DEFINE_bool('temporal_conv_2_head', False, 'whether or not to use temporal convolutions for the two-head architecture in video-only setting.')
flags.DEFINE_bool('temporal_conv_2_head_ee', False, 'whether or not to use temporal convolutions for the two-head architecture in video-only setting \
for predicting the ee pose.')
flags.DEFINE_integer('temporal_filter_size', 5, 'filter size for temporal convolution')
flags.DEFINE_integer('temporal_num_filters', 64, 'number of filters for temporal convolution')
flags.DEFINE_integer('temporal_num_filters_ee', 64, 'number of filters for temporal convolution for ee pose prediction')
flags.DEFINE_integer('temporal_num_layers', 3, 'number of layers for temporal convolution for ee pose prediction')
flags.DEFINE_integer('temporal_num_layers_ee', 3, 'number of layers for temporal convolution for ee pose prediction')
flags.DEFINE_string('init', 'random', 'initializer for conv weights. Choose among random, xavier, and he')
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)')
Expand All @@ -88,7 +97,8 @@
flags.DEFINE_bool('resume', False, 'resume training if there is a model available')
flags.DEFINE_bool('train', True, 'True to train, False to test.')
flags.DEFINE_integer('restore_iter', 0, 'iteration to load model (-1 for latest model)')
flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number).')
flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training \
(use if you want to test with a different number).')
flags.DEFINE_integer('test_update_batch_size', 1, 'number of demos used during test time')
flags.DEFINE_float('gpu_memory_fraction', 1.0, 'fraction of memory used in gpu')
flags.DEFINE_bool('record_gifs', True, 'record gifs during evaluation')
Expand Down Expand Up @@ -241,6 +251,11 @@ def main():
exp_string += '.two_heads'
if FLAGS.two_arms:
exp_string += '.two_arms'
if FLAGS.temporal_conv_2_head:
exp_string += '.1d_conv_act_' + str(FLAGS.temporal_num_layers) + '_' + str(FLAGS.temporal_num_filters)
if FLAGS.temporal_conv_2_head_ee:
exp_string += '_ee_' + str(FLAGS.temporal_num_layers_ee) + '_' + str(FLAGS.temporal_num_filters_ee)
exp_string += '_' + str(FLAGS.temporal_filter_size) + 'x1_filters'
if FLAGS.training_set_size != -1:
exp_string += '.' + str(FLAGS.training_set_size) + '_trials'

Expand Down
Loading

0 comments on commit cd5df89

Please sign in to comment.