-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use custom loss ? #58
Comments
That's a good question. Currently, you can either inherit the policy class (as you mentioned) or change the original framework's code to meet your expectations. It can be discussed further. Some existing frameworks (like RLlib) modularized the loss function part. But in my opinion, this could be inconvenient for further development. Since the loss function is highly customizable, making the abstraction of the loss function will double the code complexity. |
Ok, I got your point and I agree with you. But what about adding a In this case, it is not a custom loss strictly speaking, but rather additional component to the original loss function (regularization), that may depend on the actor. So that it only consists in an extra function call before calling backward. I don't know if doing so is usual or not. |
@Trinkle23897 Up ! |
I have no time after #106 before this Friday...Many things to do |
No problem ! I can do it ! But what do you think about the idea ? |
I think that add |
@duburcqa It's a great idea to make it easier with a customized loss. I wondered if you have made any progress on that. Thanks! |
The loss is an integral part of the algorithm, so maybe inheriting and overriding is better than allowing users to pass custom losses. It's a central design question, I don't see it being necessary for the 1.0.0 release, but would keep the issue open |
I would like to add the following extra term to the loss function,
![|| y_{pred} - y_{ref} ||_2^2](https://render.githubusercontent.com/render/math?math=%7C%7C%20y_%7Bpred%7D%20-%20y_%7Bref%7D%20%7C%7C_2%5E2)
is the action sampled by the distribution, and
can be computed by the actor.
where
What is the best way to do it using your framework ? The point is being able to take advantage of the analytical gradient computation.
The only way I can think of is to overwrite the whole
learn
method of the policy (i.e. PPO algorithm), but it feels inconvenient just to add an extra line of code...Thank you in advance,
Best,
Alexis
The text was updated successfully, but these errors were encountered: