In [1]:
from data_processing import SDTDataset, GradientDataset
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms import v2 as transformsv2
from torch.utils.data import DataLoader
import torch
import numpy as np
from model import UNet
from instance_seg import salt_and_pepper_noise, gaussian_noise
import skimage.io as skio
from skimage.color import label2rgb as label
from skimage.filters import gaussian
import skimage.morphology as skm
import skimage.measure as skme
from PIL import Image
import tifffile
from model_evaluation import watershed_from_boundary_distance, get_inner_mask
from skimage.filters import threshold_otsu
from model_evaluation import evaluate

In [2]:
# load model from checkpoint
modelpath = '/localscratch/dl4mia-project-segmentation/logs/model_std_fmaps16.pth'
model = UNet(
        depth=4,
        in_channels=1,
        out_channels=1,
        final_activation="Tanh",
        num_fmaps=16,
        fmap_inc_factor=2,
        downsample_factor=2,
        padding="same",
        upsample_mode="nearest",
    )
checkpoint = torch.load(modelpath)
model.load_state_dict(checkpoint['model_state_dict'])
model.to('cuda')
model.eval()

UNet(
  (left_convs): ModuleList(
    (0): ConvBlock(
      (conv_pass): Sequential(
        (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (1): ConvBlock(
      (conv_pass): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (2): ConvBlock(
      (conv_pass): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (3): ConvBlock(
      (conv_pass): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 12

In [3]:
val_images = ["/group/dl4miacourse/projects/membrane/ecad_gfp/20240408_embryo_1/max_project_NG-ILE-C488-mp.tif",
                "/group/dl4miacourse/projects/membrane/ecad_gfp/20240419_embryo_1/max_project_rh-ILE-C488-D.tif"]
val_masks = ["/group/dl4miacourse/projects/membrane/ecad_gfp/20240408_embryo_1/cell_mesh_mask.tif",
                "/group/dl4miacourse/projects/membrane/ecad_gfp/20240419_embryo_1/cell_mesh_mask.tif"]
train_data = SDTDataset(transform=None, img_transform=None, train=True, center_crop=True, pad=256)


In [4]:
inp_transforms = transforms.Compose(
            [
                transforms.Grayscale(),
                transforms.ToTensor(),
                transforms.Normalize([train_data.mean], [train_data.std]),  # 0.5 = mean and 0.5 = variance
            ]
        )
precisions, recalls, accuracies, ious = [], [], [], []
for this_im_path, this_mask_path in zip(val_images, val_masks):
    im = tifffile.imread(this_im_path)
    msk = tifffile.imread(this_mask_path)

    for t, (this_slice, this_mask) in enumerate(zip(im, msk)):
        print(t)
        if np.sum(this_mask > 0) == 0:
            continue
        image = Image.fromarray(this_slice)
        gt_labels = Image.fromarray(this_mask)
        image = inp_transforms(image)
        image = image.to('cuda')

        # generate prediction from neural network
        with torch.no_grad():
            pred = model(image[np.newaxis, ...])

        image = np.squeeze(image.cpu())
        pred = np.squeeze(pred.cpu().detach().numpy())

        # Do watershed and compare to gt mask
        threshold = threshold_otsu(pred)

        # Get inner mask
        inner_mask = get_inner_mask(pred, threshold=threshold)

        # Get the segmentation
        seg = watershed_from_boundary_distance(pred, inner_mask, min_seed_distance=20)
        
        precision, recall, accuracy, iou = evaluate(np.array(gt_labels).astype(np.uint16), seg.astype(np.uint16))
        precisions.append(precision)
        recalls.append(recall)
        accuracies.append(accuracy)
        ious.extend(iou)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


In [5]:
print(precisions)
print(recalls)
print(accuracies)
print(ious)

[0.0021660649819494585, 0.008432888264230498, 0.004817618719889883, 0.00808080808080808, 0.00903225806451613, 0.011132940406024885, 0.011435832274459974, 0.01464035646085296, 0.013174404015056462, 0.01507537688442211, 0.01193467336683417, 0.014769230769230769, 0.015537600994406464, 0.015413070283600493, 0.016826923076923076, 0.01586333129957291, 0.016666666666666666, 0.01542111506524318, 0.014384349827387802, 0.019469026548672566, 0.01571594877764843, 0.013038548752834467, 0.010863350485991996, 0.012485811577752554, 0.012535612535612535, 0.011494252873563218, 0.008527572484366117, 0.011185682326621925, 0.011563876651982379, 0.012345679012345678, 0.010451045104510451, 0.009846827133479213, 0.011500547645125958, 0.007991475759190196, 0.008705114254624592, 0.008034279592929834, 0.007454739084132056, 0.007772020725388601, 0.006144393241167435, 0.0056036678553234845, 0.008713480266529985, 0.008655804480651732, 0.007231404958677686, 0.005549949545913219, 0.006069802731411229, 0.0055527511357

In [6]:
print(f"Precision: {np.mean(precisions)}")
print(f"Recall: {np.mean(recalls)}")
print(f"Accuracy: {np.mean(accuracies)}")
print(f"Mean IOU: {np.mean(ious)}")
ious = np.array(ious)
print(f"Mean TP IOU: {np.mean(ious[ious > 0.5])}")

Precision: 0.008975089736764131
Recall: 0.268390935878355
Accuracy: 0.00879038149505176
Mean IOU: 0.02161993272602558
Mean TP IOU: 0.6578453183174133
