Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rraileanu committed Oct 8, 2020
1 parent 4081032 commit 90c1aa2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion env_utils.py
Expand Up @@ -410,7 +410,7 @@ def generic_step(self, policy_fn, state):
.reshape(self.enc_input_size, 1)

res = {'next_state': next_state_tensor, 'action': action_tensor, 'sa_tensor': state_action_tensor,
'sas_tensor': sas_tensor, 'reward': reward, 'done': done}
'sas_tensor': sas_tensor.to(self.args.device), 'reward': reward, 'done': done}
return res

def sample_env_context(self, policy_fn, env_idx=None):
Expand Down
8 changes: 4 additions & 4 deletions train_pdvf.py
Expand Up @@ -269,7 +269,7 @@ def train_pdvf():
total_eval_loss += eval_loss
if eval_loss < BEST_EVAL_LOSS:
BEST_EVAL_LOSS = eval_loss
utils.save_model("pdvf-stage0.", value_net, optimizer, \
pdvf_utils.save_model("pdvf-stage0.", value_net, optimizer, \
i, args, args.env_name, save_dir=args.save_dir_pdvf)

if i % args.log_interval == 0:
Expand All @@ -294,7 +294,7 @@ def train_pdvf():

if decoder_eval_loss < DECODER_BEST_EVAL_LOSS:
DECODER_BEST_EVAL_LOSS = decoder_eval_loss
utils.save_model("policy-decoder-stage0.", policy_decoder, decoder_optimizer, \
pdvf_utils.save_model("policy-decoder-stage0.", policy_decoder, decoder_optimizer, \
i, args, args.env_name, save_dir=args.save_dir_pdvf)

if i % args.log_interval == 0:
Expand Down Expand Up @@ -482,7 +482,7 @@ def train_pdvf():

if eval_loss < BEST_EVAL_LOSS:
BEST_EVAL_LOSS = eval_loss
utils.save_model("pdvf-stage{}.".format(k+1), value_net, optimizer, \
pdvf_utils.save_model("pdvf-stage{}.".format(k+1), value_net, optimizer, \
i, args, args.env_name, save_dir=args.save_dir_pdvf)

if i % args.log_interval == 0:
Expand All @@ -507,7 +507,7 @@ def train_pdvf():

if decoder_eval_loss < DECODER_BEST_EVAL_LOSS:
DECODER_BEST_EVAL_LOSS = decoder_eval_loss
utils.save_model("policy-decoder-stage{}.".format(k+1), policy_decoder, decoder_optimizer, \
pdvf_utils.save_model("policy-decoder-stage{}.".format(k+1), policy_decoder, decoder_optimizer, \
i, args, args.env_name, save_dir=args.save_dir_pdvf)

if i % args.log_interval == 0:
Expand Down

0 comments on commit 90c1aa2

Please sign in to comment.