# intro

In [None]:
# run with usr/bin/python (python 3.10.12)
BACKBONE = 'timm-resnest101e'
N_FOLDS = 10 #number of folds
FOLD = 3 #test fold
PENALIZATION_LAMBDA = 0 # 0 for standard jaccard loss
RESOLUTION = 320
BATCH_SIZE = 8

In [None]:
#!pip uninstall -y segmentation-models-pytorch
!pip uninstall -y segmentation-models-pytorch
!pip install --force-reinstall --no-deps segmentation-models-pytorch==0.2.1
!pip install -U albumentations --user 
!pip install segmentation-models-pytorch
!pip install opencv-python
!pip install -U numpy
!pip install matplotlib

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize
import os
import shutil
import numpy as np
import tqdm
import torch

torch.__version__

# build train/val data folders

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize
import os
import shutil
import tqdm

if os.path.exists("data"):
    shutil.rmtree("data")

if not os.path.exists("data"):
    os.mkdir("data")
    os.mkdir("data/torchData")
    os.mkdir("data/torchData/train")
    os.mkdir("data/torchData/trainannot")
    os.mkdir("data/torchData/val")
    os.mkdir("data/torchData/valannot")
    os.mkdir("data/torchData/test")
    os.mkdir("data/torchData/testannot")

SIZE = RESOLUTION
n = 246
validation_is = []
for i in tqdm.tqdm(range(1,n)):
    img = cv2.imread(os.path.join("dataset", "images", f"{i}.tif"))
    mask = cv2.imread(os.path.join("dataset", "masks", f"{i}.tif"))

    img = resize(img, (SIZE, SIZE), mode='constant', preserve_range=True).astype(np.uint8)
    mask = resize(mask, (SIZE, SIZE), mode='constant', preserve_range=True).astype(np.uint8) > mask.mean()
    #break

    if len(img.shape) == 2:
        img = np.dstack([img, img, img])
    if len(mask.shape) == 2:
        mask = np.dstack([mask, mask, mask])
    mask = ((mask) * 1).astype(np.uint8)

    # if ((i % 10) == 0) or ((i % 10) == 1):
    #     cv2.imwrite(os.path.join("data", "torchData", "val", f"{i}.png"), img)
    #     cv2.imwrite(os.path.join("data", "torchData", "valannot", f"{i}.png"), mask)
    # else:
    #     cv2.imwrite(os.path.join("data", "torchData", "train", f"{i}.png"), img)
    #     cv2.imwrite(os.path.join("data", "torchData", "trainannot", f"{i}.png"), mask)

    percentage = i / n * 100
    lower_bound = (FOLD - 1) * 10
    upper_bound = FOLD * 10
    if True and (percentage >= lower_bound and percentage < upper_bound):
        validation_is.append(i)
        cv2.imwrite(os.path.join("data", "torchData", "val", f"{i}.png"), img)
        cv2.imwrite(os.path.join("data", "torchData", "valannot", f"{i}.png"), mask)
    else:
        cv2.imwrite(os.path.join("data", "torchData", "train", f"{i}.png"), img)
        cv2.imwrite(os.path.join("data", "torchData", "trainannot", f"{i}.png"), mask)
        
    # mettere qui il codice per fare i fold diversamente 

    #cv2.imwrite(os.path.join("data", "torchData", "test", f"{i}.png"), test_img)
    #cv2.imwrite(os.path.join("data", "torchData", "testannot", f"{i}.png"), test_mask)

# load data

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
DATA_DIR = './data/torchData/'

In [None]:
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')

x_test_dir = os.path.join(DATA_DIR, 'val')
y_test_dir = os.path.join(DATA_DIR, 'valannot')

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

### Dataloader

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [None]:
class Dataset(BaseDataset):    
    CLASSES = ['background', 'filopodia']

    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = sorted(os.listdir(images_dir))
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        #print(mask.dtype, mask.min(), mask.max())
        # extract certain classes from mask (e.g. filopodia)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)

In [None]:
# Lets look at data we have
dataset = Dataset(x_train_dir, y_train_dir, classes=['filopodia'])

image, mask = dataset[5] # get some sample
print(mask.dtype, mask.min(), mask.max())
visualize(
    image=image, 
    filopodia_mask=mask.squeeze(),
)

# augmentation

