In [83]:
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
# import transform
from torchvision import transforms
import math
import numpy as np
from tqdm import tqdm
from PIL import Image
from IPython.display import HTML
from base64 import b64encode
import subprocess
import os

In [33]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
NUM_CLASSES = 4
CLASSES = [
    '__background__',
    'basketball', 'hoop', 'person',
]
model = fasterrcnn_resnet50_fpn(
        weights='DEFAULT'
    )
in_features = model.roi_heads.box_predictor.cls_score.in_features
    # define a new head for the detector with required number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES) 
checkpoint = torch.load("fastercnn-pytorch-training-pipeline/outputs/training/basketball_detect_training/best_model.pth")
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE).eval();

In [48]:
def get_prediction(model, imgs):
    # convert list of images to a batch of tensor
    imgs = torch.stack(imgs).to(DEVICE)
    
    with torch.no_grad():
        predictions = model(imgs)
    return predictions

def convert_detections(
    outputs, 
    detection_threshold, 
    classes,
    args
):
    """
    Return the bounding boxes, scores, and classes.
    """
    boxes = outputs['boxes'].data.cpu().numpy()
    scores = outputs['scores'].data.cpu().numpy()


    boxes = boxes[scores >= detection_threshold].astype(np.int32)
    draw_boxes = boxes.copy()
    # Get all the predicited class names.
    pred_classes = [classes[i] for i in outputs['labels'].cpu().numpy()]

    return draw_boxes, pred_classes, scores


In [77]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((640, 640)),
    transforms.ToTensor()
])

video_path = "clips/20230705_Game6.mp4"
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
batch_size = 64
num_batches = math.ceil(total_frames / batch_size)
results = []

frames = []
for i in tqdm(range(num_batches)):
    frame_count = 0
    #frames = []
    for i in range(batch_size):
        ret, img = cap.read()
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if ret:
            frames.append(transform(img))
        else:
            break
    break

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:03<?, ?it/s]


In [40]:
predictions = get_prediction(model, frames)

In [113]:


# Assuming DEVICE is defined (e.g., torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Assuming 'model' is already defined and loaded
CLASSES = [
    '__background__',  # Typically, 0 is for background
    'basketball', 'hoop', 'person',
]
def rescale_boxes(boxes, original_width, original_height, model_width=640, model_height=640):
    # Calculate scale factors
    x_scale = original_width / model_width
    y_scale = original_height / model_height

    rescaled_boxes = []
    for box in boxes:
        x1, y1, x2, y2 = box
        x1 = x1 * x_scale
        y1 = y1 * y_scale
        x2 = x2 * x_scale
        y2 = y2 * y_scale
        rescaled_boxes.append([x1, y1, x2, y2])

    return rescaled_boxes

def draw_boxes(img, boxes, labels, scores, frame_width, frame_height, threshold=0.5):
    # Rescale boxes to original frame size
    rescaled_boxes = rescale_boxes(boxes, frame_width, frame_height, 640, 640)

    # Draw bounding box with label and probability
    for box, label, score in zip(rescaled_boxes, labels, scores):
        if score > threshold:
            x1, y1, x2, y2 = map(int, box)
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            label_text = f'{CLASSES[label]}: {score:.2f}'
            cv2.putText(img, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    return img

def get_prediction(model, imgs):
    # Convert list of images to a batch of tensor
    imgs = torch.stack(imgs).to(DEVICE)
    with torch.no_grad():
        predictions = model(imgs)
    return predictions

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((640, 640)),
    transforms.ToTensor()
])

video_path = "clips/20230705_Game6.mp4"
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
fps = int(cap.get(cv2.CAP_PROP_FPS))
batch_size = 64
num_batches = math.ceil(total_frames / batch_size)
results = []
video_name = video_path.split("/")[-1]
output_path = "inferenced_" + video_name
codec = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, codec, fps, (frame_width,frame_height))

for i in tqdm(range(num_batches)):
    frames = []
    for _ in range(batch_size):
        ret, img = cap.read()
        if ret:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            frames.append(transform(img))
        else:
            print("No more frames to read or error reading frame")  # Debug print
            break
    
    if frames:
        predictions = get_prediction(model, frames)

        # Process each frame in the batch
        for img, pred in zip(frames, predictions):
            img = transforms.functional.to_pil_image(img)
            img = transforms.functional.resize(img, (frame_width, frame_height))
            img = np.array(img)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

            # Draw bounding boxes
            boxes = pred['boxes'].to('cpu').numpy()
            labels = pred['labels'].to('cpu').numpy()
            scores = pred['scores'].to('cpu').numpy()
            
            img_with_boxes = draw_boxes(img, boxes, labels, scores, frame_width, frame_height)
            out.write(img_with_boxes)
    else:
        break

