Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Naming and typing improvements in Actor/Critic/Policy forwards #1032

Merged
merged 14 commits into from
Apr 1, 2024

Conversation

arnaujc91
Copy link
Contributor

@arnaujc91 arnaujc91 commented Jan 23, 2024

Closes #917

Internal Improvements

Breaking Changes

@arnaujc91
Copy link
Contributor Author

@MischaPanch as of now if you agree with the naming there will be other places, as you mentioned, where logits should be replaced. So will continue once we agree the naming is fine for you.

@codecov-commenter
Copy link

codecov-commenter commented Jan 23, 2024

Codecov Report

Attention: Patch coverage is 97.75281% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 88.21%. Comparing base (edae9e4) to head (89d888f).
Report is 4 commits behind head on master.

❗ Current head 89d888f differs from pull request most recent head ef0b0dc. Consider uploading reports for the commit ef0b0dc to get more accurate results

Files Patch % Lines
tianshou/policy/imitation/base.py 87.50% 1 Missing ⚠️
tianshou/utils/net/common.py 66.66% 1 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1032      +/-   ##
==========================================
+ Coverage   88.15%   88.21%   +0.05%     
==========================================
  Files         100      100              
  Lines        8180     8297     +117     
==========================================
+ Hits         7211     7319     +108     
- Misses        969      978       +9     
Flag Coverage Δ
unittests 88.21% <97.75%> (+0.05%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@arnaujc91
Copy link
Contributor Author

ModelOutputBatchProtocol might need to be modified too if changes are accepted.

@arnaujc91
Copy link
Contributor Author

arnaujc91 commented Feb 1, 2024

Hi guys, these places i think logits

logits --> action_dist_input

should also be reformated:

ImitationPolicy Class

) -> ModelOutputBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits
result = Batch(logits=logits, act=act, state=hidden)
return cast(ModelOutputBatchProtocol, result)

DiscreteSACPolicy Class

logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits)
if self.deterministic_eval and not self.training:
act = logits.argmax(axis=-1)
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist)

SACPolicy Class

logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
if self.deterministic_eval and not self.training:
act = logits[0]
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1,
keepdim=True,
)
result = Batch(
logits=logits,
act=squashed_action,
state=hidden,
dist=dist,
log_prob=log_prob,
)

REDQPolicy Class: Also here loc_scale maybe should be renamed? @MischaPanch

def forward( # type: ignore
self,
batch: ObsBatchProtocol,
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale
dist = Independent(Normal(loc, scale), 1)
act = loc if self.deterministic_eval and not self.training else dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1,
keepdim=True,
)
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)

Net Class

