Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] GDPL cound not train. #20

Closed
sherlock1987 opened this issue Jun 16, 2020 · 7 comments
Closed

[BUG] GDPL cound not train. #20

sherlock1987 opened this issue Jun 16, 2020 · 7 comments
Labels
bug Something isn't working

Comments

@sherlock1987
Copy link

Describe the bug
When I try to train the model of GDPL, also I loaded the MLE pretrained model, but the loss and results for evluation is always around 0.26. Below is the problem issue, could you guys help me out? Since GDPL is pretty good, and also I plan to set this as my baseline model.

To Reproduce

  1. Go to ploicy/gdpl/train.py and add the arguements --load_model path of MLE. And you could see the results, the loss will become bigger and bigger. This results should look like this:

WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: taxi domain
DEBUG:root:<> epoch 0, loss_real:-0.5383382267836068, loss_gen:-1.5583195904683735
INFO:root:<> epoch 0: saved network to mdl
DEBUG:root:<> weight -3.7587242126464844
DEBUG:root:<> log pi -11.807324409484863
/home/raliegh/视频/convlab2_github_code_theirs/ConvLab-2/convlab2/policy/gdpl/gdpl.py:183: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.
torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
DEBUG:root:<> epoch 0, iteration 0, value, loss 3489.1388260690787
DEBUG:root:<> epoch 0, iteration 0, policy, loss -0.0036238288800967368
DEBUG:root:<> epoch 0, iteration 1, value, loss 3480.9435135690787
DEBUG:root:<> epoch 0, iteration 1, policy, loss -0.09092773252019756
DEBUG:root:<> epoch 0, iteration 2, value, loss 3498.0641061883225
DEBUG:root:<> epoch 0, iteration 2, policy, loss -0.11517706787899921
DEBUG:root:<> epoch 0, iteration 3, value, loss 3488.2195530941613
DEBUG:root:<> epoch 0, iteration 3, policy, loss -0.12360558266702451
DEBUG:root:<> epoch 0, iteration 4, value, loss 3476.682437294408
DEBUG:root:<> epoch 0, iteration 4, policy, loss -0.12722392360630788
INFO:root:<> epoch 0: saved network to mdl
WARNING:root:illegal booking slot: time, slot: attraction domain
DEBUG:root:<> epoch 1, loss_real:-2.1718062476107947, loss_gen:-6.248041303534257
INFO:root:<> epoch 1: saved network to mdl
DEBUG:root:<> weight -9.06725788116455
DEBUG:root:<> log pi -11.601991653442383
DEBUG:root:<> epoch 1, iteration 0, value, loss 1590.3297087016858
DEBUG:root:<> epoch 1, iteration 0, policy, loss -0.0042587477517755405
DEBUG:root:<> epoch 1, iteration 1, value, loss 1590.0544883326481
DEBUG:root:<> epoch 1, iteration 1, policy, loss -0.07637144262461286
DEBUG:root:<> epoch 1, iteration 2, value, loss 1589.7801545795642
DEBUG:root:<> epoch 1, iteration 2, policy, loss -0.09997303185886458
DEBUG:root:<> epoch 1, iteration 3, value, loss 1589.4738512541119
DEBUG:root:<> epoch 1, iteration 3, policy, loss -0.11133970398651927
DEBUG:root:<> epoch 1, iteration 4, value, loss 1589.1489193564967
DEBUG:root:<> epoch 1, iteration 4, policy, loss -0.11775584558123037
INFO:root:<> epoch 1: saved network to mdl
WARNING:root:illegal booking slot: time, domain: hospital
WARNING:root:illegal booking slot: time, slot: attraction domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
DEBUG:root:<> epoch 2, loss_real:-3.781325187948015, loss_gen:-10.217867334683737
INFO:root:<> epoch 2: saved network to mdl
DEBUG:root:<> weight -12.925418853759766
DEBUG:root:<> log pi -12.265064239501953
DEBUG:root:<> epoch 2, iteration 0, value, loss 4830.441213507402
DEBUG:root:<> epoch 2, iteration 0, policy, loss -0.020781385271172775
DEBUG:root:<> epoch 2, iteration 1, value, loss 4839.154656661184
DEBUG:root:<> epoch 2, iteration 1, policy, loss -0.08836260036026176
DEBUG:root:<> epoch 2, iteration 2, value, loss 4831.741853412829
DEBUG:root:<> epoch 2, iteration 2, policy, loss -0.10602868407180435
DEBUG:root:<> epoch 2, iteration 3, value, loss 4824.3883634868425
DEBUG:root:<> epoch 2, iteration 3, policy, loss -0.12300284697036994
DEBUG:root:<> epoch 2, iteration 4, value, loss 4831.304481907895
DEBUG:root:<> epoch 2, iteration 4, policy, loss -0.12597578234578433
INFO:root:<> epoch 2: saved network to mdl
WARNING:root:illegal booking slot: time, domain: attraction
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, domain: hotel
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
DEBUG:root:<> epoch 3, loss_real:-5.254823472764757, loss_gen:-13.987894455591837
INFO:root:<> epoch 3: saved network to mdl
DEBUG:root:<> weight -16.43012809753418
DEBUG:root:<> log pi -11.844439506530762
DEBUG:root:<> epoch 3, iteration 0, value, loss 6681.600123355263
DEBUG:root:<> epoch 3, iteration 0, policy, loss -0.014684114396866215
DEBUG:root:<> epoch 3, iteration 1, value, loss 6697.302657277961
DEBUG:root:<> epoch 3, iteration 1, policy, loss -0.08244152585546927
DEBUG:root:<> epoch 3, iteration 2, value, loss 6687.997532894737
DEBUG:root:<> epoch 3, iteration 2, policy, loss -0.10515823467683635
DEBUG:root:<> epoch 3, iteration 3, value, loss 6690.9089997944075
DEBUG:root:<> epoch 3, iteration 3, policy, loss -0.11676324161357786
DEBUG:root:<> epoch 3, iteration 4, value, loss 6678.3968313116775
DEBUG:root:<> epoch 3, iteration 4, policy, loss -0.12235850389697589
INFO:root:<> epoch 3: saved network to mdl
WARNING:root:illegal booking slot: time, domain: hotel
WARNING:root:illegal booking slot: time, domain: attraction
WARNING:root:illegal booking slot: time, domain: hotel
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: taxi domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: attraction domain
DEBUG:root:<> epoch 4, loss_real:-6.739606408511891, loss_gen:-18.933229109820196
INFO:root:<> epoch 4: saved network to mdl
DEBUG:root:<> weight -21.545021057128906
DEBUG:root:<> log pi -12.236998558044434
DEBUG:root:<> epoch 4, iteration 0, value, loss 16275.491156684027
DEBUG:root:<> epoch 4, iteration 0, policy, loss -0.014838041116793951
DEBUG:root:<> epoch 4, iteration 1, value, loss 16267.9013671875
DEBUG:root:<> epoch 4, iteration 1, policy, loss -0.09151227782583898
DEBUG:root:<> epoch 4, iteration 2, value, loss 16256.190104166666
DEBUG:root:<> epoch 4, iteration 2, policy, loss -0.11655553637279405
DEBUG:root:<> epoch 4, iteration 3, value, loss 16265.713351779514
DEBUG:root:<> epoch 4, iteration 3, policy, loss -0.12722003553062677
DEBUG:root:<> epoch 4, iteration 4, value, loss 16243.192165798611
DEBUG:root:<> epoch 4, iteration 4, policy, loss -0.13666448928415775
INFO:root:<> epoch 4: saved network to mdl
WARNING:root:illegal booking slot: time, domain: hotel
WARNING:root:illegal booking slot: time, domain: hotel
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, domain: taxi
WARNING:root:illegal booking slot: time, slot: taxi domain
WARNING:root:illegal booking slot: time, slot: taxi domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
DEBUG:root:<> epoch 5, loss_real:-7.912413765402401, loss_gen:-22.03384522830739
INFO:root:<> epoch 5: saved network to mdl
DEBUG:root:<> weight -24.468324661254883
DEBUG:root:<> log pi -12.261258125305176
DEBUG:root:<> epoch 5, iteration 0, value, loss 27010.648274739582
DEBUG:root:<> epoch 5, iteration 0, policy, loss -0.013149608030087419
DEBUG:root:<> epoch 5, iteration 1, value, loss 27043.53125
DEBUG:root:<> epoch 5, iteration 1, policy, loss -0.0839987989101145
DEBUG:root:<> epoch 5, iteration 2, value, loss 27066.318250868055
DEBUG:root:<> epoch 5, iteration 2, policy, loss -0.10623834199375576
DEBUG:root:<> epoch 5, iteration 3, value, loss 27043.93825954861
DEBUG:root:<> epoch 5, iteration 3, policy, loss -0.11813025466269916
DEBUG:root:<> epoch 5, iteration 4, value, loss 26953.104600694445
DEBUG:root:<> epoch 5, iteration 4, policy, loss -0.1252221003588703
INFO:root:<> epoch 5: saved network to mdl
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: attraction domain
WARNING:root:illegal booking slot: time, domain: hotel
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, domain: hospital
DEBUG:root:<> epoch 6, loss_real:-9.242614388465881, loss_gen:-24.15580458111233
INFO:root:<> epoch 6: saved network to mdl
DEBUG:root:<> weight -26.42582893371582
DEBUG:root:<> log pi -11.808538436889648
DEBUG:root:<> epoch 6, iteration 0, value, loss 35887.18179481908
DEBUG:root:<> epoch 6, iteration 0, policy, loss -0.020953503682425146
DEBUG:root:<> epoch 6, iteration 1, value, loss 35494.21656558388
DEBUG:root:<> epoch 6, iteration 1, policy, loss -0.08569272891863396
DEBUG:root:<> epoch 6, iteration 2, value, loss 35628.84801603619
DEBUG:root:<> epoch 6, iteration 2, policy, loss -0.10266891509098441
DEBUG:root:<> epoch 6, iteration 3, value, loss 35657.03916529605
DEBUG:root:<> epoch 6, iteration 3, policy, loss -0.11386555943049882
DEBUG:root:<> epoch 6, iteration 4, value, loss 35917.57833059211
DEBUG:root:<> epoch 6, iteration 4, policy, loss -0.11797217848269563
INFO:root:<> epoch 6: saved network to mdl
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: taxi domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
DEBUG:root:<> epoch 7, loss_real:-11.321088128619724, loss_gen:-29.293851852416992
INFO:root:<> epoch 7: saved network to mdl
DEBUG:root:<> weight -32.10945129394531
DEBUG:root:<> log pi -11.713705062866211
DEBUG:root:<> epoch 7, iteration 0, value, loss 44522.42914496528
DEBUG:root:<> epoch 7, iteration 0, policy, loss -0.015966814425256517
DEBUG:root:<> epoch 7, iteration 1, value, loss 44453.58452690972
DEBUG:root:<> epoch 7, iteration 1, policy, loss -0.07723193801939487
DEBUG:root:<> epoch 7, iteration 2, value, loss 44377.24782986111
DEBUG:root:<> epoch 7, iteration 2, policy, loss -0.09828437285290824
DEBUG:root:<> epoch 7, iteration 3, value, loss 44297.86208767361
DEBUG:root:<> epoch 7, iteration 3, policy, loss -0.11189984074897236
DEBUG:root:<> epoch 7, iteration 4, value, loss 44211.8828125
DEBUG:root:<> epoch 7, iteration 4, policy, loss -0.12044301960203382
INFO:root:<> epoch 7: saved network to mdl
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
DEBUG:root:<> epoch 8, loss_real:-14.25956932703654, loss_gen:-33.106894387139214
INFO:root:<> epoch 8: saved network to mdl
DEBUG:root:<> weight -35.563194274902344
DEBUG:root:<> log pi -11.887650489807129
DEBUG:root:<> epoch 8, iteration 0, value, loss 61228.02682976974
DEBUG:root:<> epoch 8, iteration 0, policy, loss -0.019527194384289414
DEBUG:root:<> epoch 8, iteration 1, value, loss 60913.86245888158
DEBUG:root:<> epoch 8, iteration 1, policy, loss -0.08493027012599141
DEBUG:root:<> epoch 8, iteration 2, value, loss 60804.58943256579
DEBUG:root:<> epoch 8, iteration 2, policy, loss -0.10401363087523925
DEBUG:root:<> epoch 8, iteration 3, value, loss 60740.71361019737
DEBUG:root:<> epoch 8, iteration 3, policy, loss -0.11570279148872942
DEBUG:root:<> epoch 8, iteration 4, value, loss 60633.64113898026
DEBUG:root:<> epoch 8, iteration 4, policy, loss -0.12276971943088268
INFO:root:<> epoch 8: saved network to mdl
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
DEBUG:root:<> epoch 9, loss_real:-16.396672407786053, loss_gen:-39.38313462999132
INFO:root:<> epoch 9: saved network to mdl
DEBUG:root:<> weight -42.118408203125
DEBUG:root:<> log pi -11.91506290435791
DEBUG:root:<> epoch 9, iteration 0, value, loss 102404.39268092105
DEBUG:root:<> epoch 9, iteration 0, policy, loss -0.023536940546412217
DEBUG:root:<> epoch 9, iteration 1, value, loss 102286.93421052632
DEBUG:root:<> epoch 9, iteration 1, policy, loss -0.0810224729541101
DEBUG:root:<> epoch 9, iteration 2, value, loss 101849.27960526316
DEBUG:root:<> epoch 9, iteration 2, policy, loss -0.10366031547126017
DEBUG:root:<> epoch 9, iteration 3, value, loss 101598.78638980263
DEBUG:root:<> epoch 9, iteration 3, policy, loss -0.11581830601943166
DEBUG:root:<> epoch 9, iteration 4, value, loss 101350.11461759868
DEBUG:root:<> epoch 9, iteration 4, policy, loss -0.1236358410433719
INFO:root:<> epoch 9: saved network to mdl
WARNING:root:illegal booking slot: time, domain: hotel
WARNING:root:illegal booking slot: time, domain: hotel
WARNING:root:illegal booking slot: time, slot: taxi domain
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: attraction domain
DEBUG:root:<> epoch 10, loss_real:-17.94006437725491, loss_gen:-41.33010853661431
INFO:root:<> epoch 10: saved network to mdl
DEBUG:root:<> weight -43.82462692260742
DEBUG:root:<> log pi -12.179319381713867
DEBUG:root:<> epoch 10, iteration 0, value, loss 111196.29091282895
DEBUG:root:<> epoch 10, iteration 0, policy, loss -0.015502721169277242
DEBUG:root:<> epoch 10, iteration 1, value, loss 108579.41981907895
DEBUG:root:<> epoch 10, iteration 1, policy, loss -0.08138108037804302
DEBUG:root:<> epoch 10, iteration 2, value, loss 108351.37541118421
DEBUG:root:<> epoch 10, iteration 2, policy, loss -0.10115281825787142
DEBUG:root:<> epoch 10, iteration 3, value, loss 109070.85341282895
DEBUG:root:<> epoch 10, iteration 3, policy, loss -0.10706739313900471
DEBUG:root:<> epoch 10, iteration 4, value, loss 108081.73663651316
DEBUG:root:<> epoch 10, iteration 4, policy, loss -0.11929772833460256
INFO:root:<> epoch 10: saved network to mdl
WARNING:root:illegal booking slot: time, slot: hotel domain
WARNING:root:illegal booking slot: time, slot: hotel domain
DEBUG:root:<> epoch 11, loss_real:-22.859329329596626, loss_gen:-50.24238416883681
INFO:root:<> epoch 11: saved network to mdl
DEBUG:root:<> weight -53.37864685058594
DEBUG:root:<> log pi -12.136919975280762
DEBUG:root:<> epoch 11, iteration 0, value, loss 201200.13569078947
DEBUG:root:<> epoch 11, iteration 0, policy, loss -0.023343098202818317
DEBUG:root:<> epoch 11, iteration 1, value, loss 195454.23190789475
DEBUG:root:<> epoch 11, iteration 1, policy, loss -0.09736867954856471
DEBUG:root:<> epoch 11, iteration 2, value, loss 199148.953125
DEBUG:root:<> epoch 11, iteration 2, policy, loss -0.10236057227379397
DEBUG:root:<> epoch 11, iteration 3, value, loss 203306.05283717104
DEBUG:root:<> epoch 11, iteration 3, policy, loss -0.10679333225676887
DEBUG:root:<> epoch 11, iteration 4, value, loss 197667.32565789475
DEBUG:root:<> epoch 11, iteration 4, policy, loss -0.12387701702353202
INFO:root:<> epoch 11: saved network to mdl

Thank you guys, have a good day! Appreciate your help.

@sherlock1987 sherlock1987 added the bug Something isn't working label Jun 16, 2020
@liangrz15
Copy link
Contributor

Hi, for this moment, the GDPL model has slight improvement over the pretrained MLE model at the beginning epochs. However, the performance will drop later. We will solve this problem as soon as possible.

@sherlock1987
Copy link
Author

Thanks Bro

@sherlock1987
Copy link
Author

Is there any clue? We could fix this problem together. I believe the reward estimator has some problems, since loss func is based on that extimator.

@sherlock1987
Copy link
Author

Hey, is anyone start looking at this?

@liangrz15
Copy link
Contributor

Hey, is anyone start looking at this?

Yes, I am working on it.

@sherlock1987
Copy link
Author

Cool!

@zqwerty
Copy link
Member

zqwerty commented Jul 16, 2020

move to #54

@zqwerty zqwerty closed this as completed Jul 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants