In [1]:
image_dir = "data/AMD_cropped/image_cropped"
mask_dir = "data/AMD_cropped/mask_cropped"
trimap_dir = "data/AMD_cropped/trimap_cropped"
fg_dir = "data/AMD_cropped/fg_cropped" 
bg_dir = "data/AMD_cropped/bg_cropped"

In [2]:
from torch.utils.data import Dataset
from datasets.MattingDataset import MattingDataset
from utils.data import Data
import torch
from PIL import Image
import numpy
import pandas
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from torch.utils.data import DataLoader, Subset

In [3]:
if not os.path.exists(image_dir):
    os.makedirs(image_dir)
if not os.path.exists(mask_dir):
    os.makedirs(mask_dir)
if not os.path.exists(trimap_dir):
    os.makedirs(trimap_dir)
if not os.path.exists(fg_dir):
    os.makedirs(fg_dir)
if not os.path.exists(bg_dir):
    os.makedirs(bg_dir)

In [4]:
data = Data("AMD")
df = data.dataframe()

In [5]:
RESIZE = 224
class MattingDataset(Dataset):
    def __init__(self, annotations_df: pandas.DataFrame, epoch: int = 0, train: bool = True, transform : A.Compose = None) -> None:
        super(MattingDataset, self).__init__()

        self.epoch = epoch
        self.annotations_df = annotations_df
        self.train = train
        self.transform = transform
        if transform == None:
            self.transform = A.Compose(
        [
            A.RandomCrop(height=RESIZE,width=RESIZE),
            # A.Resize(width=RESIZE, height=RESIZE),
            # A.RandomSizedBBoxSafeCrop(height=RESIZE,width=RESIZE, erosion_rate=.2),
        ],
        additional_targets={
            "image": "image",
            "mask": "image",
            "trimap": "image",
            "fg": "image",
            "bg": "image",
        },
    )
        

    def __len__(self):
        return len(self.annotations_df)

    def __getitem__(self, index):
        image_path = self.annotations_df.iloc[index, 0]
        mask_path = self.annotations_df.iloc[index, 1]
        trimap_path = self.annotations_df.iloc[index, 2]

        image_b = Image.open(image_path)
        mask_b = Image.open(mask_path).convert("L")
        trimap_b = Image.open(trimap_path).convert("L")

        image = numpy.array(image_b).astype(numpy.float32)
        mask = numpy.array(mask_b).astype(numpy.float32)
        trimap = numpy.array(trimap_b).astype(numpy.float32)

        image = image/255
        mask = mask/255

        fg = image*(mask.copy()[:, :, numpy.newaxis])
        bg = image*(1-(mask.copy()[:, :, numpy.newaxis]))

        fg = fg*255
        bg = bg*255

        image = image*255
        mask = mask*255

        transformed = self.transform(image=image, mask=mask, trimap=trimap, fg=fg, bg=bg)
        image = transformed["image"]
        mask = transformed["mask"]
        trimap = transformed["trimap"]
        fg = transformed["fg"]
        bg = transformed["bg"]

        image_filename = os.path.basename(image_path)
        mask_filename = os.path.basename(mask_path)
        trimap_filename = os.path.basename(trimap_path)

        image_filename = image_filename.replace(".", f"_{self.epoch}.")
        mask_filename = mask_filename.replace(".", f"_{self.epoch}.")
        trimap_filename = trimap_filename.replace(".", f"_{self.epoch}.")

        image_save_dir = os.path.join(image_dir, image_filename)
        mask_save_dir = os.path.join(mask_dir, mask_filename)
        trimap_save_dir = os.path.join(trimap_dir, trimap_filename)
        fg_save_dir = os.path.join(fg_dir, image_filename)
        bg_save_dir = os.path.join(bg_dir, image_filename)

        image_save_dir = os.path.abspath(image_save_dir)
        mask_save_dir = os.path.abspath(mask_save_dir)
        trimap_save_dir = os.path.abspath(trimap_save_dir)
        fg_save_dir = os.path.abspath(fg_save_dir)
        bg_save_dir = os.path.abspath(bg_save_dir)

        image = image.astype(numpy.uint8)
        mask = mask.astype(numpy.uint8)
        trimap = trimap.astype(numpy.uint8)
        fg = fg.astype(numpy.uint8)
        bg = bg.astype(numpy.uint8)

        Image.fromarray(image).save(image_save_dir)
        Image.fromarray(mask).save(mask_save_dir)
        Image.fromarray(trimap).save(trimap_save_dir)
        Image.fromarray(fg).save(fg_save_dir)
        Image.fromarray(bg).save(bg_save_dir)
        
        return 1


In [6]:
dataset = MattingDataset(df)

dataloader = DataLoader(dataset, num_workers=32, batch_size=32)

In [7]:
for epoch in range(5):
    for _ in dataloader:
        dataset.epoch = epoch
        pass

KeyboardInterrupt: 