In [1]:
import os
os.chdir("..")
print("Current Directory:", os.getcwd())

Current Directory: /workspace/iscat


In [2]:
from src.data_processing.dataset import iScatDataset
from src.data_processing.utils import Utils
import torch
import numpy as np
import matplotlib.pyplot as plt
DEVICE= 'cuda:11' if torch.cuda.is_available() else 'cpu'
data_path_1 = os.path.join('dataset', '2024_11_11', 'Metasurface', 'Chip_02')
data_path_2 = os.path.join('dataset', '2024_11_12', 'Metasurface', 'Chip_01')
image_paths= []
target_paths=[]
image_indicies = 12
for data_path in [data_path_1,data_path_2]:
    i,t = Utils.get_data_paths(data_path,'Brightfield',image_indicies )
    image_paths.extend(i)
    target_paths.extend(t)

In [39]:
image_size=256
fluo_masks_indices=[1]
seg_method = "comdet"
normalize=False
train_dataset = iScatDataset(image_paths[:-2], target_paths[:-2], preload_image=True,image_size = (image_size,image_size),apply_augmentation=True,normalize=normalize,device=DEVICE,fluo_masks_indices=fluo_masks_indices,seg_method=seg_method)
valid_dataset = iScatDataset(image_paths[-2:],target_paths[-2:],preload_image=True,image_size = (image_size,image_size),apply_augmentation=False,normalize=normalize,device=DEVICE,fluo_masks_indices=fluo_masks_indices,seg_method=seg_method)
MEAN = train_dataset.images.mean(dim=(0,2,3),keepdim=True)
STD = train_dataset.images.std(dim=(0,2,3),keepdim=True)

Loading images to Memory: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:01<00:00,  7.24it/s]
Loading images to Memory: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.99it/s]


In [40]:
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms.v2 import Normalize
batch_size=128
def create_dataloaders(test_dataset, batch_size=4):
    val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return val_loader
val_loader = create_dataloaders(valid_dataset, batch_size=batch_size)

In [41]:
n_samples = 10
samples = [valid_dataset[i] for i in range(n_samples)]
test_batch = next(iter(val_loader))

In [6]:
experiments_paths = (
    'experiments/runs/UNet_Brightfield_2025-01-12_18-05-44',
    'experiments/runs/UNet_Brightfield_2025-01-12_19-09-15',
    'experiments/runs/UNet_Brightfield_2025-01-12_20-27-14')

In [7]:
from src.models.Unet import UNet
def load_model(path, num_classes=2,device=DEVICE):
    model = UNet(in_channels=12, num_classes=num_classes, init_features=64)
    checkpoint = torch.load(path, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()  
    return model
    
def predict(model, image, mean, std, device):
    model.eval()
    input_image = image.to(device) # torch.Size([1, 3, 224, 224])
    input_image = Utils.z_score_normalize(input_image, mean, std)
    with torch.no_grad():
        output = model(input_image)  # Shape: [1, num_classes, 224, 224]
    predicted_mask = torch.argmax(output, dim=1).cpu().numpy()  # Shape: (224, 224)

    return predicted_mask

In [8]:
import numpy as np
from scipy.ndimage import label
from typing import Tuple, Dict

def count_matching_particles_multiclass(
    pred_mask: np.ndarray,
    gt_mask: np.ndarray
) -> Dict[int, Tuple[int, int, int]]:
    """
    Count matching particles between prediction and ground truth masks for multiple classes.
    
    Args:
        pred_mask: Multi-class prediction mask (0 for background, 1+ for different particle classes)
        gt_mask: Multi-class ground truth mask (0 for background, 1+ for different particle classes)
        
    Returns:
        Dictionary mapping class_id to tuple of (true_positives, false_positives, false_negatives)
    """
    # Get unique classes (excluding background class 0)
    classes = sorted(set(np.unique(pred_mask)) | set(np.unique(gt_mask)))
    classes = [c for c in classes if c != 0]
    
    results = {}
    
    # Process each class separately
    for class_id in classes:
        # Create binary masks for current class
        pred_binary = (pred_mask == class_id).astype(np.int32)
        gt_binary = (gt_mask == class_id).astype(np.int32)
        
        # Label connected components
        pred_labeled, num_pred = label(pred_binary)
        gt_labeled, num_gt = label(gt_binary)
        
        # Initialize counters
        tp = 0
        matched_pred_labels = set()
        matched_gt_labels = set()
        
        # For each predicted particle of current class
        for pred_label in range(1, num_pred + 1):
            pred_particle = pred_labeled == pred_label
            
            # Find any overlap with GT particles
            overlapping_gt_labels = set(gt_labeled[pred_particle]) - {0}
            
            if overlapping_gt_labels:
                # If there's any overlap, count as TP
                tp += 1
                matched_pred_labels.add(pred_label)
                matched_gt_labels.update(overlapping_gt_labels)
        
        # Count unmatched predictions as FP and unmatched GT as FN
        fp = num_pred - len(matched_pred_labels)
        fn = num_gt - len(matched_gt_labels)
        
        results[class_id] = (tp, fp, fn)
    
    return results

def process_batch_multiclass(
    pred_masks: np.ndarray,
    gt_masks: np.ndarray
) -> Dict[int, Tuple[int, int, int]]:
    """
    Process a batch of masks and aggregate the results.
    
    Args:
        pred_masks: Batch of prediction masks [batch_size, height, width]
        gt_masks: Batch of ground truth masks [batch_size, height, width]
        
    Returns:
        Dictionary mapping class_id to aggregated (tp, fp, fn) across the batch
    """
    # Initialize results dictionary
    batch_results = {}
    
    # Process each image in the batch
    for pred_mask, gt_mask in zip(pred_masks, gt_masks):
        image_results = count_matching_particles_multiclass(pred_mask, gt_mask)
        
        # Aggregate results for each class
        for class_id, (tp, fp, fn) in image_results.items():
            if class_id not in batch_results:
                batch_results[class_id] = [0, 0, 0]
            batch_results[class_id][0] += tp
            batch_results[class_id][1] += fp
            batch_results[class_id][2] += fn
    
    # Convert lists to tuples in final results
    return {k: tuple(v) for k, v in batch_results.items()}

In [42]:
images = test_batch[0].clone()
gt_masks = test_batch[1].clone().cpu().numpy()
path = experiments_paths[1]
model_path = path+'/best_model.pth'
model = load_model(model_path,num_classes=2)
pred_masks = predict(model, images, MEAN, STD, DEVICE)

Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


In [43]:
results = process_batch_multiclass(pred_masks,gt_masks)

In [44]:
results

{1: (788, 238, 39)}

In [27]:
results[1][0]

(824, 724, 923)

In [46]:
def precision(tp,fp):
    return tp/(tp+fp)
def recall(tp,fn):
    return tp/(tp+fn)
print(precision(results[1][0],results[1][1]))
print(recall(results[1][0],results[1][2]))

0.7680311890838206
0.9528415961305925


In [45]:
print(precision(results[2][0],results[2][1]))
print(recall(results[2][0],results[2][2]))

KeyError: 2

In [34]:
print(recall(results[2][0]+results[1][0],results[2][2]+results[1][2]))

0.6347531096871466


In [35]:
print(precision(results[2][0]+results[1][0],results[2][1]+results[1][1]))

0.639331814730448
