In [10]:
import pandas as pd
import numpy as np
import shutil
import cv2
from PIL import Image
import os
from sklearn.model_selection import train_test_split

In [5]:
df = pd.read_csv("./train_ratio.csv")

In [6]:
x_train, x_val, y_train, y_val = train_test_split(df["img_path"], df["mask_rle"], random_state=42, test_size=0.2, stratify=df["ratio_type"])

In [7]:
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    img = img.reshape(shape)
    return img

# RLE 인코딩 함수
def rle_encode(mask):

    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [11]:
os.makedirs("data/images/train", exist_ok=True)
os.makedirs("data/images/val", exist_ok=True)

os.makedirs("data/annotations/train", exist_ok=True)
os.makedirs("data/annotations/val", exist_ok=True)


# convert dataset annotation to semantic segmentation map
data_root = 'data'
img_dir = "images"
ann_dir = "annotations"
# define class and plaette for better visualization
palette = [128, 0, 0, 0, 128, 0]

for x, y in zip(x_train, y_train):
  img_name = x.split("/")[-1]
  shutil.copy(os.path.join(data_root, x), os.path.join(data_root, f"{img_dir}/train/{img_name}"))
  img_path = os.path.join(data_root, x)
  img = cv2.imread(img_path) # BGR
  h, w, _ = img.shape
  ann_img = rle_decode(y, (h, w))
  png = Image.fromarray(ann_img).convert('P')
  png.putpalette(palette)
  png.save(os.path.join(data_root, f"{ann_dir}/train/{img_name}"))
  del png


for x, y in zip(x_val, y_val):
  img_name = x.split("/")[-1]
  shutil.copy(os.path.join(data_root, x), os.path.join(data_root, f"{img_dir}/val/{img_name}"))
  img_path = os.path.join(data_root, x)
  img = cv2.imread(img_path) # BGR
  h, w, _ = img.shape
  ann_img = rle_decode(y, (h, w))
  png = Image.fromarray(ann_img).convert('P')
  png.putpalette(palette)
  png.save(os.path.join(data_root, f"{ann_dir}/val/{img_name}"))
  del png