### preparation

In [None]:
from PIL import Image
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

In [None]:
def list_files(folder_path):
    return [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

In [None]:
train_patch_path = ''
train_mask_path = ''
val_patch_path = ''
val_mask_path = ''

train_patch_list = list_files(train_patch_path)
train_mask_list = list_files(train_mask_path)
val_patch_list = list_files(val_patch_path)
val_mask_list = list_files(val_mask_path)

In [None]:
label_names = [
    "outside_roi", "tumor", "stroma", "lymphocytic_infiltrate", "necrosis or debris",
    "glandular secretion", "blood", "exclude", "metaplasia NOS", "fat",
    "plasma cell", "other immune infiltrate", "mucoid material", "normal acinus or duct",
    "lymphatics", "undetermined", "nerve", "skin adnexa", "blood vessel", "angioinvasion", "dcis" ]

### crop patch image

In [None]:
def crop_patches(image, output_dir, crop_size=256):

    file_name = os.path.join(train_patch_path, image)
    base_name = os.path.splitext(os.path.basename(image))[0]
    os.makedirs(output_dir, exist_ok=True)

    img = Image.open(file_name)
    crops = [
        (0, 0, 256, 256),
        (256, 0, 512, 256),
        (0, 256, 256, 512),
        (256, 256, 512, 512),
    ]
    for i, box in enumerate(crops, 1):
        cropped = img.crop(box)
        save_name = f'{base_name}_{i}.png'
        save_path = os.path.join(output_dir, save_name)
        cropped.save(save_path)
        print(f'saved: {save_path}')

In [None]:
for i in range(len(train_patch_list)):
    crop_patches(train_patch_list[i], train_patch_output)

### create mask image

In [None]:
def mask_filter_save(mask, output_path, target_label=[1, 2, 20], label_colors = {
        1: (255, 0, 0),   # tumor
        2: (0, 255, 0), # stroma
        20: (0, 0, 255),  # dcis
    }):
    os.makedirs(output_path, exist_ok=True)
    file_name = os.path.join(train_mask_path, mask)
    base_name, ext = os.path.splitext(os.path.basename(file_name))
    # print(base_name)
    mask = np.array(Image.open(file_name))

    
    crops = [
        (0, 0, 256, 256),
        (256, 0, 512, 256),
        (0, 256, 256, 512),
        (256, 256, 512, 512),
    ]

    for idx, (x1, y1, x2, y2) in enumerate(crops, 1):
        patch = mask[y1:y2, x1:x2]
        labels_in_patch, cnt = np.unique(patch, return_counts=True)
        total_pixels = patch.size
        
        found_labels = set(labels_in_patch).intersection(target_label)
        found_cnt = [cnt[i] for i, label in enumerate(labels_in_patch) if label in found_labels]
        found_ratio = sum(found_cnt)/total_pixels
        

        if len(found_labels)>=1 and found_ratio < 0.95:
            rgb_patch = np.zeros((256, 256, 3), dtype=np.uint8)
            for label, color in label_colors.items():
                rgb_patch[patch==label] = color

            vis_img = Image.fromarray(rgb_patch)
            save_name = f'{base_name}_{idx}.png'
            save_path = os.path.join(output_path, save_name)
            vis_img.save(save_path)
            print(f'save at {save_name}')
            # plt.figure(figsize=(4,4))
            # plt.imshow(vis_img)
            # plt.axis('off')
            # plt.show()
        else: 
            continue

In [None]:
for i in range(len(train_mask_list)):
    mask_filter_save(train_mask_list[i], train_mask_output_3)

### create npy files

In [None]:
train_mask_list = list_files(train_mask_path)

In [None]:
# check png -> npy
COLOR_TO_CLASS = {
    (0, 0, 0): 0, # background
    (255, 0, 0) : 1, # tumor
    (0, 255, 0) : 2, # stroma
    (0, 0, 255) : 20 # dcis
}

def convert_png_to_npy(png_file, output_path, visualize=True):
    # print(f"file name: {png_file}")
    png_path = os.path.join(train_mask_path, png_file)
    png = Image.open(png_path).convert('RGB')

    np_array = np.array(png)

    h,w, _ = np_array.shape
    label_mask = np.zeros((h, w), dtype=np.uint8)

    for rgb, class_id in COLOR_TO_CLASS.items():
        mask = np.all(np_array==rgb, axis=-1)
        label_mask[mask] = class_id


    # print out
    # print(f"shape: {label_mask.shape}")
    # print(f"unique class value: {np.unique(label_mask)}")
    # print(label_mask)

    # visualize
    # if visualize:
    #     plt.figure(figsize=(6, 6))
    #     plt.imshow(np_array)  # RGB 이미지 시각화
    #     plt.title("Original RGB Mask (PNG)")
    #     plt.axis('off')
    #     plt.show()

    # save
    os.makedirs(output_path, exist_ok=True)
    npy_filename = os.path.splitext(png_file)[0]+'.npy'
    npy_path = os.path.join(output_path, npy_filename)

    np.save(npy_path, label_mask)
    print(f'saved at {npy_filename}')

In [None]:
for i in range(len(train_mask_list)):
    convert_png_to_npy(train_mask_list[i], output_path)

### generate combined npy

In [None]:
train_npy_list = list_files(train_npy_path)

In [None]:
def convert_combined(npy_file, save_path):
    file_name = os.path.splitext(npy_file)[0]
    img_name = file_name+'.png'
    img_path = os.path.join(train_img_path, img_name)
    npy_path = os.path.join(train_npy_path, npy_file)
    
    # print(f'image name: {img_name}, npy name: {npy_file}')

    img = Image.open(img_path).convert('RGB')
    img_array = np.array(img)

    label_array = np.load(npy_path)

    combined_data={
        'input': img_array,
        'label': label_array
    }

    os.makedirs(save_path, exist_ok=True)
    save = os.path.join(save_path, npy_file)
    np.save(save, combined_data)
    print(f'save at {npy_path}')


In [None]:
for i in range(len(train_npy_list)):
    convert_combined(train_npy_list[i], save_path)