Skip to content

Commit

Permalink
refactor(#89): add BaseModelClassifier for extending
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Jan 7, 2020
1 parent f0f1ea6 commit 1f996fd
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
31 changes: 30 additions & 1 deletion stagesepx/classifier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ def classify(
frame = self._apply_hook(frame, *args, **kwargs)

# ignore some ranges
if limit_range and not any([each.contain(frame.frame_id) for each in limit_range]):
if limit_range and not any(
[each.contain(frame.frame_id) for each in limit_range]
):
logger.debug(
f"frame {frame.frame_id} ({frame.timestamp}) not in target range, skip"
)
Expand Down Expand Up @@ -366,3 +368,30 @@ def classify(
)
frame = operator.get_frame_by_id(frame.frame_id + step)
return ClassifierResult(final_result)


class BaseModelClassifier(BaseClassifier):
# model
def save_model(self, model_path: str, overwrite: bool = None):
raise NotImplemented

def load_model(self, model_path: str, overwrite: bool = None):
raise NotImplemented

def clean_model(self):
raise NotImplemented

# actions
def train(self):
raise NotImplemented

def predict(self, pic_path: str) -> str:
raise NotImplemented

def predict_with_object(self, frame: np.ndarray) -> str:
raise NotImplemented

def read_from_list(
self, data: typing.List[int], video_cap: cv2.VideoCapture = None, *_, **__
):
raise ValueError("model-like classifier only support loading data from files")
9 changes: 2 additions & 7 deletions stagesepx/classifier/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import numpy as np
from sklearn.svm import LinearSVC

from stagesepx.classifier.base import BaseClassifier
from stagesepx.classifier.base import BaseModelClassifier
from stagesepx import toolbox
from stagesepx.video import VideoFrame
from stagesepx import constants


class SVMClassifier(BaseClassifier):
class SVMClassifier(BaseModelClassifier):
FEATURE_DICT = {
"hog": toolbox.turn_hog_desc,
"lbp": toolbox.turn_lbp_desc,
Expand Down Expand Up @@ -92,11 +92,6 @@ def load_model(self, model_path: str, overwrite: bool = None):
with open(model_path, "rb") as f:
self._model = pickle.load(f)

def read_from_list(
self, data: typing.List[int], video_cap: cv2.VideoCapture = None, *_, **__
):
raise NotImplementedError("svm classifier only support loading data from files")

def train(self):
"""
train your classifier with data. must be called before prediction
Expand Down

0 comments on commit 1f996fd

Please sign in to comment.