In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import math
import os
from PIL import Image
from carvana_test.utils import tileImage, padImage

In [2]:
img_path = "C:\\Users\\giant\\Desktop\\aiptasia\\data\\carvana_data\\subset"
mask_path = "C:\\Users\\giant\\Desktop\\aiptasia\\data\\carvana_data\\subset_masks"
file_name = sorted(os.listdir(img_path))[0]; mask_file_name = sorted(os.listdir(mask_path))[0]
img = cv2.imread(os.path.join(img_path, file_name)); img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# mask = cv2.imread(os.path.join(mask_path, mask_file_name))
mask = np.array(Image.open(os.path.join(mask_path, mask_file_name)).convert("L"), dtype=np.float32)

#function to pad an image such that its dimensions are divisible by a given tile size
def padImage(img, tile_size):
    H, W = img.shape[:2] #height/width of original image

    H_pad = 0
    if H % tile_size != 0: 
        H_pad = (tile_size*math.ceil(H / tile_size) - H) // 2

    W_pad = 0
    if W % tile_size != 0:
        W_pad = (tile_size*math.ceil(W / tile_size) - W) // 2

    img = cv2.copyMakeBorder(src=img, top=H_pad, bottom=H_pad, left=W_pad, right=W_pad, borderType=cv2.BORDER_REFLECT)

    return img

#function to read and store an image as a numpy array
def readImage(img_path: str, mask: bool = False) -> np.ndarray:
    if mask: #binary mask image
        img = np.array(Image.open(img_path).convert("L"), dtype=np.float32)
    else: #raw image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img

Using below code cell to save tiles of images and their corresponding masks

In [None]:
# tile_size = 256 

# #looping over all images and masks
# for i, img_file in enumerate(sorted(os.listdir(img_path))):
#     print(i, img_file)
#     img_file_path = os.path.join(img_path, img_file)
#     mask_file_path = os.path.join(mask_path, img_file[:-4] + "_mask.gif")

#     #reading in images/masks
#     img = readImage(img_path=img_file_path)
#     mask = readImage(img_path=mask_file_path, mask=True)

#     #padding for non-overlapping tiles
#     img = padImage(img, tile_size)
#     mask = padImage(mask, tile_size)

#     #tiling img
#     img_tile_dir = os.path.join(img_path + "_tiles", img_file); os.makedirs(name=(img_path + "_tiles"), exist_ok=True)
#     mask_tile_dir = os.path.join(mask_path + "_tiles", img_file); os.makedirs(name=(mask_path + "_tiles"), exist_ok=True)
#     img_tiles = tileImage(img, tile_size, save_tiles=True, tile_dir=img_tile_dir)
#     mask_tiles = tileImage(mask, tile_size, save_tiles=True, tile_dir=mask_tile_dir)

In [None]:
plt.imshow(img); plt.show()
plt.imshow(mask, cmap="binary_r"); plt.show()

In [150]:
tile_size = 256
img = padImage(img, tile_size)

In [None]:
plt.imshow(img); plt.show()
plt.imshow(mask, cmap='binary_r'); plt.show()

In [None]:
#function to obtain a list of tiles given an input image and size
def tileImage(img: np.ndarray, tile_size: int = 572) -> list[np.ndarray]:
    """
    Input: 
    - img: image of shape (height, width, channels)
    - tile_size: size of tiles

    Output:
    - List of tiles (images): 
    """
    tiles = [] #list of tiles

    #image is smaller than given tile size
    if img.shape[0] < tile_size or img.shape[1] < tile_size: 
        print("Image size is smaller than tile size")
        return False

    #getting number of tiles along height/width
    tiles_height = math.floor(img.shape[0] / tile_size)
    if (img.shape[0] % tile_size) != 0: tiles_height += 1 #incrementing if not divisble
    tiles_width = math.floor(img.shape[1] / tile_size)
    if (img.shape[1] % tile_size) != 0: tiles_width += 1 #incrementing if not divisble

    num_tiles = tiles_height * tiles_width #total number of tiles

    print(num_tiles, tiles_height, tiles_width)
    
    for h_tile in range(tiles_height):
        for w_tile in range(tiles_width):
            tile_top = h_tile*tile_size
            tile_bottom = (h_tile + 1)*tile_size
            tile_left = w_tile*tile_size
            tile_right = (w_tile + 1)*tile_size

            #check to avoid tile doesn't surpass image height
            if tile_bottom > img.shape[0]:
                tile_bottom = img.shape[0]
                tile_top = tile_bottom - tile_size
            
            #check to avoid tile doesn't surpass image width
            if tile_right > img.shape[1]: 
                tile_right = img.shape[1]
                tile_left = tile_right - tile_size
            
            tile = img[tile_top:tile_bottom, tile_left:tile_right, :]
            tiles.append(tile)
 
    return tiles