In [None]:
import albumentations as albu

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        albu.PadIfNeeded(min_height=RESOLUTION, min_width=RESOLUTION, always_apply=True, border_mode=0),
        # albu.RandomCrop(height=RESOLUTION, width=RESOLUTION, always_apply=True),
        albu.GaussNoise(p=0.2),
        albu.Perspective(p=0.5),
        albu.OneOf([
                albu.CLAHE(p=1),
                albu.RandomBrightnessContrast(),
                albu.RandomGamma(p=1),
            ],p=0.9,),
        albu.OneOf([
                albu.Sharpen(p=1),
                #albu.Blur(blur_limit=3, p=1),
                #albu.MotionBlur(blur_limit=3, p=1),
            ],p=0.9,),
        albu.OneOf([
                albu.HueSaturationValue(p=1),
            ],p=0.9,),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    test_transform = [
        albu.augmentations.geometric.resize.LongestMaxSize([RESOLUTION, RESOLUTION])
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
dataset = Dataset(x_train_dir, y_train_dir, classes=['filopodia'])
dataset = Dataset(x_valid_dir, y_valid_dir, classes=['filopodia'])

image, mask = dataset[10] # get some sample
print(mask.dtype, mask.min(), mask.max())
visualize(
    image=image, 
    filopodia_mask=mask.squeeze(),
)

In [None]:
# Visualize resulted augmented images and masks
augmented_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=get_training_augmentation(), 
    classes=['filopodia'],
)

# same image with different random transforms
for i in range(3):
    image, mask = augmented_dataset[0]
    print(image.shape)
    visualize(image=image, mask=mask.squeeze(-1))

# create counting model

In [None]:
counting_model = torch.load("best_counting_model.pth")

# create segmentation model

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp

In [None]:
ENCODER = BACKBONE
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['filopodia']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)
model.to(DEVICE)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    augmentation=None, 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=4)

In [None]:
import skimage
from skimage.morphology import thin
from collections import Counter
import torch.nn as nn

def find_branchpoints(skeleton):
    #skeleton = skeleton.astype(int)
    return find_endpoints(skeleton) - 2

def find_endpoints(img):
    # Find row and column locations that are non-zero
    (rows,cols) = np.nonzero(img)

    # Initialize empty list of co-ordinates
    skel_coords = []

    # For each non-zero pixel...
    for (r,c) in zip(rows,cols):

        # Extract an 8-connected neighbourhood
        (col_neigh,row_neigh) = np.meshgrid(np.array([c-1,c,c+1]), np.array([r-1,r,r+1]))

        # Cast to int to index into image
        col_neigh = col_neigh.astype('int')
        row_neigh = row_neigh.astype('int')

        # Convert into a single 1D array and check for non-zero locations
        pix_neighbourhood = img[row_neigh,col_neigh].ravel() != 0

        # If the number of non-zero locations equals 2, add this to our list of co-ordinates
        if np.sum(pix_neighbourhood) == 2:
            skel_coords.append((r,c))

    return len(skel_coords)

def detect_fused(img):

    n_fused = 0
    n_single = 0

    img_thinned = thin(img) # or skeletonize, small difference
    img_thinned[0,:] = 0
    img_thin_labeled = skimage.measure.label(img_thinned.astype(np.uint8), connectivity=2)
    img_labeled = skimage.measure.label(img.astype(np.uint8), connectivity=2)
    stats_bbox = skimage.measure.regionprops(img_thin_labeled.astype(np.uint8))
    # results to fill
    fused_image = np.zeros_like(img)
    singles_image = np.zeros_like(img)
    finish = np.zeros_like(img)

    for i in range(0, len(stats_bbox)):

        bbox = stats_bbox[i].bbox
        # take thinned branch region
        bbox_region = img_thin_labeled[bbox[0]:bbox[2], bbox[1]:bbox[3]]

        # take its largest connected component in case multiple accidentally are in that bounding box
        value_counts = Counter(bbox_region.flatten()).most_common()
        most_frequent_value = value_counts[1][0] if len(value_counts) > 1 else value_counts[0][0]
        bbox_region = (bbox_region == most_frequent_value) * 1

        # if into that bounding box #branchpoints > 1 AND #endpoints >= 4, it is a FUSED filopodia
        bbox_region_padded = np.pad(bbox_region, pad_width=4, mode='constant', constant_values=0)
        n_endpoints = find_endpoints(bbox_region_padded)
        n_branchpoints = find_branchpoints(bbox_region_padded)
        is_fused = n_branchpoints > 1 and n_endpoints >= 4

        # mark FUSED and SINGLE regions with 2 different values
        if is_fused:
            fused_image += (img_labeled == (i + 1))
            n_fused += 1
        else:
            singles_image += (img_labeled == (i + 1))
            n_single += 1

        finish = singles_image + fused_image * 2

    return finish, n_single, n_fused