def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any = None,
**kwargs: Any,
) -> tuple[torch.Tensor, Any]:
"""Mapping: obs -> flatten (inside MLP)-> logits.
:param obs:
:param state: unused and returned as is
:param kwargs: unused
"""
logits = self.model(obs)
batch_size = logits.shape[0]
if self.use_dueling: # Dueling DQN
assert self.Q is not None
assert self.V is not None
q, v = self.Q(logits), self.V(logits)
if self.num_atoms > 1:
q = q.view(batch_size, -1, self.num_atoms)
v = v.view(batch_size, -1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
elif self.num_atoms > 1:
logits = logits.view(batch_size, -1, self.num_atoms)
if self.softmax:
logits = torch.softmax(logits, dim=-1)
return logits, state

Actor Class

def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action."""
if info is None:
info = {}
logits, hidden = self.preprocess(obs, state)
logits = self.max_action * torch.tanh(self.last(logits))
return logits, hidden

Actor Class

def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""
if info is None:
info = {}
logits, hidden = self.preprocess(obs, state)
logits = self.last(logits)
if self.softmax_output:
logits = F.softmax(logits, dim=-1)
return logits, hidden

So basically for all actors and all policies that uses actors. Also seems the Net class which i have to figure out how it is properly related to Policies and Actors.

Any objections? @MischaPanch @opcode81

@MischaPanch
Copy link
Collaborator

REDQPolicy Class: Also here loc_scale maybe should be renamed? @MischaPanch

def forward( # type: ignore
self,
batch: ObsBatchProtocol,
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale
dist = Independent(Normal(loc, scale), 1)
act = loc if self.deterministic_eval and not self.training else dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1,
keepdim=True,
)
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)

What's wrong with the loc_scale?

So basically for all actors and all policies that uses actors. Also seems the Net class which i have to figure out how it is properly related to Policies and Actors.

Any objections? @MischaPanch @opcode81

Looks good, go ahead! But you should also rename things inside of some BatchPrototypes, as parrtially discussed above

@arnaujc91
Copy link
Contributor Author

REDQPolicy Class: Also here loc_scale maybe should be renamed? @MischaPanch

def forward( # type: ignore
self,
batch: ObsBatchProtocol,
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale
dist = Independent(Normal(loc, scale), 1)
act = loc if self.deterministic_eval and not self.training else dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1,
keepdim=True,
)
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)

What's wrong with the loc_scale?

So basically for all actors and all policies that uses actors. Also seems the Net class which i have to figure out how it is properly related to Policies and Actors.
Any objections? @MischaPanch @opcode81

Looks good, go ahead! But you should also rename things inside of some BatchPrototypes, as parrtially discussed above

Because the logical flow i have seen is usually

self.policy -> dist

so the output of the policy network (torch forward method) is usually a tuple, where the second element is some state and the first element is what you will always input to the Torch distribution as the code demonstrates:

loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) 
loc, scale = loc_scale 
dist = Independent(Normal(loc, scale), 1) 

I just wanted to keep consistency. If i call it action_dist_input somewhere because is the input of the Torch distribution i should also do it anywhere else where thats the case too. That was my thought.

@MischaPanch
Copy link
Collaborator

I just wanted to keep consistency. If i call it action_dist_input somewhere because is the input of the Torch distribution i should also do it anywhere else where thats the case too. That was my thought.

In this case it is not passed to a general action dist, but to a gaussian, where it is used as loc and scale. So here the name loc_scale is the most appropriate one, I think. Names should always reflect the semantics of the variable as precisely as possible

@arnaujc91
Copy link
Contributor Author

arnaujc91 commented Feb 1, 2024

Hmm I bet that in the case of a continuous actor:

def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action."""
if info is None:
info = {}
logits, hidden = self.preprocess(obs, state)
logits = self.max_action * torch.tanh(self.last(logits))
return logits, hidden

The output of line 85 that is named logits is also a mean and standard deviation, no? But is still called logits and not loc_scale . Maybe i am misunderstanding something.

@MischaPanch
Copy link
Collaborator

MischaPanch commented Feb 1, 2024

The output of line 85 that is named logits is also a mean and standard deviation, no? But is still called logits and not loc_scale . Maybe i am misunderstanding something.

here you don't know how the output of forward will be used (into which kind of continuous action_dist it would go), so you shouldn't call it loc_scale. In the code above, on the other hand, the variable is fed into a Gaussian as loc and scale, so the semantics are clear.

If tianshou would only support Gaussians for continuous actors, then the output of the continuous Actor should indeed be loc_scale. For sure it should never be logits - please change that as part of your PR!

The fact that tianshou right now implicitly depends on the action_dist_input being of the loc_scale format in some places is a separate problem and deserves a separate issue. Fixing that is outside of the scope.

In this PR, the goal should be to make the naming at least a bit better and consistent with the current interfaces. If you know that the variable has to be of the form (loc, scale) in some place (like in the example above), then it should be named as such. If you only know that it will be added to a Batch's field called action_dist_input, then the var should be called action_dist_input, b/c in principle tianshou's current interfaces don't allow you to infer a more precise semantics.

That's what I mean when I say the the var names should reflect semantics as precisely as possible

@arnaujc91
Copy link
Contributor Author

arnaujc91 commented Feb 1, 2024

Understood, so wherever there are logits appearing that can be interpreted as action_dist_input i will make the replacement

logits -> action_dist_input

but if already a more accurate name (e.g. loc_scale) is already given, i will leave it untouched. I think we agree here.

@arnaujc91 arnaujc91 marked this pull request as ready for review March 26, 2024 19:49
Copy link
Collaborator

@MischaPanch MischaPanch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @arnaujc91 , it brings us an important step towards removing wrong names across the source code!

The issue is not fully complete at this point since all kinds of BatchProtocol still carry logits, but it can't be fully completed before introducing an Algorithm abstraction as part of release 2.0.0 anyway. So for the version 1.x,y this is likely as far as we can push it.

I will address the comments myself and merge this

tianshou/policy/imitation/discrete_bcq.py Outdated Show resolved Hide resolved
tianshou/policy/imitation/discrete_cql.py Outdated Show resolved Hide resolved
tianshou/policy/imitation/discrete_crr.py Outdated Show resolved Hide resolved
tianshou/policy/imitation/discrete_crr.py Outdated Show resolved Hide resolved
tianshou/policy/imitation/gail.py Outdated Show resolved Hide resolved
tianshou/policy/modelfree/pg.py Outdated Show resolved Hide resolved
action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
# in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A
# therefore action_dist_input_BD is equivalent to logits_BA
if self.action_type == "discrete":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if-else seems weird and unnecessary, I'll look into it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Required a breaking change to fix, see d0d9d45. Also required a change in the (interiors of the) high-level interfaces

@opcode81 @maxhuettenrauch pls have a look at that commit. Max, I believe you had written a todo on how to improve the typing of the dist-fn - this commit addresses the todo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran several low-level and high-level examples, things work. I think I haven't missed anything :)

tianshou/policy/modelfree/sac.py Outdated Show resolved Hide resolved
@@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t
"""Implementation of TD3, arXiv:1802.09477.

:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it's true, the actor might be expected to output dist inputs. Gonna look into it

logits, hidden = self.preprocess(obs, state)
logits = self.max_action * torch.tanh(self.last(logits))
return logits, hidden
"""Mapping: s -> action_values, hidden_state.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow naming convention

tianshou/policy/modelfree/trpo.py Outdated Show resolved Hide resolved
tianshou/policy/modelfree/trpo.py Outdated Show resolved Hide resolved
@MischaPanch MischaPanch changed the title Refactored pg.py logits variable name. Naming and typing improvements in Actor/Critic/Policy forwards Apr 1, 2024
@MischaPanch MischaPanch merged commit bf0d632 into thu-ml:master Apr 1, 2024
ZhengLi1314 pushed a commit to ZhengLi1314/tianshou_0.5.1 that referenced this pull request Apr 15, 2024
…l#1032)

Closes thu-ml#917 

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). thu-ml#1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. thu-ml#1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. thu-ml#1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). thu-ml#1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. thu-ml#1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. thu-ml#1032

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
ZhengLi1314 pushed a commit to ZhengLi1314/tianshou_0.5.1 that referenced this pull request Apr 15, 2024
…l#1032)

Closes thu-ml#917

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). thu-ml#1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. thu-ml#1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. thu-ml#1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). thu-ml#1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. thu-ml#1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. thu-ml#1032

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
ZhengLi1314 pushed a commit to ZhengLi1314/tianshou_0.5.1 that referenced this pull request Apr 15, 2024
…l#1032)

Closes thu-ml#917

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). thu-ml#1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. thu-ml#1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. thu-ml#1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). thu-ml#1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. thu-ml#1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. thu-ml#1032

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
ZhengLi1314 pushed a commit to ZhengLi1314/tianshou_0.5.1 that referenced this pull request Apr 15, 2024
…l#1032)

Closes thu-ml#917 

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). thu-ml#1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. thu-ml#1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. thu-ml#1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). thu-ml#1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. thu-ml#1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. thu-ml#1032

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rename logits to model_output
4 participants