Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model] Deep model parameter interpretation #883

Merged
merged 17 commits into from
Dec 7, 2022

Conversation

karl-richter
Copy link
Collaborator

Status quo
When the NN uses hidden layers, only the model weights of the first layer are interpreted when using the plot_parameters() function. This could be misleading to a user who wants to interprete the model.

Note: in some places this lead to wrong results in the plots, eg. when a user set highlight_nth_forecast in plot_parameters(), the weights of the lags w.r.t the n'th hidden layer is returned (since the weights are input x hidden1). If the shape of the hidden layer does not match the shape of the output, this also leads to runtime issues.

Change
Instead of using the weights of the first layer, we use a model attribution method to calculate the attributions of the lags w.r.t. each forecast. We use pytorch's captum library for saliency calculation.

@karl-richter karl-richter self-assigned this Oct 21, 2022
@github-actions
Copy link

github-actions bot commented Oct 21, 2022

c8e21bc

Model Benchmark

Benchmark Metric main current diff
AirPassengers SmoothL1Loss 0.00031 0.00032 1.79%
AirPassengers MAE 6.35364 6.37421 0.32%
AirPassengers RMSE 7.68085 7.75532 0.97%
AirPassengers Loss 0.00023 0.00023 1.55%
AirPassengers RegLoss 0 0 0.0%
AirPassengers SmoothL1Loss_val 0.06051 0.06031 -0.32%
AirPassengers MAE_val 85.1099 84.9838 -0.15%
AirPassengers RMSE_val 108.276 108.103 -0.16%
PeytonManning SmoothL1Loss 0.00587 0.00587 0.0%
PeytonManning MAE 0.34839 0.34839 0.0%
PeytonManning RMSE 0.48617 0.48617 0.0%
PeytonManning Loss 0.00464 0.00464 0.0%
PeytonManning RegLoss 0 0 0.0%
PeytonManning SmoothL1Loss_val 0.03038 0.03038 -0.0%
PeytonManning MAE_val 0.92518 0.92518 -0.0%
PeytonManning RMSE_val 1.13074 1.13074 -0.0%
YosemiteTemps SmoothL1Loss 0.00086 0.00086 -0.01%
YosemiteTemps MAE 1.43672 1.43794 0.09%
YosemiteTemps RMSE 2.14749 2.14874 0.06%
YosemiteTemps Loss 0.00064 0.00064 -0.0%
YosemiteTemps RegLoss 0 0 0.0%
YosemiteTemps SmoothL1Loss_val 0.00097 0.00096 -0.81%
YosemiteTemps MAE_val 1.71173 1.70236 -0.55%
YosemiteTemps RMSE_val 2.2758 2.26615 -0.42%

Model Training

PeytonManning

YosemiteTemps

AirPassengers

CML watermark

Copy link
Collaborator Author

@karl-richter karl-richter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi

@codecov-commenter
Copy link

codecov-commenter commented Oct 21, 2022

Codecov Report

Merging #883 (c0dc34c) into main (10c0cb0) will increase coverage by 0.03%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #883      +/-   ##
==========================================
+ Coverage   90.26%   90.29%   +0.03%     
==========================================
  Files          21       21              
  Lines        4737     4752      +15     
==========================================
+ Hits         4276     4291      +15     
  Misses        461      461              
