In [None]:
import os
os.chdir("..")
print("Current Directory:", os.getcwd())

In [None]:
from src.data_processing.dataset import iScatDataset
from src.data_processing.utils import Utils
import torch
import numpy as np
import matplotlib.pyplot as plt
DEVICE= 'cuda:7' if torch.cuda.is_available() else 'cpu'
data_path_1 = os.path.join('dataset', '2024_11_11', 'Metasurface', 'Chip_02')
data_path_2 = os.path.join('dataset', '2024_11_12', 'Metasurface', 'Chip_01')
image_paths= []
target_paths=[]
image_indicies = 12
for data_path in [data_path_1,data_path_2]:
    i,t = Utils.get_data_paths(data_path,'Brightfield',image_indicies )
    image_paths.extend(i)
    target_paths.extend(t)

In [None]:
image_size=256
fluo_masks_indices=[1]
seg_method = "comdet"
normalize=False
train_dataset = iScatDataset(image_paths[:-2], target_paths[:-2], preload_image=True,image_size = (image_size,image_size),apply_augmentation=True,normalize=normalize,device=DEVICE,fluo_masks_indices=fluo_masks_indices,seg_method=seg_method)
valid_dataset = iScatDataset(image_paths[-2:],target_paths[-2:],preload_image=True,image_size = (image_size,image_size),apply_augmentation=False,normalize=normalize,device=DEVICE,fluo_masks_indices=fluo_masks_indices,seg_method=seg_method)

In [None]:
MEAN = train_dataset.images.mean(dim=(0,2,3),keepdim=True)
STD = train_dataset.images.std(dim=(0,2,3),keepdim=True)
del train_dataset

In [None]:
# n_sample = 3
# samples = [valid_dataset[i] for i in range(n_samples)]
sample = valid_dataset[0]

In [None]:
experiments_paths = (
    'experiments/runs/UNet_Brightfield_2025-01-12_18-05-44',
    'experiments/runs/UNet_Brightfield_2025-01-12_19-09-15',
    'experiments/runs/UNet_Brightfield_2025-01-12_20-27-14')

class MultiClassUNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, init_features=32):
        super(MultiClassUNet, self).__init__()
        
        # Load the pretrained model and modify the final layer
        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', 
                               in_channels=in_channels, 
                               out_channels=1, 
                               init_features=init_features, 
                               pretrained=False)
        
        # Replace the final convolution layer to match number of classes
        model.conv = nn.Conv2d(init_features, num_classes, kernel_size=1)
        
        self.model = model
    
    def forward(self, x):
        return self.model(x)    

In [None]:
def load_model(path, num_classes=2,device=DEVICE):
    model = MultiClassUNet(in_channels=6, num_classes=num_classes, init_features=64)
    checkpoint = torch.load(path, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()  
    return model
    
def predict(model, image, mean, std, device)
    model.eval()
    input_image = image.to(device).unsqueeze(0) # torch.Size([1, 3, 224, 224])
    input_image = Utils.z_score_normalize(input_image, mean, std)
    with torch.no_grad():
        output = model(input_image)  # Shape: [1, num_classes, 224, 224]
    predicted_mask = torch.argmax(output.squeeze(0), dim=0).cpu().numpy()  # Shape: (224, 224)

    return predicted_mask


In [None]:
def normalize_image(image):
    """
    Normalize a 16-bit grayscale image to 8-bit for visualization.

    Parameters:
        image (ndarray): 16-bit grayscale image.

    Returns:
        ndarray: 8-bit grayscale image.
    """
    image = (image - image.min()) / (image.max() - image.min())  # Normalize to [0, 1]
    return (image * 255).astype(np.uint8)  # Scale to [0, 255]

def overlay_mask(image, mask, color, alpha=0.5):
    """Overlays a mask on an image with a specified color and transparency."""
    if len(image.shape) == 2:
        overlay = np.stack([image] * 3, axis=-1)
    else:
        overlay = image.copy()
    for c in range(3):
        overlay[:, :, c] = np.where(mask, overlay[:, :, c] * (1 - alpha) + color[c] * alpha, overlay[:, :, c])
    return overlay

def image_with_masks(image, predicted_mask, ground_truth_mask):
    predicted_colors = {
        1: (0, 255, 0),      # Green for class 1
        2: (0, 0, 255),    # Blue for class 2
    }
    gt_colors = {
        1: (255, 0, 0),      # Red for class 1
        2: (255,255, 0),    # Yellow for class 2
    }

    # Normalize image for visualization
    image_normalized = normalize_image(image)
    combined_image = image_normalized.copy()

    # Overlay masks for each class
    for class_label, color in predicted_colors.items():
        class_mask = (predicted_mask == class_label)
        combined_image = overlay_mask(combined_image, class_mask, color=color, alpha=0.5)

    for class_label, color in gt_colors.items():
        class_mask = (ground_truth_mask == class_label)
        combined_image = overlay_mask(combined_image, class_mask, color=color, alpha=0.5)
    return combined


In [None]:
preds = []
for idx, path in enumerate(experiments_paths):
    if idx==2:
        num_classes=3
    else:
        num_classes=2
    model_path = path+'/best_model.pth'
    model = load(model_path,num_classes=num_classes)
    pred_mask = predict(model, sample[0], MEAN, STD, DEVICE)
    combined = image_with_masks(sample[0], pred_mask , sample[1])
    preds.append(combined)

In [None]:
# Plot the image and the overlays
fig, (ax1, ax2, ax3) = plt.subplots(1, 2, figsize=(12, 5))

# Original image visualization
ax1.imshow(image_normalized, cmap='gray')
ax1.set_title("Cy5:0")
ax1.axis("off")

# Combined overlay visualization
ax2.imshow(combined_image)
ax2.set_title("FITC:1")
ax2.axis("off")

 

plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.show()
legend_elements = []
for class_label, color in predicted_colors.items():
    legend_elements.append(plt.Line2D([0], [0], color=np.array(color) / 255, lw=4, label=f'Predicted Class {class_label}'))
for class_label, color in gt_colors.items():
    legend_elements.append(plt.Line2D([0], [0], color=np.array(color) / 255, lw=4, linestyle='dashed', label=f'GT Class {class_label}'))
ax2.legend(handles=legend_elements, loc='lower right', fontsize='small')
