In [15]:
import torch
from Model.model import ModelSTGCN
import cv2
import moviepy.editor as mpy
import mmcv
from mmdet.apis import inference_detector, init_detector
from mmpose.apis import inference_top_down_pose_model, init_pose_model,vis_pose_result
import numpy as np
import shutil
import os.path as osp
import os
import glob
import torch.nn as nn

In [4]:
def extract_frame(video_path):
    dname = 'temp'
    os.makedirs(dname, exist_ok=True)
    frame_tmpl = osp.join(dname, 'img_{:05d}.jpg')
    cap = cv2.VideoCapture(video_path)
    frame_paths = []
    cnt = 0
    while(cap.isOpened()):
        flag, frame = cap.read()
        if flag:
            frame_path = frame_tmpl.format(cnt + 1)
            frame_paths.append(frame_path)
            frame=cv2.resize(frame,(640,480))
            cv2.imwrite(frame_path, frame)
            cnt += 1
        else: break
    cap.release()
    return frame_paths

# pose_config = 'mmpose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/hrnet_w48_coco_256x192.py'
# pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
# ###########
# det_config = 'Pose/yolox_s_8x8_300e_coco.py'
# det_checkpoint = 'Pose/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth'
###############
pose_config = 'Pose/hrnet_w48_coco_256x192.py'
pose_checkpoint = 'Pose/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
###########
det_config = 'Pose/yolox_s_8x8_300e_coco.py'
det_checkpoint = 'Pose/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth'
##############
# initialize pose model
pose_model = init_pose_model(pose_config, pose_checkpoint)
# initialize detector
det_model = init_detector(det_config, det_checkpoint)
def detection_inference(det_config, det_checkpoint ,frame_paths, det_score_thr=0.5,device='cuda' ):
    model = init_detector(det_config, det_checkpoint, device)
    assert model.CLASSES[0] == 'person', ('We require you to use a detector '
                                          'trained on COCO')
    results = []
    print('Performing Human Detection for each frame')
    prog_bar = mmcv.ProgressBar(len(frame_paths))
    for frame_path in frame_paths:
        result = inference_detector(model, frame_path)
        # We only keep human detections with score larger than det_score_thr
        result = result[0][result[0][:, 4] >= det_score_thr]
        results.append(result)
        prog_bar.update()
    return results

def pose_inference(pose_config,pose_checkpoint, frame_paths,image_shape, det_results, device='cuda'):
    model = init_pose_model(pose_config, pose_checkpoint, device)
    print('Performing Human Pose Estimation for each frame')
    prog_bar = mmcv.ProgressBar(len(frame_paths))

    num_frame = len(det_results)
    num_person = max([len(x) for x in det_results])
    if num_person == 0:
        kp = np.zeros((1, num_frame, 17, 3), dtype=np.float32)
        return kp    
    kp = np.zeros((num_person, num_frame, 17, 3), dtype=np.float32)
    for i, (f, d) in enumerate(zip(frame_paths, det_results)):
        # Align input format
        if len(d) == 0: 
            prog_bar.update()
            continue
        d = [dict(bbox=x) for x in list(d) if x[-1] > 0.5]
        pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
        vis_ske = vis_pose_result(model,f,pose,dataset=model.cfg.data.test.type,show=False)
        if cv2.waitKey(20)& 0xFF==ord('q'): break
        cv2.imshow('',vis_ske)
        for j, item in enumerate(pose):
            #kp and score (x,y,score) / (widt,heigt,1)
            # normkp = item['keypoints']/(image_shape + (1,)) 
            # kp[j, i] = normkp
            kp[j,i] = item['keypoints']
        prog_bar.update()
    cv2.destroyAllWindows()
    return kp

def pose_extraction(vid,det_config, det_checkpoint,pose_config,pose_checkpoint,label, det_score_thr=0.5,device='cuda'):
    frame_paths = extract_frame(vid)
    det_results = detection_inference(det_config, det_checkpoint ,frame_paths, det_score_thr,device)
    image = cv2.imread(frame_paths[0])
    image_shape = (image.shape[1], image.shape[0])
    pose_results = pose_inference(pose_config,pose_checkpoint, frame_paths,image_shape, det_results, device)
    anno = dict()
    anno['kp'] = pose_results[..., :2]
    anno['kp_score'] = pose_results[..., 2]
    anno['frame_dir'] = osp.splitext(osp.basename(vid))[0]
    anno['img_shape'] = image_shape
    anno['original_shape'] = image_shape
    anno['total_frames'] = pose_results.shape[1]
    anno['label'] = label
    # shutil.rmtree(osp.dirname(frame_paths[0]))
    return anno


load checkpoint from local path: Pose/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth
load checkpoint from local path: Pose/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth


In [None]:
file='FALL/Data_fall_10.mp4'
model=ModelSTGCN(3,2)
model.load_state_dict(torch.load('Model.path'))
model.eval()

In [22]:
a=np.random.random((2,50,17,3))
lst=[x for x in range(50)]
for i in a:
    for frame_i in range(0,len(lst),10):
        tempa = i[frame_i:frame_i+30]
        print(tempa.shape,frame_i+30)

(30, 17, 3) 30
(30, 17, 3) 40
(30, 17, 3) 50
(20, 17, 3) 60
(10, 17, 3) 70
(30, 17, 3) 30
(30, 17, 3) 40
(30, 17, 3) 50
(20, 17, 3) 60
(10, 17, 3) 70


In [None]:
def ActionReg(model: nn.Module = None, file: str = None,
              det_config = det_config , det_checkpoint = det_checkpoint,
              pose_config = pose_config , pose_checkpoint = pose_checkpoint,
              device = 'cuda'):
    assert all(param is not None for param in [model,file,
                                   det_config,det_checkpoint,
                                   pose_config,pose_checkpoint]),"All param must be give in"
    model.to(device)
    labels = []
    frame_paths = extract_frame(file)
    det_results = detection_inference(det_config,det_checkpoint,frame_paths)
    image = cv2.imread(frame_paths[0])
    image_shape = (image.shape[1], image.shape[0])
    leng_frame = len(frame_paths)
    pose_results = pose_inference(pose_config,pose_checkpoint, frame_paths,image_shape, det_results, device)
    for person in pose_results:
        for window in range(0,leng_frame,30):
            temp = torch.from_numpy(person[window:window+30]).to(device)
            outputs = model(temp)
            pred = torch.argmax(outputs,dim=1)
            labels.append(pred)
    return labels
