Skip to content

Commit

Permalink
feat(#123): cmd shortcut
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Jun 11, 2020
1 parent d035fda commit 920a1c2
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 2 deletions.
38 changes: 38 additions & 0 deletions stagesepx/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import typing
import uuid
import tempfile
from loguru import logger

from stagesepx.cutter import VideoCutResult, VideoCutRange, VideoCutter
Expand Down Expand Up @@ -208,3 +209,40 @@ def keras_train(
model_path = f"{uuid.uuid4()}.h5"
logger.debug(f"trying to save it to {model_path}")
cl.save_model(model_path, overwrite=overwrite)


def analyse(
video: typing.Union[str, VideoObject],
output_path: str,
pre_load: bool = True,
threshold: float = 0.98,
offset: int = 3,
boost_mode: bool = True,
):
""" designed for https://github.com/williamfzc/stagesepx/issues/123 """

if isinstance(video, str):
video = VideoObject(video, pre_load=pre_load)

cutter = VideoCutter()
res = cutter.cut(video)

stable, unstable = res.get_range(threshold=threshold, offset=offset,)

with tempfile.TemporaryDirectory() as temp_dir:
res.pick_and_save(
stable, 5, to_dir=temp_dir,
)

cl = SVMClassifier()
cl.load(temp_dir)
cl.train()
classify_result = cl.classify(video, stable, boost_mode=boost_mode)

r = Reporter()
r.draw(
classify_result,
report_path=output_path,
unstable_ranges=unstable,
cut_result=res,
)
1 change: 1 addition & 0 deletions stagesepx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TerminalCli(object):
# or, some translations, default value, and so on
one_step = staticmethod(api.one_step)
train = staticmethod(api.keras_train)
analyse = staticmethod(api.analyse)


def main():
Expand Down
8 changes: 7 additions & 1 deletion test/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile

from stagesepx.api import _cut, _classify, one_step, _train
from stagesepx.api import _cut, _classify, one_step, _train, analyse
from stagesepx.reporter import Reporter
from stagesepx.video import VideoObject

Expand Down Expand Up @@ -56,3 +57,8 @@ def test_boost():
report_path=os.path.join(data_home, "report.html"),
cut_result=res,
)


def test_analyse():
with tempfile.NamedTemporaryFile(suffix=".html", mode="w") as f:
analyse(VIDEO_PATH, f.name)
6 changes: 5 additions & 1 deletion test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def test_cli():
logger.info("checking keras trainer ...")
subprocess.check_call(["stagesepx", "train", "output", "output.h5"])
# try to train
subprocess.check_call(["stagesepx", "train", "output", "output.h5", "--epochs", "1"])
subprocess.check_call(
["stagesepx", "train", "output", "output.h5", "--epochs", "1"]
)

# new
subprocess.check_call(["stagesepx", "analyse", VIDEO_PATH, "output"])
shutil.rmtree("output")

0 comments on commit 920a1c2

Please sign in to comment.