In [1]:
import albumentations as A
# define the transforms
# This cell may need to be ran twice, ignore the first run error.
transform = A.Compose([
    A.Resize(128, 171, always_apply=True),
    A.CenterCrop(112, 112, always_apply=True),
    A.Normalize(mean = [0.43216, 0.394666, 0.37645],
                std = [0.22803, 0.22145, 0.216989], 
                always_apply=True)
])

In [2]:
names = "action_recognition_kinetics.txt"
nn_model = "resnet-34_kinetics.onnx"
video = "example_activities.mp4"

In [3]:
import torch
import torchvision
import cv2
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt

In [4]:
# parser = argparse.ArgumentParser()

# parser.add_argument('-f', '--file', required=False)
# parser.add_argument('-i', '--input', default = video, help='path to input video')
# parser.add_argument('-l', '--clip-len', dest='clip_len', default=16, type=int,
#                     help='number of frames to consider for each prediction')
# parser.add_argument('-c', '--classes', default = names, help='Path to classes list.')
# parser.add_argument('-m', '--model', default = nn_model, help='Path to model.')

# args = vars(parser.parse_args())

args = {
    "input": video,
    "clip_len": 16,
    "classes": names,
    "model": nn_model,
}
#### PRINT INFO #####
print(f"Number of frames to consider for each prediction: {args['clip_len']}")

Number of frames to consider for each prediction: 16


In [5]:
# get the lables
class_names = open(args["classes"]).read().strip().split("\n")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load the model
model = torchvision.models.video.r3d_18(pretrained=True, progress=True) #This one works.

# load the model onto the computation device
model = model.eval().to(device)

Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /home/psl/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth
100.0%


In [6]:
cap = cv2.VideoCapture(args['input'])
if (cap.isOpened() == False):
    print('Error while trying to read video. Please check path again')
# get the frame width and height
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))

In [7]:
save_name = f"{args['input'].split('/')[-1].split('.')[0]}"

# define codec and create VideoWriter object 
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
out = cv2.VideoWriter(f"filepath/{save_name}.mp4", fourcc, 15.0, (frame_width, frame_height))

In [8]:
frame_count = 0 # to count total frames
total_fps = 0 # to get the final frames per second
# a clips list to append and store the individual frames
clips = []

In [10]:
# read until end of video
while(cap.isOpened()):
    # capture each frame of the video
    ret, frame = cap.read()
    if ret == True:
        # get the start time
        start_time = time.time()
        image = frame.copy()
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = transform(image=frame)['image']
        clips.append(frame)
        if len(clips) == args['clip_len']:
            with torch.no_grad(): # we do not want to backprop any gradients
                input_frames = np.array(clips)
                # add an extra dimension        
                input_frames = np.expand_dims(input_frames, axis=0)
                # transpose to get [1, 3, num_clips, height, width]
                input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
                # convert the frames to tensor
                input_frames = torch.tensor(input_frames, dtype=torch.float32)
                input_frames = input_frames.to(device)
                # forward pass to get the predictions
                outputs = model(input_frames)
                # get the prediction index
                _, preds = torch.max(outputs.data, 1)
                
                # map predictions to the respective class names
                label = class_names[preds].strip()
            # get the end time
            end_time = time.time()
            # get the fps
            fps = 1 / (end_time - start_time)
            # add fps to total fps
            total_fps += fps
            # increment frame count
            frame_count += 1
            wait_time = max(1, int(fps/4))
            cv2.putText(image, label, (15, 25),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2, 
                        lineType=cv2.LINE_AA)
            clips.pop(0)
            
            # plt.imshow(image)
            # plt.show()
            out.write(image)
            # press `q` to exit
            #if cv2.waitKey(wait_time) & 0xFF == ord('q'):
            #    break
    else:
        break

In [12]:
# release VideoCapture()
cap.release()
# close all frames and video windows
# cv2.destroyAllWindows()
# calculate and print the average FPS
avg_fps = total_fps / frame_count
print(f"Average FPS: {avg_fps:.3f}")

Average FPS: 49.738