out.release()
cap.release()


  0%|          | 0/6 [00:00<?, ?it/s]

[[[187 196 224]
  [189 196 224]
  [192 197 226]
  ...
  [ 46  43  50]
  [ 54  51  58]
  [ 72  69  76]]

 [[187 196 224]
  [189 196 224]
  [192 197 226]
  ...
  [ 51  48  55]
  [ 61  58  65]
  [ 76  73  80]]

 [[187 196 224]
  [189 196 224]
  [192 197 226]
  ...
  [ 61  57  64]
  [ 74  71  78]
  [ 84  80  87]]

 ...

 [[ 51  87 155]
  [ 51  87 155]
  [ 51  87 155]
  ...
  [105 124 177]
  [103 121 175]
  [100 118 172]]

 [[ 51  87 155]
  [ 51  87 155]
  [ 51  87 155]
  ...
  [105 124 177]
  [103 121 175]
  [100 118 172]]

 [[ 51  87 155]
  [ 51  87 155]
  [ 51  87 155]
  ...
  [105 124 177]
  [103 121 175]
  [100 118 172]]]
[[[187 196 224]
  [189 196 224]
  [192 197 226]
  ...
  [ 46  43  50]
  [ 54  51  58]
  [ 72  69  76]]

 [[187 196 224]
  [189 196 224]
  [192 197 226]
  ...
  [ 51  48  55]
  [ 61  58  65]
  [ 76  73  80]]

 [[187 196 224]
  [189 196 224]
  [192 197 226]
  ...
  [ 61  57  64]
  [ 74  71  78]
  [ 84  80  87]]

 ...

 [[ 51  87 155]
  [ 51  87 155]
  [ 51  87 155]
  ..

 17%|█▋        | 1/6 [00:05<00:28,  5.61s/it]

[[[189 196 223]
  [189 196 223]
  [190 197 224]
  ...
  [ 55  44  53]
  [ 61  50  59]
  [ 77  66  75]]

 [[189 196 223]
  [189 196 223]
  [190 197 224]
  ...
  [ 60  49  58]
  [ 67  56  65]
  [ 81  70  79]]

 [[189 196 223]
  [189 196 223]
  [190 197 224]
  ...
  [ 69  58  67]
  [ 79  68  77]
  [ 90  79  88]]

 ...

 [[ 52  89 154]
  [ 52  89 154]
  [ 52  89 154]
  ...
  [105 125 176]
  [103 123 172]
  [ 97 118 169]]

 [[ 51  88 153]
  [ 51  88 153]
  [ 51  88 153]
  ...
  [106 126 176]
  [104 124 173]
  [ 98 119 170]]

 [[ 51  88 153]
  [ 51  88 153]
  [ 51  88 153]
  ...
  [106 126 176]
  [104 124 174]
  [ 98 119 170]]]
[[[189 196 223]
  [189 196 223]
  [190 197 224]
  ...
  [ 55  44  53]
  [ 61  50  59]
  [ 77  66  75]]

 [[189 196 223]
  [189 196 223]
  [190 197 224]
  ...
  [ 60  49  58]
  [ 67  56  65]
  [ 81  70  79]]

 [[189 196 223]
  [189 196 223]
  [190 197 224]
  ...
  [ 69  58  67]
  [ 79  68  77]
  [ 90  79  88]]

 ...

 [[ 52  89 154]
  [ 52  89 154]
  [ 52  89 154]
  ..

 33%|███▎      | 2/6 [00:09<00:19,  4.81s/it]

[[[189 196 223]
  [191 198 225]
  [191 198 225]
  ...
  [ 55  44  53]
  [ 61  50  59]
  [ 77  66  75]]

 [[189 196 223]
  [191 198 225]
  [191 198 225]
  ...
  [ 60  49  58]
  [ 67  56  65]
  [ 81  70  79]]

 [[189 196 223]
  [190 197 224]
  [191 198 225]
  ...
  [ 69  58  67]
  [ 79  68  77]
  [ 90  79  88]]

 ...

 [[ 54  90 158]
  [ 53  89 157]
  [ 53  89 157]
  ...
  [108 127 179]
  [105 125 175]
  [ 99 120 172]]

 [[ 54  90 158]
  [ 52  88 156]
  [ 52  88 156]
  ...
  [109 128 180]
  [106 126 176]
  [100 120 172]]

 [[ 54  90 158]
  [ 52  88 156]
  [ 52  88 156]
  ...
  [110 129 180]
  [106 126 176]
  [100 120 172]]]
[[[189 196 223]
  [191 198 225]
  [191 198 225]
  ...
  [ 55  44  53]
  [ 61  50  59]
  [ 77  66  75]]

 [[189 196 223]
  [191 198 225]
  [191 198 225]
  ...
  [ 60  49  58]
  [ 67  56  65]
  [ 81  70  79]]

 [[189 196 223]
  [190 197 224]
  [191 198 225]
  ...
  [ 69  58  67]
  [ 79  68  77]
  [ 90  79  88]]

 ...

 [[ 55  91 159]
  [ 54  90 158]
  [ 54  90 158]
  ..

 50%|█████     | 3/6 [00:14<00:14,  4.68s/it]

[[[189 196 223]
  [191 198 225]
  [191 198 225]
  ...
  [ 55  44  53]
  [ 61  50  59]
  [ 77  66  75]]

 [[189 196 223]
  [191 198 225]
  [191 198 225]
  ...
  [ 60  49  58]
  [ 67  56  65]
  [ 81  70  79]]

 [[189 196 223]
  [190 197 224]
  [191 198 225]
  ...
  [ 69  58  67]
  [ 79  68  77]
  [ 90  79  88]]

 ...

 [[ 52  88 156]
  [ 52  88 156]
  [ 52  88 156]
  ...
  [109 127 180]
  [104 123 175]
  [ 99 119 173]]

 [[ 51  87 155]
  [ 51  87 155]
  [ 51  87 155]
  ...
  [110 128 181]
  [105 124 176]
  [ 99 119 172]]

 [[ 51  87 155]
  [ 51  87 155]
  [ 51  87 155]
  ...
  [110 128 181]
  [105 124 176]
  [ 99 119 172]]]
[[[189 196 223]
  [191 198 225]
  [191 198 225]
  ...
  [ 55  44  53]
  [ 61  50  59]
  [ 77  66  75]]

 [[189 196 223]
  [191 198 225]
  [191 198 225]
  ...
  [ 60  49  58]
  [ 67  56  65]
  [ 81  70  79]]

 [[189 196 223]
  [190 197 224]
  [191 198 225]
  ...
  [ 69  58  67]
  [ 79  68  77]
  [ 90  79  88]]

 ...

 [[ 52  88 156]
  [ 52  88 156]
  [ 52  88 156]
  ..

 67%|██████▋   | 4/6 [00:17<00:08,  4.49s/it]

[[[189 196 223]
  [190 197 224]
  [191 198 225]
  ...
  [ 52  43  50]
  [ 56  49  55]
  [ 69  64  70]]

 [[189 196 223]
  [190 197 224]
  [191 198 225]
  ...
  [ 57  49  55]
  [ 63  56  62]
  [ 74  69  75]]

 [[189 196 223]
  [190 197 224]
  [191 198 225]
  ...
  [ 68  60  66]
  [ 77  71  77]
  [ 85  80  86]]

 ...

 [[ 48  87 152]
  [ 49  88 153]
  [ 50  89 154]
  ...
  [107 124 176]
  [103 122 174]
  [101 120 171]]

 [[ 48  87 152]
  [ 48  87 152]
  [ 50  89 154]
  ...
  [108 125 178]
  [104 123 175]
  [102 121 172]]

 [[ 48  87 152]
  [ 48  87 152]
  [ 50  89 154]
  ...
  [109 125 179]
  [104 123 175]
  [102 121 173]]]
No more frames to read or error reading frame





In [114]:
img

In [115]:
cap = cv2.VideoCapture(output_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
total_frames

0

In [103]:
def display_compressed_video(input_path):
    output_path = "compressed_" + input_path
    os.remove(output_path) if os.path.exists(output_path) else None
    try:
        # Use subprocess to safely call FFmpeg
        subprocess.run(['ffmpeg', '-i', input_path, '-vcodec', 'libx264', output_path], check=True)

        # Read and encode the compressed video
        with open(output_path, 'rb') as file:
            mp4 = file.read()
        data_url = "data:video/mp4;base64," + b64encode(mp4).decode()

        # Display video in HTML
        display_html = f"""
        <video width=800 controls>
            <source src="{data_url}" type="video/mp4">
        </video>
        """
        return HTML(display_html)
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

In [104]:
display_compressed_video('inferenced_20230705_Game6.mp4')

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab