Skip to content

Commit

Permalink
feat: support feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Jul 17, 2019
1 parent d01cae5 commit cacecff
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
13 changes: 8 additions & 5 deletions example/classify_with_svm.py
Expand Up @@ -2,18 +2,21 @@
from stagesepx.reporter import Reporter


cl = SVMClassifier()
# 默认情况下使用 HoG 进行特征提取
# 你可以将其关闭从而直接对原始图片进行训练与测试:feature_type='none'
cl = SVMClassifier(feature_type='hog')

# 基本与SSIM分类器的流程一致
# 但它对数据的要求可能有所差别,具体参见 cut.py 中的描述
cl.load('./cut_result')

# 在加载数据完成之后需要先训练
cl.train()
# 在训练后你可以把模型保存起来
cl.save_model('model.pkl')
# 或者直接读取已经训练好的模型
cl.load_model('model.pkl')

# # 在训练后你可以把模型保存起来
# cl.save_model('model.pkl')
# # 或者直接读取已经训练好的模型
# cl.load_model('model.pkl')

# 开始分类
res = cl.classify('../test.mp4')
Expand Down
27 changes: 20 additions & 7 deletions stagesepx/classifier.py
Expand Up @@ -46,10 +46,6 @@ def load(self, data_home: str):


class SSIMClassifier(_BaseClassifier):
def __init__(self):
# TODO 指定分析算法 (是否进行特征提取等)
pass

def classify(self, video_path: str, threshold: float = None) -> typing.List[ClassifierResult]:
logger.debug(f'classify with {self.__class__.__name__}')
assert self.data, 'should load data first'
Expand Down Expand Up @@ -89,9 +85,25 @@ def classify(self, video_path: str, threshold: float = None) -> typing.List[Clas


class SVMClassifier(_BaseClassifier):
def __init__(self):
FEATURE_DICT = {
'hog': toolbox.turn_hog_desc,
# TODO not implemented
# 'surf': toolbox.turn_surf_desc,

# do not use feature transform
'none': lambda x: x,
}

def __init__(self, feature_type: str = None):
super().__init__()

if not feature_type:
feature_type = 'hog'
if feature_type not in self.FEATURE_DICT:
raise AttributeError(f'no feature func named {feature_type}')
self.feature_func = self.FEATURE_DICT[feature_type]
self._model = None
logger.debug(f'feature function: {feature_type}')

def clean_model(self):
self._model = None
Expand Down Expand Up @@ -129,7 +141,7 @@ def train(self):
logger.debug(f'loading {each_pic} ...')
each_pic_object = cv2.imread(each_pic.as_posix())
each_pic_object = toolbox.compress_frame(each_pic_object)
each_pic_object = toolbox.turn_hog_desc(each_pic_object)
each_pic_object = self.feature_func(each_pic_object).flatten()
train_data.append(each_pic_object)
train_label.append(each_label)
logger.debug('data ready')
Expand All @@ -142,7 +154,8 @@ def predict(self, pic_path: str) -> str:

def predict_with_object(self, pic_object: np.ndarray) -> str:
pic_object = toolbox.compress_frame(pic_object)
pic_object = toolbox.turn_hog_desc(pic_object).reshape(1, -1)
pic_object = self.feature_func(pic_object)
pic_object = pic_object.reshape(1, -1)
return self._model.predict(pic_object)[0]

def classify(self, video_path: str) -> typing.List[ClassifierResult]:
Expand Down

0 comments on commit cacecff

Please sign in to comment.