In [27]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import cv2
import h5py
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os 
import pandas as pd
from torchvision import datasets, transforms
from PIL import Image
import utils
import numpy as np

In [45]:
root = './'

In [46]:
# for i in open(os.path.join(root,'5k_Healthy.txt')):
#     print(i.split("\\")[-1][:-5])

In [158]:
class HnABRATS(torch.utils.data.Dataset):
    def __init__(self, ROOT_DIR, img_size = (256, 256), is_train = True, transform = None, num_mod = 1):     
        self.D = num_mod
        self.transform = transforms.Compose(
            [transforms.ToPILImage(),        
             transforms.Resize(img_size, transforms.InterpolationMode.BILINEAR),
             # transforms.CenterCrop(256),
             transforms.ToTensor(),
             transforms.Normalize((0.5), (0.5))
            ]
        ) if not transform else transform

        self.root = os.path.join(ROOT_DIR, "DATASETS/brats2021")
        if is_train:
            self.filenames = [i for i in open(os.path.join(self.root,'Healthy.txt'))]
        else:
            self.filenames = [i for i in open(os.path.join(self.root,'Anomaly.txt'))]
            self.mask_filenames = [i for i in open(os.path.join(self.root,'Mask.txt'))]
        self.img_size = img_size
        self.is_train = is_train
        if ".DS_Store" in self.filenames:
            self.filenames.remove(".DS_Store")

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = np.load(self.filenames[idx][:-1])
#         print(image)
        if self.D == 4:
            if self.transform:
                for i in range(image.shape[0]):
                    image1 = self.transform(image[0])
                    image2 = self.transform(image[1])
                    image3 = self.transform(image[2])
                    image4 = self.transform(image[3])
                image = torch.stack([image1, image2, image3, image4], dim = 1)
        else:
            image = self.transform(image)
        if not self.is_train:
            mask = np.load(self.mask_filenames[idx][:-1])
            sample = {'image': image, "filenames": self.filenames[idx].split("/")[-1][:-4], "mask":mask}
        else:
            sample = {'image': image, "filenames": self.filenames[idx].split("/")[-1][:-4]}
        return sample

    def __len__(self):
        return len(self.filenames)

In [174]:
class AnnoDataset(torch.utils.data.Dataset):
    def __init__(self, ROOT_DIR, set = "train", input_size = 256, channels = 1):
        super().__init__()
        self.root = os.path.join(ROOT_DIR, "DATASETS/brainMRI")
        self.se = set
        self.input_size = input_size
        self.get_directory_paths()
        self.channels = channels
    def __getitem__(self, idx):
        if self.se == "train":
            if self.channels==3:
                img = Image.open(self.file_paths[idx]).convert('RGB')
            else:
                img = Image.open(self.file_paths[idx]).convert('L')
            img = np.array(img)
            filenames = self.file_names[idx]
#             if len(img.shape) == 2:
#                 img = np.stack([img] * 3, 2)
            img = Image.fromarray(img)
            img = transforms.Resize(self.input_size)(img)
#             img = transforms.RandomResizedCrop(size=self.input_size,scale=(0.4, 0.75),ratio=(0.5,1.5))(img)
#             img = transforms.RandomCrop(self.input_size)(img)
            img = transforms.RandomHorizontalFlip()(img)
            img = transforms.ColorJitter(brightness=0.2, contrast=0.2)(img)
            img = transforms.ToTensor()(img)
#             img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
            mask = torch.zeros(self.input_size)
        elif self.se == 'test':
            if self.file_names[idx] in self.mfile_names:
                index = self.mfile_names.index(self.file_names[idx])
                mask = Image.open(self.mfile_paths[index])
                mask = np.array(mask)
                if self.channels==3:
                    mask = np.stack((mask,) * self.channels, axis=-1)
                mask = Image.fromarray(mask)             
                mask = transforms.Resize(self.input_size)(mask)
                mask = transforms.ToTensor()(mask)
                mask = torch.where(mask>0, 1.0, 0.0)
