In [6]:
!pip install segmentation_models_pytorch

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.3.0-py3-none-any.whl (97 kB)
[K     |████████████████████████████████| 97 kB 244 kB/s eta 0:00:01
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[K     |████████████████████████████████| 58 kB 217 kB/s eta 0:00:01
[?25hCollecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |████████████████████████████████| 376 kB 1.8 MB/s eta 0:00:01
[?25hCollecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
Collecting munch
  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)
Building wheels for collected packages: efficientnet-pytorch, pretrainedmodels
  Building wheel for efficientnet-pytorch (setup.py) ... [?25ldone
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16446 sha256=78878e82e5c58475e168cdaac183bf2d29d6f3abf9117c1042b966adefa2590e
  St

In [7]:
import torch
import numpy as np
from glob import glob
import cv2
import os
import imgaug
import imgaug.augmenters as iaa
from tqdm.auto import tqdm
import pandas as pd
import torchvision.transforms as T
from segmentation_models_pytorch.encoders import get_preprocessing_fn

In [None]:
class ClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, video_dir, config):
        self.mask_paths = sorted(glob(os.path.join(video_dir, 'segmentation/*.png')))
        self.img_paths = [p.replace('segmentation', 'rgb') for p in self.mask_paths]
        self.img_h = config['img_h']
        self.img_w = config['img_w']
        self.aug = config['aug']
        self.seq = iaa.Sequential([
            iaa.Fliplr(0.5),  # 50% horizontal flip
            iaa.Affine(
                rotate=(-15, 15),
                shear=(-10, 10),
                scale={"x": (0.9, 1.1), "y": (0.9, 1.1)},
            ),
        ])
        self.config = config
        self._init_img_preprocess_fn(config)

    def __len__(self):
        return len(self.mask_paths)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.img_paths[idx])[:, :, ::-1]        
        img = self.preprocess(img)
        mask = self.preprocess_mask(mask)

        if self.aug:
            img, mask = self.seq(image=img, segmentation_maps=mask)

        return self._to_torch_tensor(img, mask)
    
    def preprocess(self, img):
        img = cv2.resize(img, (self.img_w, self.img_h))
        return img
    
    def preprocess_mask(self, mask):
        mask = imgaug.augmentables.segmaps.SegmentationMapsOnImage(mask.astype(np.int8), 
                                                                   shape=mask.shape)
        mask = mask.resize((self.img_h, self.img_w))
        return mask

    def _init_img_preprocess_fn(self, config):
        model_type = config['model_type']
        if model_type == 'UNet' and config[model_type]['encoder']['pretrained']:
            transform = T.Compose([
                T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
            # img = transform(img)
        elif model_type == 'smp':
            encoder_name = config[model_type]['encoder_name']
            pretrained = config[model_type]['pretrained'] # used pretrained
            transform = get_preprocessing_fn(encoder_name, pretrained='imagenet')
        else:
            raise ValueError('Not implemented model type preprocess fn')
        self.transform = transform

    def _to_torch_tensor(self, img, mask):
        model_type = self.config['model_type']
        if model_type == 'UNet':
            img = torch.tensor(img, dtype=torch.float).permute(2, 0, 1) / 255.
            img = self.transform(img)

        elif model_type == 'smp':
            img = img / 255.
            img = self.transform(img)
            img = torch.tensor(img, dtype=torch.float).permute(2, 0, 1)
        else:
            raise ValueError('Not implemented model type preprocess fn')
        mask = mask.get_arr()  # to np
        mask = torch.tensor(mask, dtype=torch.long)
        return img, mask



    
def _build_seg_dataset(video_dirs, config):
    datasets = []
    for video_dir in tqdm(video_dirs):
        ds = SegDataset(video_dir, config)
        datasets.append(ds)
    return torch.utils.data.ConcatDataset(datasets)

def build_seg_datasets(config):
    data_config = config['data']
    root_dir = data_config['root_dir']
    train_video_dirs = pd.read_csv(data_config['train_csv'])['video_dir'].tolist()
    val_video_dirs = pd.read_csv(data_config['val_csv'])['video_dir'].tolist()
    train_video_dirs = _add_root_dir(root_dir, train_video_dirs)
    val_video_dirs = _add_root_dir(root_dir, val_video_dirs)
    print('# of Videos: ', len(train_video_dirs), len(val_video_dirs))
    train_ds = _build_seg_dataset(train_video_dirs, config)
    val_ds = _build_seg_dataset(val_video_dirs, config)
    
    return train_ds, val_ds
    

def _add_root_dir(root_dir, dirs):
    return [os.path.join(root_dir, p) for p in dirs]