tile_size = 1000
tiles = tileImage(img=img, tile_size=tile_size)
len(tiles)

In [None]:
np.zeros((2, 2), dtype=img.dtype)

In [None]:
def stitchTiles(tiles: list[np.ndarray], img_shape: tuple[int]) -> np.ndarray:
    stiched_img = np.zeros(shape=img_shape, dtype=tiles[0].dtype) #initializing array of zeros as full size of image

    tile_size = tiles[0].shape[0] #size of tiles
    img_height = img_shape[0]; img_width = img_shape[1] #image height/width
    
    #getting number of tiles along height/width
    tiles_height = math.floor(img_height / tile_size)
    if (img_height % tile_size) != 0: tiles_height += 1 #incrementing if not divisble
    tiles_width = math.floor(img_width / tile_size)
    if (img_width % tile_size) != 0: tiles_width += 1 #incrementing if not divisble
    
    i = 0 #iterator

    for h_tile in range(tiles_height):
        for w_tile in range(tiles_width):
            tile_top = h_tile*tile_size
            tile_bottom = (h_tile + 1)*tile_size
            tile_left = w_tile*tile_size
            tile_right = (w_tile + 1)*tile_size

            #check to avoid tile doesn't surpass image height
            if tile_bottom > img.shape[0]:
                tile_bottom = img.shape[0]
                # tile_top = tile_bottom - tile_size
            
            #check to avoid tile doesn't surpass image width
            if tile_right > img.shape[1]: 
                tile_right = img.shape[1]
                # tile_left = tile_right - tile_size

            print(i, tile_top, tile_bottom, tile_left, tile_right)
            stiched_img[tile_top:tile_bottom, tile_left:tile_right, :] = tiles[i][:(tile_bottom - tile_top), :(tile_right - tile_left), :]
            i += 1
    
    return stiched_img

stiched_img = stitchTiles(tiles=tiles, img_shape=img.shape)
plt.imshow(stiched_img); plt.show()

In [None]:
plt.imshow(tiles[2][(1000 - 80):, :, :])
plt.show()

In [None]:
test = np.zeros(shape=img.shape, dtype='uint8')
test[:1000, :1000, :] = tiles[0]
plt.imshow(test)
plt.show()

In [None]:
# ###Do we disregard (i.e. not process) a tile if it does not contain aiptasia?
# plt.imshow(img[0:tile_size, 0:tile_size, :])
# plt.show()
# plt.clf()

# plt.imshow(img[0:tile_size, tile_size:2*tile_size, :])
# plt.show()
# plt.clf()

In [None]:
for tile in tiles:
    plt.imshow(tile)
    plt.show()
    plt.clf()

In [None]:
import platform

#storing image and mask directories
if platform.system() == "Linux":
    root_dir = "/home/pcuriel/data/aiptasia/image_data/water_body_data/"
    image_dir = root_dir + "Images"
    mask_dir = root_dir + "Masks"
elif platform.system() == "Windows": 
    root_dir = "C:\\Users\\giant\\Desktop\\aiptasia\\data\\carvana_data"
    image_dir = root_dir + "\\train"
    mask_dir = root_dir + "\\train_masks"