We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
我目前将其他模型移植到BasicTS,但是面临问题是移植模型的loss部分需要多个损失项的相加,loss = loss1(data1)+loss2(data2)+....,但BasicTS框架中losses.py由于自定义损失函数被包装了,传参只有input_data, target_data两个,是否有解决方案?
The text was updated successfully, but these errors were encountered:
您可以参考STEP的loss。 总的来说,目前自定义loss函数的参数输入需要满足下述限制:
def customized_loss(prediction, real_value, other_param_1, other_param_2, ..., null_val=np.nan): # main loss pass
其中,prediction, real_value, other_param_1, other_param_2, ...,这些参数是和runner的forward函数的返回值相匹配的。换句话说,runner的返回值会自动作为参数注入到loss中。同理您可以参考STEP的runner。
最后一个参数,null_val是用来识别数据集中需要被忽略的异常点,默认一般为np.nan。您可以通过CFG.NULL_VAL在配置文件中进行设定。例如,对于交通数据集来说,0值一般是异常值(传感器宕机)。我们不希望模型强制拟合这些异常值,此时的NULL_VAL就会被设定为0.0,再进一步采用masked_mae等指标。
Sorry, something went wrong.
No branches or pull requests
我目前将其他模型移植到BasicTS,但是面临问题是移植模型的loss部分需要多个损失项的相加,loss = loss1(data1)+loss2(data2)+....,但BasicTS框架中losses.py由于自定义损失函数被包装了,传参只有input_data, target_data两个,是否有解决方案?
The text was updated successfully, but these errors were encountered: