Skip to content

Commit

Permalink
feat(#14): support specific size in compression
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Jul 30, 2019
1 parent 45d475c commit f95a4df
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 24 deletions.
7 changes: 6 additions & 1 deletion stagesepx/classifier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ def __init__(self,


class BaseClassifier(object):
def __init__(self):
def __init__(self,
compress_rate: float = None,
target_size: typing.Tuple[int, int] = None):
self.compress_rate = compress_rate
self.target_size = target_size

self._data: typing.Dict[
str,
typing.Union[
Expand Down
8 changes: 6 additions & 2 deletions stagesepx/classifier/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ def _classify_frame(self,
if not threshold:
threshold = 0.85

frame = toolbox.compress_frame(frame)
frame = toolbox.compress_frame(
frame,
self.compress_rate,
self.target_size,
)

result = list()
for each_stage_name, each_stage_pic_list in self.read(video_cap):
each_result = list()
for target_pic in each_stage_pic_list:
target_pic = toolbox.compress_frame(target_pic)
target_pic = toolbox.compress_frame(target_pic, self.compress_rate, self.target_size)
each_pic_ssim = toolbox.compare_ssim(frame, target_pic)
each_result.append(each_pic_ssim)
ssim = max(each_result)
Expand Down
4 changes: 2 additions & 2 deletions stagesepx/classifier/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def train(self):
train_label = list()
for each_label, each_label_pic_list in self.read():
for each_pic_object in each_label_pic_list:
each_pic_object = toolbox.compress_frame(each_pic_object)
each_pic_object = toolbox.compress_frame(each_pic_object, self.compress_rate, self.target_size)
each_pic_object = self.feature_func(each_pic_object).flatten()
train_data.append(each_pic_object)
train_label.append(each_label)
Expand All @@ -79,7 +79,7 @@ def predict(self, pic_path: str) -> str:
return self.predict_with_object(pic_object)

def predict_with_object(self, pic_object: np.ndarray) -> str:
pic_object = toolbox.compress_frame(pic_object)
pic_object = toolbox.compress_frame(pic_object, self.compress_rate, self.target_size)
pic_object = self.feature_func(pic_object)
pic_object = pic_object.reshape(1, -1)
return self._model.predict(pic_object)[0]
Expand Down
32 changes: 19 additions & 13 deletions stagesepx/cutter.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ def pick_and_save(self,
range_list: typing.List[VideoCutRange],
frame_count: int,
to_dir: str = None,
compress_rate: float = None,

# in kwargs
# compress_rate: float = None,
# target_size: typing.Tuple[int, int] = None,
# to_grey: bool = None,

*args, **kwargs) -> str:
stage_list = list()
for index, each_range in enumerate(range_list):
Expand All @@ -230,25 +235,26 @@ def pick_and_save(self,
for each_frame_id in each_frame_list:
each_frame_path = os.path.join(each_stage_dir, f'{uuid.uuid4()}.png')
each_frame = toolbox.get_frame(cap, each_frame_id - 1)
if compress_rate:
each_frame = toolbox.compress_frame(each_frame, compress_rate)
each_frame = toolbox.compress_frame(each_frame, **kwargs)
cv2.imwrite(each_frame_path, each_frame)
logger.debug(f'frame [{each_frame_id}] saved to {each_frame_path}')

return to_dir


class VideoCutter(object):
def __init__(self, step: int = None, compress_rate: float = None):
def __init__(self,
step: int = None,
# TODO removed in the future
compress_rate: float = None):
if not step:
step = 1
if not compress_rate:
compress_rate = 0.2

self.step = step
self.compress_rate = compress_rate

def convert_video_into_ssim_list(self, video_path: str) -> typing.List[VideoCutRange]:
if compress_rate:
logger.warning('compress_rate has been moved to func `cut`')

def convert_video_into_ssim_list(self, video_path: str, **kwargs) -> typing.List[VideoCutRange]:
ssim_list = list()
with toolbox.video_capture(video_path) as cap:
# get video info
Expand All @@ -265,10 +271,10 @@ def convert_video_into_ssim_list(self, video_path: str) -> typing.List[VideoCutR
end_frame_id = toolbox.get_current_frame_id(cap)

# compress
start = toolbox.compress_frame(start, compress_rate=self.compress_rate)
start = toolbox.compress_frame(start, **kwargs)

while ret:
end = toolbox.compress_frame(end, compress_rate=self.compress_rate)
end = toolbox.compress_frame(end, **kwargs)
ssim = toolbox.compare_ssim(start, end)
logger.debug(f'ssim between {start_frame_id} & {end_frame_id}: {ssim}')

Expand All @@ -289,10 +295,10 @@ def convert_video_into_ssim_list(self, video_path: str) -> typing.List[VideoCutR

return ssim_list

def cut(self, video_path: str) -> VideoCutResult:
def cut(self, video_path: str, **kwargs) -> VideoCutResult:
logger.info(f'start cutting: {video_path}')
assert os.path.isfile(video_path), f'video [{video_path}] not existed'
ssim_list = self.convert_video_into_ssim_list(video_path)
ssim_list = self.convert_video_into_ssim_list(video_path, **kwargs)
logger.info(f'cut finished: {video_path}')
return VideoCutResult(
video_path,
Expand Down
20 changes: 14 additions & 6 deletions stagesepx/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,22 @@ def turn_lbp_desc(old: np.ndarray, radius: int = None) -> np.ndarray:
return lbp


def compress_frame(old: np.ndarray, compress_rate: float = None, interpolation: int = None) -> np.ndarray:
if not compress_rate:
compress_rate = 0.2
def compress_frame(old: np.ndarray,
compress_rate: float = None,
target_size: typing.Tuple[int, int] = None,
not_grey: bool = None,
interpolation: int = None) -> np.ndarray:
target = turn_grey(old) if not not_grey else old
if not interpolation:
interpolation = cv2.INTER_AREA

grey = turn_grey(old)
return cv2.resize(grey, (0, 0), fx=compress_rate, fy=compress_rate, interpolation=interpolation)
# target size first
if target_size:
return cv2.resize(target, target_size, interpolation=interpolation)
# else, use compress rate
# default rate is 1 (no compression)
if not compress_rate:
return target
return cv2.resize(target, (0, 0), fx=compress_rate, fy=compress_rate, interpolation=interpolation)


def get_timestamp_str() -> str:
Expand Down

0 comments on commit f95a4df

Please sign in to comment.