In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_path_list, label_list, transforms=None):
        self.label_list = label_list
        self.transforms = transforms
        self.img_path_list = img_path_list
        
    def __getitem__(self, index):        
        images = self.get_frames(self.img_path_list[index])
                        
        if self.transforms is not None:
            res = self.transforms(**images)
            images = torch.zeros((len(images), 3, CFG["IMG_SIZE"], CFG["IMG_SIZE"]))
            images[0, :, :, :] = res["image"]
            for i in range(1, len(images)):
                images[i, :, :, :] = res[f"image{i}"]

        if self.label_list is not None:
            label = self.label_list[index]
            return images, label
        else:
            return images

    def __len__(self):
        return len(self.img_path_list) 
    
    def get_frames(self, path):
        cap = cv2.VideoCapture(path)
        frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        imgs = []        
        for fidx in range(frames):
            _, img = cap.read()            
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            imgs.append(img)
        
        ret = {f"image{i}":imgs[i] for i in range(1, len(imgs))}
        ret['image'] = imgs[0]

        return ret


In [1]:
import albumentations as A
from ipywidgets import interact
from matplotlib import pyplot as plt
import numpy as np 
import cv2



In [42]:
def get_frames(path):
    cap = cv2.VideoCapture(path)
    frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    imgs = []        
    for fidx in range(frames):
        _, img = cap.read()            
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        imgs.append(img)
    
    ret = {f"image{i}":imgs[i] for i in range(1, len(imgs))}
    ret['image'] = imgs[0]
    return ret

def aug(transforms, images):
    res = transforms(**images)
    images = np.zeros((len(images), 180, 320, 3), dtype=np.uint8)
    images[0, :, :, :] = res["image"]
    for i in range(1, len(images)):
        images[i, :, :, :] = res[f"image{i}"]
    return images, res

In [46]:
transforms = A.Compose([
    A.Resize(height=180, width=320),
    A.VerticalFlip(p=1),
], p=1, additional_targets={f"image{i}":"image" for i in range(1, 50)})

frames = get_frames("./train/TRAIN_0001.mp4")
frames,res = aug(transforms, frames)

@interact(frame=(0, len(frames)-1))
def show_frame(frame=0):
    plt.imshow(frames[frame,:,:,:])

interactive(children=(IntSlider(value=0, description='frame', max=49), Output()), _dom_classes=('widget-intera…