In [1]:
import os
import time
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import torchvision
from torchvision import transforms

In [2]:
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['image.cmap'] = 'binary'

def display_video(path, loop=True):
    mp4 = open(path, 'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return display.HTML(f'''<video src="{
        data_url}" controls=true autoplay=true {
        "loop=true " if loop else ""}/>''')

In [3]:
dataset_path = 'datasets/jams-germs'
vid_fnames = os.listdir(f'{dataset_path}/raw-videos')
resize = transforms.Compose([transforms.ToTensor(),
    transforms.Resize(512, transforms.InterpolationMode.BICUBIC)])

## Test batch sampling time from video files

In [4]:
batch_size = 32
batch_frames = 16
batch = []
time_start = time.time()
for i in range(batch_size):
    vid_id = np.random.randint(len(vid_fnames))
    vid_fname = vid_fnames[vid_id]
    vid_path = f'{dataset_path}/raw-videos/{vid_fname}'
    
    vidcap = cv2.VideoCapture(vid_path)
    frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
    start_frame = np.random.randint(frame_count - batch_frames + 1)
    
    vidcap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    batch.append([])
    for _ in range(batch_frames):
        success, img = vidcap.read()
        img = resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        batch[-1].append(img)

print(f'{(time.time() - time_start):.2f}s')
# 18s

18.43s


## Convert first video to frames

In [5]:
time_start = time.time()
for vid_fname in vid_fnames[:1]:
    break
    
    title = '-'.join(vid_fname.split('-')[:-1])
    vid_path = f'{dataset_path}/raw-videos/{vid_fname}'
    frames_path = f'{dataset_path}/frames/{title}'
    os.makedirs(frames_path, exist_ok=True)
    
    vidcap = cv2.VideoCapture(vid_path)
    frame_id = 0
    while True:
        success, img = vidcap.read()
        if not success:
            break
        img = resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        torchvision.utils.save_image(
            img, f'{frames_path}/{frame_id}.jpg')
        frame_id += 1

print(f'{(time.time() - time_start):.2f}s')
# takes about the same time as video length
# 80mb -> 120mb

0.00s


## Test batch sampling time from frames

In [6]:
batch_size = 32
batch_frames = 16
batch = []

frames_path = 'datasets/jams-germs/frames/Earthworm Under Microscope'
frame_count = len(os.listdir(frames_path))
to_tensor = transforms.PILToTensor()

time_start = time.time()
for i in range(batch_size):    
    start_frame = np.random.randint(frame_count - batch_frames + 1)
    batch.append([])
    for j in range(batch_frames):
        img = Image.open(f'{frames_path}/{start_frame + j}.jpg')
        batch[-1].append(to_tensor(img))

print(f'{(time.time() - time_start):.2f}s')
# 3s

2.48s
