-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(test): case structure improvement
- Loading branch information
1 parent
68707ec
commit 2c51277
Showing
5 changed files
with
131 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from stagesepx.classifier import SSIMClassifier, SVMClassifier | ||
from stagesepx.reporter import Reporter | ||
|
||
from test_cutter import test_default as cutter_default | ||
from test_cutter import RESULT_DIR as CUTTER_RESULT_DIR | ||
|
||
import os | ||
|
||
PROJECT_PATH = os.path.dirname(os.path.dirname(__file__)) | ||
VIDEO_PATH = os.path.join(PROJECT_PATH, 'demo.mp4') | ||
MODEL_PATH = os.path.join(PROJECT_PATH, 'model.pkl') | ||
|
||
# cut, and get result dir | ||
cutter_res = cutter_default() | ||
|
||
|
||
def _draw_report(res): | ||
r = Reporter() | ||
report_path = os.path.join(CUTTER_RESULT_DIR, 'report.html') | ||
r.draw( | ||
res, | ||
report_path=report_path, | ||
) | ||
assert os.path.isfile(report_path) | ||
|
||
|
||
def test_default(): | ||
# --- classify --- | ||
cl = SVMClassifier() | ||
cl.load(CUTTER_RESULT_DIR) | ||
cl.train() | ||
cl.save_model(MODEL_PATH) | ||
classify_result = cl.classify(VIDEO_PATH) | ||
|
||
# --- draw --- | ||
_draw_report(classify_result) | ||
|
||
|
||
def test_ssim_classifier(): | ||
cl = SSIMClassifier() | ||
cl.load(CUTTER_RESULT_DIR) | ||
classify_result = cl.classify(VIDEO_PATH) | ||
|
||
# --- draw --- | ||
_draw_report(classify_result) | ||
|
||
|
||
def test_save_and_load(): | ||
# test save and load | ||
cl = SVMClassifier() | ||
cl.load_model(MODEL_PATH) | ||
classify_result = cl.classify(VIDEO_PATH) | ||
|
||
# --- draw --- | ||
_draw_report(classify_result) | ||
|
||
|
||
def test_work_with_cutter(): | ||
cl = SVMClassifier() | ||
cl.load_model(MODEL_PATH) | ||
stable, _ = cutter_res.get_range() | ||
classify_result = cl.classify( | ||
VIDEO_PATH, | ||
stable, | ||
) | ||
|
||
# --- draw --- | ||
_draw_report(classify_result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from stagesepx.cutter import VideoCutter, VideoCutResult | ||
import os | ||
|
||
PROJECT_PATH = os.path.dirname(os.path.dirname(__file__)) | ||
VIDEO_PATH = os.path.join(PROJECT_PATH, 'demo.mp4') | ||
RESULT_DIR = os.path.join(PROJECT_PATH, 'cut_result') | ||
|
||
|
||
def test_default(): | ||
cutter = VideoCutter() | ||
res = cutter.cut(VIDEO_PATH) | ||
stable, unstable = res.get_range() | ||
assert len(stable) == 3, 'count of stable range is not correct' | ||
|
||
data_home = res.pick_and_save(stable, 5, to_dir=RESULT_DIR) | ||
assert data_home == RESULT_DIR | ||
assert os.path.isdir(data_home), 'result dir not existed' | ||
return res | ||
|
||
|
||
def test_limit(): | ||
cutter = VideoCutter() | ||
res = cutter.cut(VIDEO_PATH) | ||
stable, unstable = res.get_range(limit=3) | ||
# when limit=3, final stage should be ignored. | ||
assert len(stable) == 1, 'count of stable range is not correct' | ||
|
||
|
||
def test_step(): | ||
cutter = VideoCutter(step=2) | ||
res = cutter.cut(VIDEO_PATH) | ||
stable, unstable = res.get_range() | ||
# when limit=3, final stage should be ignored. | ||
assert len(stable) == 1, 'count of stable range is not correct' | ||
|
||
|
||
def test_dump_and_load(): | ||
cutter = VideoCutter() | ||
res = cutter.cut(VIDEO_PATH) | ||
json_path = 'cutter_result.json' | ||
res.dump(json_path) | ||
|
||
res_from_file = VideoCutResult.load(json_path) | ||
assert res.dumps() == res_from_file.dumps() | ||
|
||
|
||
def test_prune(): | ||
cutter = VideoCutter() | ||
res = cutter.cut(VIDEO_PATH) | ||
stable, unstable = res.get_range() | ||
assert len(stable) == 3, 'count of stable range is not correct' | ||
|
||
data_home = res.pick_and_save(stable, 5, prune=0.99) | ||
assert os.path.isdir(data_home), 'result dir not existed' |
This file was deleted.
Oops, something went wrong.