Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请教一下关于loss部分的修改 #74

Open
maxwell-five opened this issue Sep 16, 2023 · 1 comment
Open

请教一下关于loss部分的修改 #74

maxwell-five opened this issue Sep 16, 2023 · 1 comment

Comments

@maxwell-five
Copy link

我目前将其他模型移植到BasicTS,但是面临问题是移植模型的loss部分需要多个损失项的相加,loss = loss1(data1)+loss2(data2)+....,但BasicTS框架中losses.py由于自定义损失函数被包装了,传参只有input_data, target_data两个,是否有解决方案?

@zezhishao
Copy link
Owner

您可以参考STEP的loss
总的来说,目前自定义loss函数的参数输入需要满足下述限制:

def customized_loss(prediction, real_value, other_param_1, other_param_2, ..., null_val=np.nan):
    # main loss
    pass

其中,prediction, real_value, other_param_1, other_param_2, ...,这些参数是和runner的forward函数的返回值相匹配的。换句话说,runner的返回值会自动作为参数注入到loss中。同理您可以参考STEP的runner

最后一个参数,null_val是用来识别数据集中需要被忽略的异常点,默认一般为np.nan。您可以通过CFG.NULL_VAL在配置文件中进行设定。例如,对于交通数据集来说,0值一般是异常值(传感器宕机)。我们不希望模型强制拟合这些异常值,此时的NULL_VAL就会被设定为0.0,再进一步采用masked_mae等指标。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants