In [1]:
from typing import Tuple, List
from glob import glob

import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
import torch.utils.data
from scipy.ndimage.interpolation import rotate
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet

from data_augmentation import *
from archs import *

In [2]:
class Dataset(torch.utils.data.Dataset):
    """マスクの生成→データセット"""
    def __init__(self,
                 img_paths: List[np.array],
                 mask_paths: List[np.array],
                 train=False):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.train = train

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]
        resize = Resize()

        image = np.load(img_path)
        image = image.astype('float32')/image.max()
        if len(image.shape) == 2:
            image = np.reshape(image, [image.shape[0], image.shape[1], 1])
        image = image.transpose(2, 0, 1) # (channel, width, height) に変換
        image = resize(image)

        mask = np.load(mask_path) # (channel, width, height) になってる
        mask = mask.astype('uint8')
        mask = resize(mask)

        # 普通にdatasetのtransformでimageとmaskをランダムでtransformかけようとすると、
        # imageとmaskそれぞれにrandomがかかるっぽい。
        bounding_only = BoundingOnlyDA(rate=1, classes=4)
        image, mask = bounding_only(image, mask)


        image = torch.from_numpy(image.copy())
        mask = torch.from_numpy(mask.copy())

        return image, mask

In [3]:
img_paths = glob('inputs/image_gamma_1.1_0926/*')
mask_paths = glob('inputs/mask_gamma_1.1_0926/*')

In [4]:
train_data = Dataset(img_paths, mask_paths, train=True)

In [5]:
model = EfficientNetB4NestedUNet(args=4)

Loaded pretrained weights for efficientnet-b4


In [6]:
data = train_data[0][0].view(1,3,256,256)

In [8]:
out = model(data)

In [9]:
out.shape

torch.Size([1, 4, 256, 256])