In [5]:
from data_processing import SDTDataset, GradientDataset
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms import v2 as transformsv2
from torch.utils.data import DataLoader
import torch
import numpy as np
from model import UNet
from instance_seg import salt_and_pepper_noise, gaussian_noise
import skimage.io as skio
from skimage.color import label2rgb as label
from skimage.filters import gaussian
import skimage.morphology as skm
import skimage.measure as skme
from PIL import Image
import tifffile
from model_evaluation import watershed_from_boundary_distance, get_inner_mask
from skimage.filters import threshold_otsu
from model_evaluation import evaluate

In [2]:
# load model from checkpoint
modelpath = '/localscratch/dl4mia-project-segmentation/logs/model.pth'
model = UNet(
        depth=4,
        in_channels=1,
        out_channels=1,
        final_activation="Sigmoid",
        num_fmaps=16,
        fmap_inc_factor=2,
        downsample_factor=2,
        padding="same",
        upsample_mode="nearest",
    )
checkpoint = torch.load(modelpath)
model.load_state_dict(checkpoint['model_state_dict'])
model.to('cuda')
model.eval()

UNet(
  (left_convs): ModuleList(
    (0): ConvBlock(
      (conv_pass): Sequential(
        (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (1): ConvBlock(
      (conv_pass): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (2): ConvBlock(
      (conv_pass): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (3): ConvBlock(
      (conv_pass): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 12

In [15]:
val_images = ["/group/dl4miacourse/projects/membrane/ecad_gfp/20240408_embryo_1/max_project_NG-ILE-C488-mp.tif",
                "/group/dl4miacourse/projects/membrane/ecad_gfp/20240419_embryo_1/max_project_rh-ILE-C488-D.tif"]
val_masks = ["/group/dl4miacourse/projects/membrane/ecad_gfp/20240408_embryo_1/cell_mesh_mask.tif",
                "/group/dl4miacourse/projects/membrane/ecad_gfp/20240419_embryo_1/cell_mesh_mask.tif"]
train_data = GradientDataset(transform=None, img_transform=None, train=True, center_crop=True, pad=256)


In [23]:
inp_transforms = transforms.Compose(
            [
                transforms.Grayscale(),
                transforms.ToTensor(),
                transforms.Normalize([train_data.mean], [train_data.std]),  # 0.5 = mean and 0.5 = variance
            ]
        )
precisions, recalls, accuracies, ious = [], [], [], []
for this_im_path, this_mask_path in zip(val_images, val_masks):
    im = tifffile.imread(this_im_path)
    msk = tifffile.imread(this_mask_path)

    for t, (this_slice, this_mask) in enumerate(zip(im, msk)):
        print(t)
        if np.sum(this_mask > 0) == 0:
            continue
        image = Image.fromarray(this_slice)
        gt_labels = Image.fromarray(this_mask)
        image = inp_transforms(image)
        image = image.to('cuda')

        # generate prediction from neural network
        with torch.no_grad():
            pred = model(image[np.newaxis, ...])

        image = np.squeeze(image.cpu())
        pred = np.squeeze(pred.cpu().detach().numpy())

        # Do watershed and compare to gt mask
        threshold = threshold_otsu(pred)

        # Get inner mask
        inner_mask = get_inner_mask(pred, threshold=threshold)

        # Get the segmentation
        seg = watershed_from_boundary_distance(1 - pred, 1 - inner_mask, min_seed_distance=20)
        seg = skm.binary_dilation(seg > 0, skm.disk(1))
        seg = skme.label(seg)
        new_seg = np.zeros(seg.shape)
        for this_region in np.unique(seg):
            if this_region == 0:
                continue
            region_mask = seg == this_region
            region_mask = skm.binary_dilation(region_mask, skm.disk(3))
            new_seg = np.where(region_mask, this_region, new_seg)
        seg = new_seg * (np.array(gt_labels) > 0)

        these_metrics = dict()
        precision, recall, accuracy, iou = evaluate(np.array(gt_labels).astype(np.uint16), seg.astype(np.uint16))
        precisions.append(precision)
        recalls.append(recall)
        accuracies.append(accuracy)
        ious.extend(iou)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


In [24]:
print(precisions)
print(recalls)
print(accuracies)
print(ious)

[0.5789473684210527, 0.6785714285714286, 0.625, 0.6410256410256411, 0.660377358490566, 0.65, 0.5806451612903226, 0.59375, 0.6470588235294118, 0.625, 0.6470588235294118, 0.640625, 0.7068965517241379, 0.6271186440677966, 0.6923076923076923, 0.676923076923077, 0.589041095890411, 0.6417910447761194, 0.7, 0.6923076923076923, 0.6666666666666666, 0.6721311475409836, 0.6507936507936508, 0.7096774193548387, 0.6909090909090909, 0.6727272727272727, 0.631578947368421, 0.6779661016949152, 0.6724137931034483, 0.6666666666666666, 0.6206896551724138, 0.6071428571428571, 0.6229508196721312, 0.6470588235294118, 0.6415094339622641, 0.62, 0.6326530612244898, 0.6666666666666666, 0.6976744186046512, 0.5660377358490566, 0.5681818181818182, 0.6, 0.5, 0.5, 0.55, 0.48484848484848486, 0.5833333333333334, 0.6363636363636364, 0.5483870967741935, 0.5666666666666667, 0.46153846153846156, 0.4074074074074074, 0.6551724137931034, 0.3888888888888889, 0.25, 0.3, 0.5, 0.46153846153846156, 0.75, 0.375, 0.4375, 0.7, 0.5, 0.

In [25]:
print(f"Precision: {np.mean(precisions)}")
print(f"Recall: {np.mean(recalls)}")
print(f"Accuracy: {np.mean(accuracies)}")
print(f"Mean IOU: {np.mean(ious)}")
ious = np.array(ious)
print(f"Mean TP IOU: {np.mean(ious[ious > 0.5])}")

Precision: 0.5452955754908374
Recall: 0.4256913612233855
Accuracy: 0.3226891798345228
Mean IOU: 0.5431439876556396
Mean TP IOU: 0.7780895829200745
