Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VerticalCAS fails to learn the score table policy #4

Closed
nskh opened this issue Jul 11, 2023 · 2 comments
Closed

VerticalCAS fails to learn the score table policy #4

nskh opened this issue Jul 11, 2023 · 2 comments

Comments

@nskh
Copy link

nskh commented Jul 11, 2023

I'm trying to retrain VerticalCAS on my own machine but the trained behavior seems to converge to much simpler policies that just linearly partition the input space and do not match the score table's original behavior (see images below, one of which is from @kjulian3's DASC '19 paper on VerticalCAS and HorizontalCAS). I couldn't get the provided Julia visualization code to work (see #3), so I wrote my own in Python, which seems to work well since using it to plot the score table does match the image from the DASC '19 paper. Note that training does seem to converge; I'm seeing the values plateau for both loss around 0.0042 and accuracy around 0.9826.

I've tried to train using both the master and dasc branches, though there are issues with tensorflow 1 vs 2 compatibility so I had to modify some code locally. Are there any other versions of this code or alternative training approaches that can reproduce the intended behavior? I can't think of anything else to try.

Score table behavior, from DASC '19:
Screenshot 2023-07-10 at 4 26 37 PM

Score table behavior, plotted with my code:
Screenshot 2023-07-11 at 11 21 50 AM

Trained policy after 100 training epochs on master (dasc branch with tensorflow v2 adaptations is very similar):
Screenshot 2023-07-11 at 11 19 04 AM

@smkatz12
Copy link

It seems like there may be a mismatch between your network plotting code and the 0.9826 accuracy you are observing for the network. I would expect that the network plot should match much more closely to the score table plot if that accuracy is correct. Maybe double check your plotting code for the network policy?

@nskh
Copy link
Author

nskh commented Jul 14, 2023

Thank you! I was feeding my networks un-normalized inputs, so everything I was plotting was "way out there" in the COC region. Normalization fixed my issues. Thanks!

@nskh nskh closed this as completed Jul 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants