In [52]:
import glob

import numpy as np
import os
from PIL import Image
import rasterio

from tqdm.notebook import tqdm

In [57]:
def get_sentinel_files(root_folder):
    sentinel_files = []
    # Get all .jp2 files in the b2, b3, and b4 directories
    files_b2 = sorted(glob.glob(os.path.join(root_folder, 'b2', 'lrs', '*.jp2')))
    files_b3 = sorted(glob.glob(os.path.join(root_folder, 'b3',  'lrs', '*.jp2')))
    files_b4 = sorted(glob.glob(os.path.join(root_folder, 'b4',  'lrs', '*.jp2')))
            
    # Ensure each band has the same number of files and corresponding names
    if len(files_b2) == len(files_b3) == len(files_b4):
        for f2, f3, f4 in zip(files_b2, files_b3, files_b4):
            if os.path.basename(f2) == os.path.basename(f3) == os.path.basename(f4):
                sentinel_files.append((f2, f3, f4))

    return sentinel_files

def normalize_band(band):
    """Normalize the band data to the range [0, 1]."""
    band_min, band_max = band.min(), band.max()
    return (band - band_min) / (band_max - band_min)

def load_and_combine_bands(jp2_files):
    bands = []
    for file in jp2_files:
        with rasterio.open(file) as src:
            band = src.read(1)  # Read the first band
            normalized_band = normalize_band(band)
            bands.append(normalized_band)
    combined_image = np.stack(bands, axis=-1)
    return combined_image

def get_patch_coordinates(image_shape, patch_size, oversample_factor=5):
    """Calculate patch coordinates for image tiling and random oversampling."""
    height, width = image_shape[:2]
    num_patches_vertically = height // patch_size
    num_patches_horizontally = width // patch_size
    ordered_coords = [(i * patch_size, j * patch_size) for i in range(num_patches_vertically) for j in range(num_patches_horizontally)]
    random_coords = [(np.random.randint(0, height - patch_size), np.random.randint(0, width - patch_size)) for _ in range(oversample_factor * len(ordered_coords))]
    return ordered_coords + random_coords

def extract_patch(image, top_left, patch_size):
    """Extract a patch from the image."""
    x, y = top_left
    return image[x:x+patch_size, y:y+patch_size]

def save_patches(patches, base_filename, folder):
    """Save image patches to a specified folder, possibly converting to 8-bit format."""
    os.makedirs(folder, exist_ok=True)
    for idx, patch in enumerate(patches):
        # Determine the current bit depth of the patch
        if patch.dtype == np.uint16:
            # Convert 16-bit to 8-bit
            patch_8bit = (patch / 256).astype(np.uint8)
        elif patch.dtype == np.uint8:
            patch_8bit = patch
        elif patch.dtype == np.float64:
            # Scale the patch to 0-255 and convert to uint8
            patch_8bit = (patch * 255).astype(np.uint8)
        else:
            # For other types, you might want to add specific handling
            raise ValueError(f"Unsupported image data type: {patch.dtype}")

        patch_image = Image.fromarray(patch_8bit)
        patch_image.save(os.path.join(folder, f'{base_filename}.png'))

def process_folder(sentinel_folder, hr_folder, output_folder, patch_size=256, hr_patch_size=768):
    """Process Sentinel and HR images, extract and save patches."""
    shape = (0, 0, 0)
    folder_counts = 1

    sentinel_files = get_sentinel_files(sentinel_folder)
    hr_files = sorted([
        os.path.join(hr_folder, f) 
        for f in os.listdir(hr_folder) 
        if (f.endswith('_2.tiff') or f.endswith('_3.tiff') or f.endswith('_4.tiff'))
    ])

    for idx, (sentinel_bands) in tqdm(enumerate(sentinel_files)):
        sentinel_image = load_and_combine_bands(sentinel_bands)
        hr_image = load_and_combine_bands(hr_files)

        if shape != sentinel_image.shape:
            # Patch coordinates
            patch_coordinates = get_patch_coordinates(sentinel_image.shape, patch_size)
            shape = sentinel_image.shape

            folder_name = folder_counts * 100
            folder_counts = folder_counts + 1

        for patch_idx, coord in enumerate(patch_coordinates):
            # Extract and save Sentinel patch
            sentinel_patch = extract_patch(sentinel_image, coord, patch_size)
            save_patches([sentinel_patch], f'LR{idx}', os.path.join(output_folder, f'imgset{folder_name + patch_idx}'))

            # Extract and save HR patch
            hr_coord = (coord[0] * 3, coord[1] * 3)  # Scale factor of 3
            hr_patch = extract_patch(hr_image, hr_coord, hr_patch_size)
            save_patches([hr_patch], 'HR', os.path.join(output_folder, f'imgset{folder_name + patch_idx}'))

            # Save mask patch
            array = np.full((256, 256), 255, dtype=np.uint8)
            image = Image.fromarray(array)
            image.save(os.path.join(output_folder, f'imgset{folder_name + patch_idx}', f'QM{idx}.png'))


In [58]:
# Example usage
root_folder = '../MuS2/image_data'
image_sets = glob.glob(os.path.join(root_folder, '*'))

image_sets = image_sets

for image_set in image_sets:
    sentinel_folder = image_set
    hr_folder = os.path.join(image_set, 'hr_resized')
    process_folder(sentinel_folder, hr_folder, 'MuS2_data/')

0it [00:00, ?it/s]

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

KeyboardInterrupt: 

: 