In [122]:
import os
import sys
import pickle
import cv2
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor

sys.path.append("../")

from codes.utils.img_utils import add_bboxes_on_image
from codes.video_handling import videoObj
from codes.utils.time_consist import tube_space_time, filter_pipes, pipe_to_frame_instances
from codes.utils.evaluation import filter_preds_score_video
from codes.time_consistency import run_on_video

In [144]:
video_name = "20181112_rectified_DJI_0043"
video_path = f"../data/v1/videos/{video_name}.avi"

vid = videoObj(video_path)

In [145]:
phase_corr_file = os.path.join(f"../data/phase_correlation/{video_name}_phaseCorr.pkl")
with open(phase_corr_file, 'rb') as phase_corr_pkl:
    phase_corr = pickle.load(phase_corr_pkl)

In [146]:
fold = "train+val"
obj = "watertank"
model_iter = "0014176"

pred_file = f"../output/v1_new_final/faster_rcnn_R_50_FPN_1x/mbg_{fold}_{obj}/{video_name}_preds_model_{model_iter}.pkl"

In [147]:
fold = "train+val"
obj = "watertank"
model_iter = "0014176"

pred_file = f"../output/v1_new_final/faster_rcnn_R_50_FPN_1x/mbg_{fold}_{obj}/{video_name}_preds_model_{model_iter}.pkl"

if os.path.isfile(pred_file):
    print("loading predictions...")
    with open(pred_file, 'rb') as f:
        preds = pickle.load(f) 
    print("done!")
        
else:
    print("computing predictions...")
    
    config_file = "../codes/configs/mosquitoes/faster_rcnn_R_50_FPN_1x.yaml"
        
    cfg = get_cfg()
    cfg.merge_from_file(config_file)
    cfg.MODEL.WEIGHTS = os.path.join(os.path.dirname(pred_file), f"model_{model_iter}.pth")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05  # set the testing threshold for this model
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
    cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 600
    cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 300
    cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN = 50
    cfg.MODEL.RPN.POST_NMS_TOPK_TEST = 50
    
    print(f"weights: {cfg.MODEL.WEIGHTS}")        
    
    predictor = DefaultPredictor(cfg)
    preds = run_on_video(video_path, predictor, every=1)
    
    print("saving predictions...")
    with open(pred_file, 'wb') as f:
        pickle.dump(preds, f)        
    print("done!")
        

loading predictions...
done!


In [148]:
preds_frames = filter_preds_score_video(preds.copy(), 0.5)

In [149]:
preds_tube = filter_preds_score_video(preds.copy(), 0.5)
pipes = tube_space_time(preds_tube, phase_corr, h=vid.videoInfo.getHeight(), w=vid.videoInfo.getWidth())
filtered_pipes = filter_pipes(pipes, cut_exts=0, thr=0.4)
preds_time = pipe_to_frame_instances(filtered_pipes, vid)

In [153]:
def make_video_detections(video, detections, output_path, rescale_factor=0.5):  
    
    w,h = video.videoInfo.getWidthHeight()
    
    output_file = cv2.VideoWriter(
                filename=output_path,
                # some installation of opencv may not support x264 (due to its license),
                # you can try other format (e.g. MPEG)
                fourcc=cv2.VideoWriter_fourcc(*"x264"),
                fps=video.videoInfo.getFrameRateFloat(),
                frameSize=(int(w*rescale_factor), int(h*rescale_factor)),
                isColor=True,
            )    
    
    frame_gen = video.frame_from_video()
        
    for idx, frame in tqdm(enumerate(frame_gen), total=video.videoInfo.getNumberOfFrames()):
        
        if idx < 50:
            continue
        
        
        boxes = detections[f'frame_{idx:04d}']['instances'].get('pred_boxes').to('cpu').tensor        
        res = add_bboxes_on_image(frame, boxes, color=(255,0,0))
        res = cv2.resize(res,(int(w*rescale_factor), int(h*rescale_factor)))
        output_file.write(res)
        
        if idx > 1008:
            break
        
        
    output_file.release()  

In [154]:
vid = videoObj(video_path)
make_video_detections(vid, preds_time, output_path=f'{video_name}_tubes_part.mp4')

HBox(children=(FloatProgress(value=0.0, max=3862.0), HTML(value='')))




In [155]:
vid = videoObj(video_path)
make_video_detections(vid, preds_frames, output_path=f'{video_name}_frames_part.mp4')

HBox(children=(FloatProgress(value=0.0, max=3862.0), HTML(value='')))


