Skip to content

Commit

Permalink
feat(#22): support FrameSaveHook
Browse files Browse the repository at this point in the history
really good for debug
  • Loading branch information
williamfzc committed Aug 11, 2019
1 parent 58a6a85 commit 0253304
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
4 changes: 2 additions & 2 deletions example/cut_and_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# 你可以通过 thumbnail 将阶段的变化过程转化成一张缩略图,这样可以很直观地看出阶段的变化过程!
# 例如,你希望查看第一个unstable阶段发生了什么
# 这样做能够将转化后的缩略图保存到当前目录下
res.thumbnail(unstable[0], to_dir='.')
# res.thumbnail(unstable[0], to_dir='.')

# 对区间进行采样
data_home = res.pick_and_save(
Expand Down Expand Up @@ -115,7 +115,7 @@
# 你可以将把一些文件夹路径插入到报告中
# 这样你可以很方便地从报告中查看各项相关内容
# 当然,你需要想好这些路径与报告最后所在位置之间的相对位置,以确保他们能够被访问到
r.add_dir_link(data_home)
# r.add_dir_link(data_home)

# 在0.3.2及之后的版本,你可以在报告中加入一些自定义内容 (https://github.com/williamfzc/stagesepx/issues/13)
# r.add_extra('here is title', 'here is content')
Expand Down
32 changes: 28 additions & 4 deletions stagesepx/hook.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import os
from loguru import logger
import cv2

from stagesepx import toolbox


class BaseHook(object):
name = 'hook'

def __init__(self, *_, **__):
# default: dict
self.result = dict()
Expand All @@ -14,12 +16,34 @@ def do(self, frame_id: int, frame: np.ndarray, *_, **__):


class ExampleHook(BaseHook):
name = 'example_hook'

def __init__(self):
# you can handle result by yourself
# change the type, or anything you want
super().__init__()
self.result = dict()

def do(self, frame_id: int, frame: np.ndarray, *_, **__):
frame = toolbox.turn_grey(frame)
self.result[frame_id] = frame.shape


class FrameSaveHook(BaseHook):
""" add this hook, and save all the frames you want to specific dir """
def __init__(self, target_dir: str, *_, **__):
super().__init__(*_, **__)

self.target_dir = target_dir
os.makedirs(target_dir, exist_ok=True)
logger.debug(f'init frame saver, frames will be saved to {target_dir}')

def do(self,
frame_id: int,
frame: np.ndarray,
compress_rate: float = None,
*args, **kwargs):
if not compress_rate:
compress_rate = 0.2

compressed = toolbox.compress_frame(frame, compress_rate=compress_rate)
target_path = os.path.join(self.target_dir, f'{frame_id}.png')
cv2.imwrite(target_path, compressed)

0 comments on commit 0253304

Please sign in to comment.