Skip to content

Commit

Permalink
feat(#6): support loading range as data directly
Browse files Browse the repository at this point in the history
only SSIM classifier
  • Loading branch information
williamfzc committed Jul 23, 2019
1 parent 2f11c47 commit 7ff66f5
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 22 deletions.
61 changes: 48 additions & 13 deletions stagesepx/classifier/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import pathlib
import typing
import cv2
import numpy as np
from loguru import logger

from stagesepx.cutter import VideoCutRange
from stagesepx import toolbox


class StageData(object):
def __init__(self, name: str, pic_list: typing.List[str]):
self.name = name
self.pic_list = pic_list


class ClassifierResult(object):
def __init__(self,
video_path: str,
Expand All @@ -27,21 +22,61 @@ def __init__(self,

class BaseClassifier(object):
def __init__(self):
self.data: typing.Dict[str, typing.List[pathlib.Path]] = dict()
self._data: typing.Dict[
str,
typing.Union[
typing.List[pathlib.Path],
typing.List[int],
]
] = dict()

def load(self, data: typing.Union[str, typing.List[VideoCutRange]], *args, **kwargs):
if isinstance(data, str):
return self.load_from_dir(data, *args, **kwargs)
if isinstance(data, list):
return self.load_from_list(data, *args, **kwargs)
raise TypeError(f'data type error, should be str or typing.List[VideoCutRange]')

def load_from_list(self, data: typing.List[VideoCutRange], frame_count: int = None, *_, **__):
for stage_name, stage_data in enumerate(data):
target_frame_list = stage_data.pick(frame_count)
self._data[str(stage_name)] = target_frame_list

def load(self, data_home: str):
p = pathlib.Path(data_home)
def load_from_dir(self, dir_path: str, *_, **__):
p = pathlib.Path(dir_path)
stage_dir_list = p.iterdir()
for each in stage_dir_list:
# load dir only
if each.is_file():
continue
stage_name = each.name
stage_pic_list = [i.absolute() for i in each.iterdir()]
self.data[stage_name] = stage_pic_list
self._data[stage_name] = stage_pic_list
logger.debug(f'stage [{stage_name}] found, and got {len(stage_pic_list)} pics')

def _classify_frame(self, frame: np.ndarray, *args, **kwargs) -> str:
def read(self, *args, **kwargs):
for stage_name, stage_data in self._data.items():
if isinstance(stage_data[0], pathlib.Path):
yield stage_name, self.read_from_path(stage_data, *args, **kwargs)
elif isinstance(stage_data[0], int):
yield stage_name, self.read_from_list(stage_data, *args, **kwargs)
else:
raise TypeError(f'data type error, should be str or typing.List[VideoCutRange]')

@staticmethod
def read_from_path(data: typing.List[pathlib.Path], *_, **__):
return (cv2.imread(each.as_posix()) for each in data)

def read_from_list(self, data: typing.List[int], video_cap: cv2.VideoCapture = None, *_, **__):
cur_frame_id = toolbox.get_current_frame_id(video_cap)
data = (toolbox.get_frame(video_cap, each - 1) for each in data)
toolbox.video_jump(video_cap, cur_frame_id)
return data

def _classify_frame(self,
frame: np.ndarray,
video_cap: cv2.VideoCapture,
*args, **kwargs) -> str:
raise NotImplementedError('must implement this function')

def classify(self,
Expand All @@ -50,7 +85,7 @@ def classify(self,
step: int = None,
*args, **kwargs) -> typing.List[ClassifierResult]:
logger.debug(f'classify with {self.__class__.__name__}')
assert self.data, 'should load data first'
assert self._data, 'should load data first'

if not step:
step = 1
Expand All @@ -68,7 +103,7 @@ def classify(self,
ret, frame = cap.read()
continue

result = self._classify_frame(frame, *args, **kwargs)
result = self._classify_frame(frame, cap, *args, **kwargs)
logger.debug(f'frame {frame_id} ({frame_timestamp}) belongs to {result}')
final_result.append(ClassifierResult(video_path, frame_id, frame_timestamp, result))
toolbox.video_jump(cap, frame_id + step - 1)
Expand Down
6 changes: 3 additions & 3 deletions stagesepx/classifier/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class SSIMClassifier(BaseClassifier):
def _classify_frame(self,
frame: np.ndarray,
video_cap: cv2.VideoCapture,
threshold: float = None,
*_, **__) -> str:
if not threshold:
Expand All @@ -17,10 +18,9 @@ def _classify_frame(self,
frame = toolbox.compress_frame(frame)

result = list()
for each_stage_name, each_stage_pic_list in self.data.items():
for each_stage_name, each_stage_pic_list in self.read(video_cap):
each_result = list()
for each in each_stage_pic_list:
target_pic = cv2.imread(each.as_posix())
for target_pic in each_stage_pic_list:
target_pic = toolbox.compress_frame(target_pic)
each_pic_ssim = toolbox.compare_ssim(frame, target_pic)
each_result.append(each_pic_ssim)
Expand Down
10 changes: 6 additions & 4 deletions stagesepx/classifier/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import cv2
import os
import pickle
import typing
import numpy as np
from sklearn.svm import LinearSVC

Expand Down Expand Up @@ -55,16 +56,17 @@ def load_model(self, model_path: str, overwrite: bool = None):
with open(model_path, 'rb') as f:
self._model = pickle.load(f)

def read_from_list(self, data: typing.List[int], video_cap: cv2.VideoCapture = None, *_, **__):
raise NotImplementedError('svm classifier only support loading data from files')

def train(self):
if not self._model:
self._model = LinearSVC()

train_data = list()
train_label = list()
for each_label, each_label_pic_list in self.data.items():
for each_pic in each_label_pic_list:
logger.debug(f'loading {each_pic} ...')
each_pic_object = cv2.imread(each_pic.as_posix())
for each_label, each_label_pic_list in self.read():
for each_pic_object in each_label_pic_list:
each_pic_object = toolbox.compress_frame(each_pic_object)
each_pic_object = self.feature_func(each_pic_object).flatten()
train_data.append(each_pic_object)
Expand Down
8 changes: 6 additions & 2 deletions stagesepx/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@ def get_frame_size(video_cap: cv2.VideoCapture) -> typing.Tuple[int, int]:
return int(w), int(h)


def get_frame(video_cap: cv2.VideoCapture, frame_id: int) -> np.ndarray:
def get_frame(video_cap: cv2.VideoCapture, frame_id: int, recover: bool = None) -> np.ndarray:
cur = get_current_frame_id(video_cap)
video_jump(video_cap, frame_id)
ret, frame = video_cap.read()
assert ret
assert ret, f'read frame failed, frame id: {frame_id}'

if recover:
video_jump(video_cap, cur)
return frame


Expand Down

0 comments on commit 7ff66f5

Please sign in to comment.