In [1]:
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
path2patch = "/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/patch_extraction/filtered_patches_nb.npy"
path2patchposition = "/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/patch_extraction/filtered_patch_positions_nb.npy"
path2prediction = "/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/patch_extraction/predicted_labels.npy"
path2img = '/home/yec23006/projects/research/KneeGrowthPlate/Knee_GrowthPlate/Images/CCC_K05_hK_FL1_s1_shift3_So.jpg'

patch = np.load(path2patch)
patchposition = np.load(path2patchposition)
prediction = np.load(path2prediction)
image = cv2.cvtColor(cv2.imread(path2img, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)

In [5]:
def refine_patch_labels(image_shape, patch_size, positions, labels):
    """
    Ensures that the blue (0) area remains connected by:
      1. Converting isolated red (1) patches to blue if surrounded by blue.
      2. Ensuring all blue regions are connected using connected components.

    Args:
        image_shape (tuple): Shape of the original image (H, W, C).
        patch_size (int): The size of the patches.
        positions (np.array): List of (y, x) positions of each patch.
        labels (np.array): Patch classification labels (1: blue, 0: red).

    Returns:
        np.array: Updated labels with connected blue patches.
    """
    h, w, _ = image_shape
    
    # Create a grid to store patch labels
    grid_h, grid_w = h // patch_size, w // patch_size
    grid = np.full((grid_h, grid_w), -1, dtype=int)
    
    # Map positions to the grid
    pos_to_index = {}
    for idx, ((y, x), label) in enumerate(zip(positions, labels)):
        grid_y, grid_x = y // patch_size, x // patch_size
        grid[grid_y, grid_x] = label
        pos_to_index[(grid_y, grid_x)] = idx  # Store index for updates later

    # --- Step 1: Convert isolated red patches to blue ---
    refined_labels = labels.copy()
    offsets = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # 4-neighborhood

    for i, ((y, x), label) in enumerate(zip(positions, labels)):
        if label == 0:  # Check only red patches
            grid_y, grid_x = y // patch_size, x // patch_size
            surrounding_blue_count = 0
            total_neighbors = 0
            
            for dy, dx in offsets:
                ny, nx = grid_y + dy, grid_x + dx
                if 0 <= ny < grid.shape[0] and 0 <= nx < grid.shape[1]:  # Valid index
                    total_neighbors += 1
                    if grid[ny, nx] == 0:
                        surrounding_blue_count += 1

            # Convert red to blue if mostly surrounded by blue
            if total_neighbors > 0 and (surrounding_blue_count / total_neighbors) > 0.7:
                refined_labels[i] = 0
                grid[grid_y, grid_x] = 0  # Update grid too

    # --- Step 2: Ensure Blue Area is Connected ---
    # Label connected blue components
    blue_mask = (grid == 1).astype(np.uint8)
    num_labels, labeled_grid = cv2.connectedComponents(blue_mask, connectivity=4)

    # Find the largest blue component
    component_sizes = np.bincount(labeled_grid.ravel())[1:]  # Ignore background (0)
    if len(component_sizes) > 0:
        largest_blue_label = np.argmax(component_sizes) + 1  # Largest blue component

        # Convert smaller blue components to red (disconnect them)
        for (grid_y, grid_x), idx in pos_to_index.items():
            if grid[grid_y, grid_x] == 0 and labeled_grid[grid_y, grid_x] != largest_blue_label:
                refined_labels[idx] = 1  # Change to red to maintain connectivity

    return refined_labels

def reconstruct_image_from_patches(image_shape, patches, positions, labels):
    """
    Reconstruct an image from patches with color-coded overlay based on classification labels.

    Args:
        image_shape (tuple): The shape of the original image (H, W, C).
        patches (np.array): Array of extracted patches.
        positions (list): List of (y, x) coordinates for each patch.
        labels (np.array): Classification labels for each patch (0: non-columnar, 1: columnar).

    Returns:
        reconstructed_image (np.array): Reconstructed image with color overlay.
    """
    h, w, c = image_shape
    patch_size = patches.shape[1]
    reconstructed_image = np.zeros((h, w, c), dtype=np.uint8)
    count_map = np.zeros((h, w), dtype=np.uint8)

    # Color mapping for labels
    columnar_color = np.array([255, 0, 0], dtype=np.uint8)   # Red for columnar
    non_columnar_color = np.array([0, 0, 255], dtype=np.uint8)  # Blue for non-columnar

    # for patch, (y, x), label in zip(patches, positions, labels):
    #     color_overlay = columnar_color if label == 1 else non_columnar_color

    #     # Blend original patch with label color overlay
    #     blended_patch = (0.5 * patch + 0.5 * color_overlay).astype(np.uint8)

    #     # Assign patch to the reconstructed image
    #     reconstructed_image[y:y+patch_size, x:x+patch_size] += blended_patch
    #     count_map[y:y+patch_size, x:x+patch_size] += 1

    # # Normalize overlapping areas by averaging
    # mask = count_map > 0
    # reconstructed_image[mask] //= count_map[mask, None]

    for (y, x), label in zip(positions, labels):
        color = columnar_color if label == 1 else non_columnar_color

        # Fill patch region with the respective color
        reconstructed_image[y:y+patch_size, x:x+patch_size] = color

    return reconstructed_image

In [6]:
# Update labels based on surrounding patches
refined_labels = refine_patch_labels(image.shape, 64, patchposition, prediction)

# Use the refined labels to reconstruct the image
refined_image = reconstruct_image_from_patches(image.shape, patch, patchposition, refined_labels)
cv2.imwrite("/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/patch_extraction/Postprocessing_8.png", refined_image)


True