def num_filopodia_demerged(mask):
    thinned = thin(mask)
    img_thin_labeled = skimage.measure.label(thinned.astype(np.uint8), connectivity=2)
    stats_bbox = skimage.measure.regionprops(img_thin_labeled.astype(np.uint8))
    filopodia_count = 0
    for i in range(0, len(stats_bbox)):
        bbox = stats_bbox[i].bbox
        bbox_region = img_thin_labeled[bbox[0]:bbox[2], bbox[1]:bbox[3]]
        value_counts = Counter(bbox_region.flatten()).most_common()
        most_frequent_value = value_counts[1][0] if len(value_counts) > 1 else value_counts[0][0]
        bbox_region = (bbox_region == most_frequent_value) * 1

        # if into that bounding box #branchpoints > 1 AND #endpoints >= 4, it is a FUSED filopodia
        bbox_region_padded = np.pad(bbox_region, pad_width=4, mode='constant', constant_values=0)
        n_endpoints = find_endpoints(bbox_region_padded)
        
        filopodia_count += (n_endpoints - 1)
    return filopodia_count

def custom_loss(y_true, y_pred):
    # y_true and y_pred are batches, calculate single losses
    filopodia_penalization = 0
    #print("loss RESOLUTIONZE: ", len(y_pred))
    for i in range(0, len(y_pred)):
        pred = y_pred[i]
        true = y_true[i]
        pred = (pred.cpu().detach().numpy()[0].reshape((RESOLUTION, RESOLUTION)) > 0.5).astype(np.float64)
        true = true.cpu().detach().numpy()[0].reshape((RESOLUTION, RESOLUTION))
        n_filo_pred = num_filopodia_demerged(pred) # o counting_model(torch.Tensor(pred))
        n_filo_true = num_filopodia_demerged(true)
        if n_filo_true > 0 and n_filo_pred > 0:
            filopodia_penalization += np.abs(np.log(n_filo_pred / n_filo_true))
        else:
            filopodia_penalization += 0 

    return filopodia_penalization


class CustomLoss(nn.Module):

    iou_loss = smp.utils.losses.JaccardLoss(eps=0.1)

    def __init__(self):
        super(CustomLoss, self).__init__()
        self.__name__ = "jaccard_loss"

    def forward(self, output, target):
        iou = self.iou_loss.forward(output, target)
        penalization = custom_loss(target, output) * PENALIZATION_LAMBDA
        penalization = torch.tensor(penalization).to(DEVICE)
        return iou + penalization


In [None]:
loss = smp.utils.losses.JaccardLoss(eps=0.1) if PENALIZATION_LAMBDA == 0 else CustomLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

# training

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader's samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
max_score = 0
train_loss_progession = []
val_loss_progression = []

for i in range(0, 100):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    train_loss_progession.append(train_logs['jaccard_loss'])
    val_loss_progression.append(valid_logs['jaccard_loss'])

    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')

    # if i % 5 == 0:
    #     optimizer.param_groups[0]['lr'] = 0.01
    # else:
    #     optimizer.param_groups[0]['lr'] = 0.0001
        
    if i > 70:
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.8
        print('Decrease decoder learning rate')

In [None]:
plt.plot(train_loss_progession)
plt.plot(val_loss_progression)
plt.legend(labels=["train loss", "val loss"])

# Test best saved model

In [None]:
# load best saved checkpoint
best_model = torch.load('./best_model.pth')

ENCODER = 'timm-resnest101e'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['filopodia']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
DATA_DIR = './data/torchData/'
x_test_dir = os.path.join(DATA_DIR, 'val')
y_test_dir = os.path.join(DATA_DIR, 'valannot')

# create test dataset
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    augmentation=None, 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=["filopodia"],
)

test_dataloader = DataLoader(test_dataset)

In [None]:
# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(test_dataloader)

## Visualize predictions

In [None]:
DATA_DIR = './data/torchData/'
x_test_dir = os.path.join(DATA_DIR, 'val')
y_test_dir = os.path.join(DATA_DIR, 'valannot')


# create test dataset
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    augmentation=None, 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=["filopodia"],
)

test_dataloader = DataLoader(test_dataset)

In [None]:
for i in range(len(test_dataset)):
    n = i
    
    image, gt_mask = test_dataset[n]
    
    gt_mask = gt_mask.squeeze().astype(bool)
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    print(image.shape, image.dtype, image.min(), image.max(), image.mean())
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy() > 0.999)
    # todo small erosion
    pr_mask = cv2.erode(pr_mask.astype(np.uint8), cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (1, 1))) * 255
        
    visualize(image=image[0,:,:], ground_truth_mask=gt_mask, predicted_mask=pr_mask, diff=(gt_mask ^ pr_mask))
    print(str(i + (FOLD - 1) * len(test_dataset)))
    cv2.imwrite("./finalPredictions512/" + str(i + (FOLD - 1) * len(test_dataset)) + ".png", pr_mask)