Impacted Files Coverage Δ
neuralprophet/plot_model_parameters_matplotlib.py 90.50% <100.00%> (ø)
neuralprophet/plot_utils.py 89.91% <100.00%> (ø)
neuralprophet/utils_torch.py 89.18% <100.00%> (+7.37%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Copy link
Owner

@ourownstory ourownstory left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great looking!
final task: find and update notebooks

neuralprophet/time_net.py Outdated Show resolved Hide resolved
neuralprophet/utils_torch.py Outdated Show resolved Hide resolved
@karl-richter karl-richter marked this pull request as ready for review October 25, 2022 16:06
@karl-richter karl-richter added the status: ready PR is ready to be merged label Oct 25, 2022
@Kevin-Chen0
Copy link
Collaborator

@karl-richter Can you resolve merge conflict and fix flake8? Thx

@Kevin-Chen0 Kevin-Chen0 added status: needs update PR has outstanding comment(s) or PR test(s) that need to be resolved and removed status: ready PR is ready to be merged labels Nov 16, 2022
@Kevin-Chen0 Kevin-Chen0 requested review from Kevin-Chen0 and removed request for Kevin-Chen0 November 17, 2022 00:18
@karl-richter karl-richter added status: needs review PR needs to be reviewed by Reviewer(s) and removed status: needs update PR has outstanding comment(s) or PR test(s) that need to be resolved labels Nov 17, 2022
@github-actions
Copy link

github-actions bot commented Nov 17, 2022

3e28c93

Model Benchmark

Benchmark Metric main current diff
AirPassengers MAE_val 15.2698 15.2698 0.0%
AirPassengers RMSE_val 19.4209 19.4209 0.0%
AirPassengers Loss_val 0.00195 0.00195 0.0%
AirPassengers RegLoss_val 0 0 0.0%
AirPassengers epoch 89 89 0.0%
AirPassengers MAE 9.82902 9.82902 0.0%
AirPassengers RMSE 11.7005 11.7005 0.0%
AirPassengers Loss 0.00056 0.00056 0.0%
AirPassengers RegLoss 0 0 0.0%
AirPassengers time 4.51 5.35 18.63% ⚠️
AirPassengers system_performance 0.8004 0.9326 16.52% ⚠️
AirPassengers system_std 0.0008 0.00855 968.75% ⚠️
PeytonManning MAE_val 0.64636 0.64636 0.0%
PeytonManning RMSE_val 0.79276 0.79276 0.0%
PeytonManning Loss_val 0.01494 0.01494 0.0%
PeytonManning RegLoss_val 0 0 0.0%
PeytonManning epoch 37 37 0.0%
PeytonManning MAE 0.42701 0.42701 0.0%
PeytonManning RMSE 0.57032 0.57032 0.0%
PeytonManning Loss 0.00635 0.00635 0.0%
PeytonManning RegLoss 0 0 0.0%
PeytonManning time 11.81 13.97 18.29% ⚠️
PeytonManning system_performance 0.7942 0.934 17.6% ⚠️
PeytonManning system_std 0.0004 0.01126 2715.0% ⚠️
YosemiteTemps MAE_val 1.72949 1.72949 0.0%
YosemiteTemps RMSE_val 2.27386 2.27386 0.0%
YosemiteTemps Loss_val 0.00096 0.00096 0.0%
YosemiteTemps RegLoss_val 0 0 0.0%
YosemiteTemps epoch 84 84 0.0%
YosemiteTemps MAE 1.45189 1.45189 0.0%
YosemiteTemps RMSE 2.16631 2.16631 0.0%
YosemiteTemps Loss 0.00066 0.00066 0.0%
YosemiteTemps RegLoss 0 0 0.0%
YosemiteTemps time 94.18 116.66 23.87% ⚠️
YosemiteTemps system_performance 0.8008 0.9418 17.61% ⚠️
YosemiteTemps system_std 0.00117 0.01514 1194.02% ⚠️
Model training plots

Model Training

PeytonManning

YosemiteTemps

AirPassengers

@noxan
Copy link
Collaborator

noxan commented Nov 22, 2022

@ourownstory You asked about some notebook updates - is this still open or can we merge this PR?

@noxan noxan dismissed ourownstory’s stale review November 22, 2022 06:07

Most likely outdated, added a comment regarding current status.

Copy link
Collaborator

@noxan noxan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - very clean and structured code, was good to review even without knowing all the specifics :)

Two minor remarks, see comments.

neuralprophet/utils_torch.py Show resolved Hide resolved
neuralprophet/utils_torch.py Outdated Show resolved Hide resolved
@noxan noxan added status: needs update PR has outstanding comment(s) or PR test(s) that need to be resolved and removed status: needs review PR needs to be reviewed by Reviewer(s) labels Dec 3, 2022
@karl-richter karl-richter added status: ready PR is ready to be merged and removed status: needs update PR has outstanding comment(s) or PR test(s) that need to be resolved labels Dec 7, 2022
@noxan noxan merged commit 3e28c93 into main Dec 7, 2022
@noxan noxan deleted the feature/deep_model_interpretation branch December 7, 2022 17:07
@ourownstory ourownstory removed the status: ready PR is ready to be merged label Mar 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AR with hidden layers, highlighted weights: plot importance, not weights
5 participants