In [1]:
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [2]:
def patch_extractor(img, num_patches=100, patch_size=64, intensity_threshold=10):
    h, w, _ = img.shape
    patches = []

    for _ in range(num_patches):
        y= np.random.randint(0, h - patch_size + 1)
        x= np.random.randint(0, w - patch_size + 1)
        patch = img[y:y+patch_size, x:x+patch_size, :]

        if np.mean(patch) > intensity_threshold :
            patches.append(patch)
    
    return np.array(patches)


def reconstruct_image_from_patches(image_shape, patches, positions, labels):
    """
    Reconstruct an image from patches with color-coded overlay based on classification labels.

    Args:
        image_shape (tuple): The shape of the original image (H, W, C).
        patches (np.array): Array of extracted patches.
        positions (list): List of (y, x) coordinates for each patch.
        labels (np.array): Classification labels for each patch (0: non-columnar, 1: columnar).

    Returns:
        reconstructed_image (np.array): Reconstructed image with color overlay.
    """
    h, w, c = image_shape
    patch_size = patches.shape[1]
    reconstructed_image = np.zeros((h, w, c), dtype=np.uint8)
    count_map = np.zeros((h, w), dtype=np.uint8)

    # Color mapping for labels
    columnar_color = np.array([255, 0, 0], dtype=np.uint8)   # Red for columnar
    non_columnar_color = np.array([0, 0, 255], dtype=np.uint8)  # Blue for non-columnar

    for (y, x), label in zip(positions, labels):
        color = columnar_color if label >= 0.5 else non_columnar_color

        # Fill patch region with the respective color
        reconstructed_image[y:y+patch_size, x:x+patch_size] = color

    return reconstructed_image


In [3]:
# Load images & patches (ref/val)
path2img = '/home/yec23006/projects/research/KneeGrowthPlate/Knee_GrowthPlate/Images/CCC_K05_hK_FL1_s1_shift3_So.jpg'
path2platemask = '/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/plate_selection/growthplate_mask.png'
path2columnar = '/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/plate_selection/columnar_mask.png'
path2patch = "/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/patch_extraction/filtered_patches_nb.npy"
path2patchposition = "/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/patch_extraction/filtered_patch_positions_nb.npy"
save2 = "/home/yec23006/projects/research/KneeGrowthPlate/Embedding/results/patch_extraction"

image = cv2.imread(path2img, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
filtered_patches = np.load(path2patch) # for validation
patchposition = np.load(path2patchposition) # for recon

# ref patches (labels)
columnar = image[6900:7050, 5100:5500, :]
noncolumnar = image[6700:6850, 5100:5500]
columnar_patches = patch_extractor(columnar)
noncolumnar_patches = patch_extractor(noncolumnar)

In [45]:
# Tensor dataset for train and test
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# TrainLoader
all_patches = np.vstack((columnar_patches, noncolumnar_patches)).astype(np.float32)/255.0
labels = np.array([1]*len(columnar_patches) + [0]*len(noncolumnar_patches)) # 1 for columnar, 0 for noncolumnar
train_loader = torch.tensor(all_patches).permute(0,3,1,2).to(device)
labels = torch.tensor(labels, dtype=torch.long).to(device)
data_loader = DataLoader(TensorDataset(train_loader, labels), batch_size=32, shuffle=True)

# TestLoader
test_images = torch.tensor(filtered_patches.astype(np.float32)/255.0).to(device)
test_images = test_images.permute(0, 3, 1, 2)
test_loader = DataLoader(TensorDataset(test_images), batch_size=32, shuffle=False)

In [43]:
batch.shape, images.shape

(torch.Size([3, 64, 64]), torch.Size([8, 3, 64, 64]))

In [46]:
# CNN patch classifier 
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

model = CNNClassifier().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in data_loader:
        images, labels = images.to(device), labels.float().to(device).unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        predicted = (outputs > 0.5).float()
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
    
    train_accuracy = 100 * correct / total
    print(f"Train Accuracy : {train_accuracy:.2f}%")
    
# Eval
model.eval()
predictions = []

with torch.no_grad():
    for batch in test_loader:
        batch = batch[0].to(device)  # Extract the batch tensor
        output = model(batch)  # Get model predictions
        predictions.append(output.cpu().numpy())  # Move to CPU & store

# Concatenate all batch predictions
predictions = np.vstack(predictions)

# Save predictions
np.save(os.path.join(save2, "CNNPatchClassifierPred.npy"), predictions)


Train Accuracy : 72.50%
Train Accuracy : 92.50%
Train Accuracy : 81.50%
Train Accuracy : 88.50%
Train Accuracy : 95.50%
Train Accuracy : 97.50%
Train Accuracy : 98.00%
Train Accuracy : 98.00%
Train Accuracy : 99.00%
Train Accuracy : 99.50%


In [6]:
predictions = np.load(os.path.join(save2, "CNNPatchClassifierPred.npy"))
reconstructed_image = reconstruct_image_from_patches(image.shape, filtered_patches, patchposition, predictions)
overlay_img = reconstructed_image * 0.5 + image.astype(np.float32)
overlay_img = np.clip(overlay_img, 0, 255).astype(np.uint8)
Image.fromarray(overlay_img).save(os.path.join(save2, "CNNPatchClassifierResultRecon.png"))