# measure metrics

In [None]:
def iou(prediction, true_mask):
    intersection = np.logical_and(prediction, true_mask).sum()
    union = np.logical_or(prediction, true_mask).sum()
    iou_score = intersection / union
    return iou_score

def dice(prediction, true_mask):
    intersection = np.logical_and(prediction, true_mask).sum()
    dice_score = (2. * intersection) / (prediction.sum() + true_mask.sum())
    return dice_score

def precision(prediction, true_mask):
    true_positives = np.logical_and(prediction, true_mask).sum()
    false_positives = np.logical_and(prediction, np.logical_not(true_mask)).sum()
    precision_score = true_positives / (true_positives + false_positives)
    return precision_score


def recall(prediction, true_mask):
    true_positives = np.logical_and(prediction, true_mask).sum()
    false_negatives = np.logical_and(np.logical_not(prediction), true_mask).sum()
    if int(true_positives + false_negatives) == 0:
        return 0
    recall_score = true_positives / (true_positives + false_negatives)
    return recall_score

def f1_score(prediction, true_mask):
    p = precision(prediction, true_mask)
    r = recall(prediction, true_mask)
    if precision == 0:
        return 0
    f1 = 2 * (p * r) / (p + r)
    return f1

def mse(prediction, true_mask):
    mse_score = np.mean((prediction - true_mask) ** 2)
    return mse_score

def num_filopodia_blobs(mask):
    return skimage.measure.label(mask)

def num_filopodia_demerged(mask):
    thinned = thin(mask)
    img_thin_labeled = skimage.measure.label(thinned.astype(np.uint8), connectivity=2)
    stats_bbox = skimage.measure.regionprops(img_thin_labeled.astype(np.uint8))
    filopodia_count = 0
    for i in range(0, len(stats_bbox)):
        bbox = stats_bbox[i].bbox
        bbox_region = img_thin_labeled[bbox[0]:bbox[2], bbox[1]:bbox[3]]

        value_counts = Counter(bbox_region.flatten()).most_common()
        most_frequent_value = value_counts[1][0] if len(value_counts) > 1 else value_counts[0][0]
        bbox_region = (bbox_region == most_frequent_value) * 1

        # if into that bounding box #branchpoints > 1 AND #endpoints >= 4, it is a FUSED filopodia
        bbox_region_padded = np.pad(bbox_region, pad_width=4, mode='constant', constant_values=0)
        n_endpoints = find_endpoints(bbox_region_padded)
        
        filopodia_count += n_endpoints - 1
    return filopodia_count

def filopodia_length_sum(mask):
    return np.count_nonzero(thin(mask))

In [None]:
IOUs, DICEs, PRECISIONs, RECALLs, F1SCOREs, MSEs = [],[],[],[],[],[]
filo_N_diffs, filo_N_abs_diffs, filo_len_diffs, filo_len_abs_diffs = [],[],[],[]
single_filo_N_diff, single_filo_N_abs_diff, merged_filo_N_diff, merged_filo_N_abs_diff = [],[],[],[]

