In [1]:
import numpy as np
import json
import pandas as pd
import re
import os
import matplotlib.pyplot as plt
import imageio.v2 as imageio
import cv2
import shutil
from tqdm import tqdm, trange
import scipy.ndimage as ndimage
import pickle
import SimpleITK as sitk
import h5py

# filename = "/storage/valis_reg/segmentation_masks/1202_msrcr_Simple Segmentation.h5"

def read_h5_mask(filename):
    with h5py.File(filename, "r") as f:
        data = f['exported_data']
        data = np.squeeze(np.array(data), axis=-1)
    return data


# mask = read_h5_mask(filename)
regex = re.compile(r'\d+')


def black2white(img):
    img = img.copy()
    img[img == [0,0,0]] = 255
    return img


In [2]:
retinex_imgs = [imageio.imread(f'/storage/valis_reg/244_processed/msrcr/{img}') for img in tqdm(sorted(os.listdir('/storage/valis_reg/244_processed/msrcr/')))]
orig_imgs = [imageio.imread(f'/storage/valis_reg/244_processed/original/{img}') for img in tqdm(sorted(os.listdir('/storage/valis_reg/244_processed/original/')))]


100%|██████████| 2495/2495 [01:12<00:00, 34.42it/s]
100%|██████████| 2495/2495 [01:06<00:00, 37.66it/s]
100%|██████████| 2496/2496 [01:51<00:00, 22.42it/s]


In [8]:
segmentation_dir = '/storage/valis_reg/segmentation_masks'
segmentation_masks = {f.split('_')[0]: read_h5_mask(os.path.join(segmentation_dir, f)) for f in tqdm(sorted(os.listdir(segmentation_dir))) if f!='2434'}

100%|██████████| 2496/2496 [00:08<00:00, 309.97it/s]


In [12]:
del segmentation_masks['2434']

In [14]:
def invert_mask(mask):
    mask[mask>1] = 0
    return mask
# plt.imshow(invert_mask(segmentation_masks['1202']))

In [15]:
segmentation_masks = dict(sorted(segmentation_masks.items(), key=lambda x: int(x[0])))
retinex_masks = list(segmentation_masks.values())
retinex_masks = [invert_mask(mask) for mask in retinex_masks]

In [16]:
def find_largest_image_dimensions(retinex_imgs):
    largest_r = 0
    largest_c = 0
    for img in retinex_imgs:
        r, c = img.shape[0], img.shape[1]
        if r > largest_r:
            largest_r = r
        if c > largest_c:
            largest_c = c
    return largest_r,largest_c

largest_r, largest_c = find_largest_image_dimensions(retinex_imgs=retinex_imgs)


def pad_img_gray(img, r, c):
    padded_img = np.ones((r, c))*255
    padded_img[:img.shape[0], :img.shape[1]] = img
    return padded_img

def pad_mask(mask, r, c):
    padded_mask = np.zeros((r, c))
    padded_mask[:mask.shape[0], :mask.shape[1]] = mask
    return padded_mask

def pad_img_color(img, r, c):
    padded_img = np.ones((r, c, 3))*255
    padded_img[:img.shape[0], :img.shape[1], :] = img
    return padded_img

def pad_img(img, r, c, mask=False):
    if mask:
        return pad_mask(img, r, c)
    if len(img.shape) == 2:
        return pad_img_gray(img, r, c)
    else:
        return pad_img_color(img, r, c)


retinex_imgs = [pad_img(img, largest_r, largest_c) for img in tqdm(retinex_imgs)]
orig_imgs = [pad_img(img, largest_r, largest_c) for img in tqdm(orig_imgs)]
retinex_masks = [pad_img(mask, largest_r, largest_c, mask=True) for mask in tqdm(retinex_masks)]

100%|██████████| 2495/2495 [01:17<00:00, 32.34it/s]
100%|██████████| 2495/2495 [05:15<00:00,  7.91it/s]
100%|██████████| 2495/2495 [00:22<00:00, 112.64it/s]


In [24]:
output_dir = 'sitk_combined'
with open('/storage/valis_reg/sitk_combined/transforms.pkl', 'rb') as file:
    transforms = pickle.load(file)

In [26]:

def apply_channel_transform(img_channel, transform):
    sitk_img = sitk.GetImageFromArray(img_channel)
    moved_img = sitk.Resample(sitk_img, transform, sitk.sitkLinear, 255.0, sitk_img.GetPixelID())
    moved_img = sitk.GetArrayFromImage(moved_img)
    return moved_img.astype(np.uint8)

def apply_transform(img, transform):
    r = apply_channel_transform(img[:,:,0], transform)
    g = apply_channel_transform(img[:,:,1], transform)
    b = apply_channel_transform(img[:,:,2], transform)
    return np.stack([r,g,b], axis=2)

def apply_transfrom_mask(mask, transform):
    sitk_img = sitk.GetImageFromArray(mask)
    moved_img = sitk.Resample(sitk_img, transform, sitk.sitkNearestNeighbor, 0.0, sitk_img.GetPixelID())
    moved_img = sitk.GetArrayFromImage(moved_img)
    return moved_img.astype(np.uint8)

In [27]:
assert len(retinex_imgs) == len(transforms)

In [28]:
retinex_imgs_aligned = [apply_transform(img, transform) for img, transform in tqdm(zip(retinex_imgs, transforms))]
orig_imgs_aligned = [apply_transform(img, transform) for img, transform in tqdm(zip(orig_imgs, transforms))]

2495it [07:39,  5.43it/s]
2495it [07:35,  5.48it/s]


In [29]:
retinex_masks_aligned = [apply_transfrom_mask(mask, transform) for mask, transform in tqdm(zip(retinex_masks, transforms))]

2495it [01:13, 33.87it/s]


In [30]:
os.makedirs('244_BFIW_aligned/retinex', exist_ok=True)
os.makedirs('244_BFIW_aligned/original', exist_ok=True)
os.makedirs('244_BFIW_aligned/mask', exist_ok=True)

for i, (retinex_img, orig_img, mask) in tqdm(enumerate(zip(retinex_imgs_aligned, orig_imgs_aligned, retinex_masks_aligned))):
    imageio.imwrite(f'244_BFIW_aligned/retinex/{str(i).zfill(4)}.jpg', retinex_img)
    imageio.imwrite(f'244_BFIW_aligned/original/{str(i).zfill(4)}.jpg', orig_img)
    imageio.imwrite(f'244_BFIW_aligned/mask/{str(i).zfill(4)}.jpg', mask)

2495it [01:57, 21.30it/s]
