diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index a525746bc..9149a383b 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -3,7 +3,7 @@ import numpy as np import torch -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as from tianshou.policy import BasePolicy from tianshou.utils import RunningMeanStd @@ -131,7 +131,7 @@ def learn( # type: ignore result = self(minibatch) dist = result.dist act = to_torch_as(minibatch.act, result.act) - ret = to_torch_as(minibatch.returns, result.act) + ret = to_torch(minibatch.returns, torch.float, result.act.device) log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) loss = -(log_prob * ret).mean() loss.backward()