+ 6.1: Inference on videos

In [1]:
from net import WaterNet
from data import transform
import cv2
from pathlib import Path
import torch
import numpy as np
from einops import rearrange
import pickle
import random

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

## Util fxns

In [3]:
def arr2ten(arr):
    """Convert arr2ten plus scaling"""
    ten = torch.from_numpy(arr) / 255
    ten = rearrange(ten, "h w c -> 1 c h w")
    return ten

def ten2arr(ten):
    """Convert ten2arr plus scaling"""
    arr = ten.cpu().detach().numpy()
    arr = np.clip(arr, 0, 1)
    arr = (arr * 255).astype(np.uint8)
    arr = rearrange(arr, "c h w -> h w c")
    return arr

## Load weights

In [4]:
waternet = WaterNet()

with open("../assets/waternet-exported-state-dict.pt", "rb") as f:
    exported_sd = torch.load(f)

waternet.load_state_dict(exported_sd)

<All keys matched successfully>

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [6]:
waternet = waternet.to(device)
waternet.eval();

## Video in / out setup

In [7]:
srcvids = list(Path("../data/internet/").glob("*.mp4"))

In [8]:
srcvids

[PosixPath('../data/internet/StationMBackground.mp4'),
 PosixPath('../data/internet/2022-05-15 00-15-59_Trim.mp4'),
 PosixPath('../data/internet/2022-05-15 00-15-28_Trim.mp4'),
 PosixPath('../data/internet/Australian Mesophotic Coral Examination - 4K ROV Highlights - FK210409 - 1.mp4'),
 PosixPath('../data/internet/Deep Sea Corals of PIPA - 4K ROV Highlights - Blunt Nose Sixgill Shark - FK171005.mp4'),
 PosixPath('../data/internet/2022-05-15 00-14-32_Trim.mp4'),
 PosixPath('../data/internet/ROV SuBastian Dive 163 - Point Dume Seep - Backyard Deep - 1.mp4'),
 PosixPath('../data/internet/2022-05-15 00-19-09_Trim.mp4'),
 PosixPath('../data/internet/Australian Mesophotic Coral Examination - 4K ROV Highlights - FK210409 - 2.mp4'),
 PosixPath('../data/internet/ROV Dive 419 - Seamount Exploration - World Ocean Day - Part 3.mp4')]

In [9]:
if not Path("../output/internet/").exists():
    Path("../output/internet/").mkdir()

In [10]:
for i in srcvids:
    inpath = i.as_posix()
    outpath = "../output/internet/" + i.stem + "recoloured.mp4"
    
    cap = cv2.VideoCapture(inpath)

    frames_per_second = int(cap.get(cv2.CAP_PROP_FPS))
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    codec = cv2.VideoWriter.fourcc(*"avc1")
    video_writer = cv2.VideoWriter(
        outpath, codec, frames_per_second, (frame_width, frame_height)
    )
    
    frames = 0
    
    print(f"Starting video {i.name}")

    while True:
        retval, image = cap.read()

        if retval is False:
            break

        # Logic start ------------------------------------
        rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        wb, gc, he = transform(rgb)

        rgb_ten = arr2ten(rgb)
        wb_ten =  arr2ten(wb)
        gc_ten =  arr2ten(gc)
        he_ten =  arr2ten(he)

        with torch.no_grad():
            rgb_ten = rgb_ten.to(device)
            wb_ten = wb_ten.to(device)
            gc_ten = gc_ten.to(device)
            he_ten = he_ten.to(device)

            out = waternet(rgb_ten, wb_ten, he_ten, gc_ten)
            out_im = ten2arr(out[0])

        video_writer.write(out_im)
        frames += 1

        if frames % 50 == 0:
            print(f"Processed {frames} frames")

    cap.release()
    video_writer.release()

    cv2.destroyAllWindows()
    
    print(f"Finished video {i.name}")

Starting video StationMBackground.mp4
Processed 50 frames
Processed 100 frames
Processed 150 frames
Processed 200 frames
Processed 250 frames
Processed 300 frames
Processed 350 frames
Finished video StationMBackground.mp4
Starting video 2022-05-15 00-15-59_Trim.mp4
Processed 50 frames
Processed 100 frames
Processed 150 frames
Processed 200 frames
Finished video 2022-05-15 00-15-59_Trim.mp4
Starting video 2022-05-15 00-15-28_Trim.mp4
Processed 50 frames
Processed 100 frames
Finished video 2022-05-15 00-15-28_Trim.mp4
Starting video Australian Mesophotic Coral Examination - 4K ROV Highlights - FK210409 - 1.mp4
Finished video Australian Mesophotic Coral Examination - 4K ROV Highlights - FK210409 - 1.mp4
Starting video Deep Sea Corals of PIPA - 4K ROV Highlights - Blunt Nose Sixgill Shark - FK171005.mp4
Processed 50 frames
Processed 100 frames
Processed 150 frames
Processed 200 frames
Processed 250 frames
Processed 300 frames
Processed 350 frames
Processed 400 frames
Processed 450 frames
P

Seems to take more than one second for each frame.