#                 print("mask1", mask.shape)
            else:
                mask = torch.zeros(self.channels, *self.input_size)
#                 print("mask2", mask.shape)
            if self.channels==3:
                img = Image.open(self.file_paths[idx]).convert('RGB')
            else:
                img = Image.open(self.file_paths[idx]).convert('L')
#             img = Image.open(self.file_paths[idx]).convert('L')
            img = np.array(img)
            filenames = self.file_names[idx]
#             if len(img.shape) == 2:
#                 img = np.stack([img] * 3, 2)
            img = Image.fromarray(img)
            img = transforms.Resize(self.input_size)(img)
#             img = transforms.RandomResizedCrop(size=self.input_size,scale=(0.4, 0.75),ratio=(0.5,1.5))(img)
#             img = transforms.RandomCrop(self.input_size)(img)
#             img = transforms.RandomHorizontalFlip()(img)
#             img = transforms.ColorJitter(brightness=0.2, contrast=0.2)(img)
            img = transforms.ToTensor()(img)
#             img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
#             target  = None
#         print(img.max())
#         print(img.min())
        return  {'image': img, "filenames": filenames, "mask": mask, "label" : 1 if mask.max()>0 else 0}
    def get_directory_paths(self):
        self.file_paths = []
        self.file_names = []
        if self.se == 'train':
            root_dir = os.path.join(self.root, self.se)
            for rout, dirs, files in os.walk(root_dir):
                if not dirs:
                    for f in files:
                        self.file_paths.append(os.path.join(rout, f))
                        self.file_names.append(f)
        elif self.se == 'test':
            self.mfile_paths = []
            self.mfile_names = []
            root_dir = os.path.join(self.root, self.se)
            mask_dir = os.path.join(self.root, "ground_truth")
            for rout, dirs, files in os.walk(root_dir):
                if not dirs:
                    for f in files:
                        self.file_paths.append(os.path.join(rout, f))
                        self.file_names.append(f)
            for rout, dirs, files in os.walk(mask_dir):
                if not dirs:
                    for f in files:
                        self.mfile_paths.append(os.path.join(rout, f))
                        self.mfile_names.append(f)    
    def __len__(self):
        return len(self.file_paths)

In [164]:
Data = HnABRATS(root, is_train=True)

In [165]:
len(Data)

92631

In [171]:
(0.0337*11.8) /104.4

0.0038090038314176245

In [187]:
data = AnnoDataset(root, set='test')

In [188]:
len(data)

154

In [178]:
import numpy as np

# Assuming dataset is your list of images and masks
total_abnormal_pixels = 0
total_pixels = 0

for j,i in enumerate(data):
    mask = np.float16(i["mask"]>0)   #[0][0]
    abnormal_pixel_count = np.sum(mask)
    total_abnormal_pixels += abnormal_pixel_count
    total_pixels += mask.size
    print(j)

# Calculate average percentage
average_percentage_abnormal = (total_abnormal_pixels / total_pixels) * 100

print(f"Average percentage of abnormal pixels: {average_percentage_abnormal:.2f}%")

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
Average percentage of abnormal pixels: 7.31%


In [193]:
total_pixels/153

80400.73202614379

In [196]:
3.72*(5070/11477)

1.6433214254596151

In [194]:
total_abnormal_pixels/154

5837.012987012987

In [136]:
path = r"C:\Users\Admin\Documents\Anomaly Detection\AnoDDPM\DATASETS\brats2021\traning_data\BraTS2021_00005\BraTS2021_00005+_Healthy_+30.npy"

In [137]:
img = np.load(path)

In [139]:
img.max()

0.79723364

In [77]:
np.sum(i["mask"][0])

0.0

In [85]:
np.float16(i["mask"]>0).max()

1.0

In [61]:
for i in Data:
    print(i["mask"].shape)

torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1, 4, 256, 256])
torch.Size([1,

KeyboardInterrupt: 

In [62]:
i['mask'][200]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.], dtype=float32)

In [63]:
i['mask'].shape

(240, 155)

In [64]:
i['mask'].max()

4.0