Skip to content

Commit

Permalink
feat(#54): support score_threshold
Browse files Browse the repository at this point in the history
IMPORTANT:
these scores are not always precise
at the most of time, we used a tiny train data set for training
which may causes 'liblinear failed to converge'
actually, it can know which one is the target class
but the calculated value may becomes weird
  • Loading branch information
williamfzc committed Sep 8, 2019
1 parent cc899a7 commit b2d5ba7
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion stagesepx/classifier/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,37 @@ class SVMClassifier(BaseClassifier):
# do not use feature transform
'raw': lambda x: x,
}
UNKNOWN_STAGE_NAME = '-2'

def __init__(self,
feature_type: str = None,
score_threshold: float = None,
*args, **kwargs):
"""
init classifier
:param feature_type:
before training, classifier will convert pictures into feature, for better classification.
eg: 'hog', 'lbp' or 'raw'
:param score_threshold:
float, 0 - 1.0, under this value, label -> UNKNOWN_STAGE_NAME
default value is 0 (None)
"""
super().__init__(*args, **kwargs)

# feature settings
if not feature_type:
feature_type = 'hog'
if feature_type not in self.FEATURE_DICT:
raise AttributeError(f'no feature func named {feature_type}')
self.feature_func: typing.Callable = self.FEATURE_DICT[feature_type]
self._model: typing.Optional[LinearSVC] = None
logger.debug(f'feature function: {feature_type}')

# model settings
self._model: typing.Optional[LinearSVC] = None
self.score_threshold: float = score_threshold or 0.
logger.debug(f'score threshold: {self.score_threshold}')

def clean_model(self):
self._model = None

Expand Down Expand Up @@ -128,8 +138,19 @@ def predict_with_object(self, pic_object: np.ndarray) -> str:
pic_object = self.feature_func(pic_object)
pic_object = pic_object.reshape(1, -1)
# scores for each stages
# IMPORTANT:
# these scores are not always precise
# at the most of time, we used a tiny train data set for training
# which may causes 'liblinear failed to converge'
# actually, it can know which one is the target class
# but the calculated value may becomes weird
scores = self._model.decision_function(pic_object)[0]
logger.debug(f'scores: {scores}')
# unknown
if max(scores) < self.score_threshold:
logger.warning(f'max score is lower than {self.score_threshold}, unknown class')
return self.UNKNOWN_STAGE_NAME

return self._model.classes_[np.argmax(scores)]

def _classify_frame(self,
Expand Down

0 comments on commit b2d5ba7

Please sign in to comment.