Skip to content

Commit

Permalink
feat(#18): support comparison between different results of cutter
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Aug 4, 2019
1 parent 0a693bc commit 87a378b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
26 changes: 26 additions & 0 deletions stagesepx/classifier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,32 @@ def load_from_dir(self, dir_path: str, *_, **__):
self._data[stage_name] = stage_pic_list
logger.debug(f'stage [{stage_name}] found, and got {len(stage_pic_list)} pics')

def diff(self, another: 'BaseClassifier') -> typing.Dict[str, typing.Dict[str, float]]:
assert another._data, 'must load data first'
result_dict = dict()

self_data = dict(self.read())
for each_self_stage, each_self_data in self_data.items():
another_data = dict(another.read())
each_self_data_pic = next(each_self_data)

each_stage_dict = dict()
# compare with all the stages
for each_another_stage, each_another_data in another_data.items():
# compare with all the pictures in same stage, and pick the max one
max_ssim = -1
for each_pic in each_another_data:
ssim = toolbox.compare_ssim(
each_pic,
each_self_data_pic,
)
if ssim > max_ssim:
max_ssim = ssim
each_stage_dict[each_another_stage] = max_ssim

result_dict[each_self_stage] = each_stage_dict
return result_dict

def read(self, *args, **kwargs):
for stage_name, stage_data in self._data.items():
if isinstance(stage_data[0], pathlib.Path):
Expand Down
9 changes: 6 additions & 3 deletions stagesepx/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
import typing
import numpy as np
from skimage.filters import threshold_otsu
from skimage.measure import compare_ssim
from skimage.measure import compare_ssim as origin_compare_ssim
from skimage.feature import hog, local_binary_pattern

compare_ssim = compare_ssim


@contextlib.contextmanager
def video_capture(video_path: str):
Expand All @@ -24,6 +22,11 @@ def video_jump(video_cap: cv2.VideoCapture, frame_id: int):
video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id - 1)


def compare_ssim(pic1: np.ndarray, pic2: np.ndarray) -> float:
pic1, pic2 = [turn_grey(i) for i in [pic1, pic2]]
return origin_compare_ssim(pic1, pic2)


def get_current_frame_id(video_cap: cv2.VideoCapture) -> int:
return int(video_cap.get(cv2.CAP_PROP_POS_FRAMES))

Expand Down

0 comments on commit 87a378b

Please sign in to comment.