for i in range(len(test_dataset)):
    n = i
    
    image, gt_mask = test_dataset[n]
    plt.imshow(image[0,:,:]), plt.show()
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy() > 0.99)

    pred = pr_mask
    mask = gt_mask

    contours, _ = cv2.findContours(pred.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    filtered_contours = []
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > 25:
            filtered_contours.append(contour)
    filtered_image = np.zeros_like(pred, dtype=float)
    cv2.drawContours(filtered_image, filtered_contours, -1, 255, thickness=cv2.FILLED)
    #pred = filtered_image
    
    fused_pred, n_single_p, n_fused_p = detect_fused(pred)
    fused_true, n_single_t, n_fused_t = detect_fused(mask)

    IOUs.append(iou(pred, mask))
    DICEs.append(dice(pred, mask))
    PRECISIONs.append(precision(pred, mask))
    RECALLs.append(recall(pred, mask))
    F1SCOREs.append(f1_score(pred, mask))
    MSEs.append(mse(pred, mask))
    filo_N_diffs.append(num_filopodia_demerged(pred) - num_filopodia_demerged(mask))
    filo_N_abs_diffs.append(abs(num_filopodia_demerged(pred) - num_filopodia_demerged(mask)))
    filo_len_diffs.append(filopodia_length_sum(pred) - filopodia_length_sum(mask))
    filo_len_abs_diffs.append(abs(filopodia_length_sum(pred) - filopodia_length_sum(mask)))
    single_filo_N_diff.append(n_single_p - n_single_t)
    single_filo_N_abs_diff.append(abs(n_single_p - n_single_t))
    merged_filo_N_diff.append(n_fused_p - n_fused_t)
    merged_filo_N_abs_diff.append(abs(n_fused_p - n_fused_t))

In [None]:
# First, check if all lists have the same length
lengths = [len(lst) for lst in [IOUs, DICEs, PRECISIONs, RECALLs, F1SCOREs, MSEs,
                                filo_N_diffs, filo_N_abs_diffs, filo_len_diffs, filo_len_abs_diffs,
                                single_filo_N_diff, single_filo_N_abs_diff,
                                merged_filo_N_diff, merged_filo_N_abs_diff]]
assert all(length == lengths[0] for length in lengths), "All lists must have the same length"

lists = [IOUs, DICEs, PRECISIONs, RECALLs, F1SCOREs, MSEs,
         filo_N_diffs, filo_N_abs_diffs, filo_len_diffs, filo_len_abs_diffs,
         single_filo_N_diff, single_filo_N_abs_diff,
         merged_filo_N_diff, merged_filo_N_abs_diff]

for i in range(len(IOUs)):
    values = [lst[i] for lst in lists]
    print(", ".join(map(str, values)))


In [None]:
print("IOU", np.mean(IOUs), "±", np.std(IOUs))
print("DICE", np.mean(DICEs), "±", np.std(DICEs))
print("PRECISION", np.nanmean(PRECISIONs), "±", np.nanstd(PRECISIONs))
print("RECALL", np.mean(RECALLs), "±", np.std(RECALLs))
print("F1", np.nanmean(F1SCOREs), "±", np.nanstd(F1SCOREs))
print("MSE", np.mean(MSEs), "±", np.std(MSEs))
print("Filo # difference", np.mean(filo_N_diffs), "±", np.std(filo_N_diffs))
print("Filo # abs difference", np.mean(filo_N_abs_diffs), "±", np.std(filo_N_abs_diffs))
print("Filo len difference", np.mean(filo_len_diffs), "±", np.std(filo_len_diffs))
print("Filo len abs difference", np.mean(filo_len_abs_diffs), "±", np.std(filo_len_abs_diffs))
print("Single filo # diff", np.mean(single_filo_N_diff), "±", np.std(single_filo_N_diff))
print("Single filo # abs diff", np.mean(single_filo_N_abs_diff), "±", np.std(single_filo_N_abs_diff))
print("Fused filo # diff", np.mean(merged_filo_N_diff), "±", np.std(merged_filo_N_diff))
print("Fused filo # abs diff", np.mean(merged_filo_N_abs_diff), "±", np.std(merged_filo_N_abs_diff))
print(np.mean(IOUs), "±", np.std(IOUs), ",",
        np.mean(DICEs), "±", np.std(DICEs), ",",
        np.nanmean(PRECISIONs), "±", np.nanstd(PRECISIONs), ",",
        np.mean(RECALLs), "±", np.std(RECALLs), ",",
        np.nanmean(F1SCOREs), "±", np.nanstd(F1SCOREs), ",",
        np.mean(MSEs), "±", np.std(MSEs), ",",
        np.mean(filo_N_diffs), "±", np.std(filo_N_diffs), ",",
        np.mean(filo_N_abs_diffs), "±", np.std(filo_N_abs_diffs), ",",
        np.mean(filo_len_diffs), "±", np.std(filo_len_diffs), ",",
        np.mean(filo_len_abs_diffs), "±", np.std(filo_len_abs_diffs), ",",
        np.mean(single_filo_N_diff), "±", np.std(single_filo_N_diff) , ",",
        np.mean(single_filo_N_abs_diff), "±", np.std(single_filo_N_abs_diff) , ",",
        np.mean(merged_filo_N_diff), "±", np.std(merged_filo_N_diff) , ",",
        np.mean(merged_filo_N_abs_diff), "±", np.std(merged_filo_N_abs_diff) , ",",)
means = [np.mean(IOUs), np.mean(DICEs), np.nanmean(PRECISIONs), np.mean(RECALLs), np.nanmean(F1SCOREs), np.mean(MSEs), np.mean(filo_N_diffs), np.mean(filo_N_abs_diffs), np.mean(filo_len_diffs), np.mean(filo_len_abs_diffs), np.mean(single_filo_N_diff), np.mean(single_filo_N_abs_diff), np.mean(merged_filo_N_diff), np.mean(merged_filo_N_abs_diff)]
print(', '.join(['{:.3f}'.format(mean) for mean in means]))