You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed a problem in the pipeline.train_val_utils.validate when we are running on crf classifier mode such that the inference function the crf.py returns the predicted target sequence (which is actually predicted class ids), not the probabilities for each class. Then, pipeline.train_val_utils.validate function, we are getting the predicted class ids by running torch.argmax. However, as I said, we do not have the probailities here, instead we directly have the predicted class ids.
I think we can solve this by adding a is_crf argument in the pipeline.train_val_utils.validate function and an if block decides to apply the argmax or not.
When using the crf decoder, the seqeval library will be used as the evaluation metric, in this case only the predicted class is needed. Maybe a restriction should be added in the validation step. I will work on it later~
Hi, thanks for your effort.
I noticed a problem in the
pipeline.train_val_utils.validate
when we are running oncrf
classifier mode such that theinference
function thecrf.py
returns the predicted target sequence (which is actually predicted class ids), not the probabilities for each class. Then,pipeline.train_val_utils.validate
function, we are getting the predicted class ids by runningtorch.argmax
. However, as I said, we do not have the probailities here, instead we directly have the predicted class ids.I think we can solve this by adding a
is_crf
argument in thepipeline.train_val_utils.validate
function and an if block decides to apply the argmax or not.Also, irrelevant but there is a type here
ViBERTgrid-PyTorch/model/field_type_classification_head.py
Line 474 in 97af769
Thanks, sincerely.
The text was updated successfully, but these errors were encountered: