From 0ac0a94ab879f5e1eff6738e94c584445488791d Mon Sep 17 00:00:00 2001 From: Kenneth-Schroeder Date: Fri, 4 Feb 2022 14:51:36 +0100 Subject: [PATCH 1/4] fixing cast to int of returns in PGPolicy --- tianshou/policy/modelfree/pg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index a525746bc..e53c5d471 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -131,7 +131,7 @@ def learn( # type: ignore result = self(minibatch) dist = result.dist act = to_torch_as(minibatch.act, result.act) - ret = to_torch_as(minibatch.returns, result.act) + ret = to_torch(minibatch.returns, result.act.device, torch.float) log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) loss = -(log_prob * ret).mean() loss.backward() From 41ee416bed442f6f9d138a97681702a3615eabe0 Mon Sep 17 00:00:00 2001 From: Kenneth-Schroeder Date: Fri, 4 Feb 2022 14:55:33 +0100 Subject: [PATCH 2/4] change parameter order in to_torch call --- tianshou/policy/modelfree/pg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index e53c5d471..087c91d61 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -131,7 +131,7 @@ def learn( # type: ignore result = self(minibatch) dist = result.dist act = to_torch_as(minibatch.act, result.act) - ret = to_torch(minibatch.returns, result.act.device, torch.float) + ret = to_torch(minibatch.returns, torch.float, result.act.device) log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) loss = -(log_prob * ret).mean() loss.backward() From dc45d4d8d5e55bf73dfda0e0a4326e2d05b5ae53 Mon Sep 17 00:00:00 2001 From: Kenneth-Schroeder Date: Fri, 4 Feb 2022 15:04:03 +0100 Subject: [PATCH 3/4] add to_torch import --- tianshou/policy/modelfree/pg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 087c91d61..6ad2e190d 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -3,7 +3,7 @@ import numpy as np import torch -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_torch from tianshou.policy import BasePolicy from tianshou.utils import RunningMeanStd From 5ff31453f10adfba2f7063d7d099b1c6670eef02 Mon Sep 17 00:00:00 2001 From: Kenneth-Schroeder Date: Fri, 4 Feb 2022 23:04:37 +0100 Subject: [PATCH 4/4] fix import ordering, ran all contributing make commands --- tianshou/policy/modelfree/pg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 6ad2e190d..9149a383b 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -3,7 +3,7 @@ import numpy as np import torch -from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_torch +from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as from tianshou.policy import BasePolicy from tianshou.utils import RunningMeanStd