-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
80 lines (59 loc) · 2.08 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Dataset objects for MaCheX."""
import json
import os
from typing import (
Final,
Optional,
Dict,
)
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset
from torchvision.transforms import Compose, ToTensor
MACHEX_PATH: Final = 'machex_dataset'
class ChestXrayDataset(Dataset):
"""Class for handling datasets in the MaCheX composition."""
def __init__(self, root: str, transforms: Optional[Compose] = None) -> None:
"""Initialize ChestXrayDataset."""
self.root = root
json_path = os.path.join(self.root, 'index.json')
self.index_dict = ChestXrayDataset._load_json(json_path)
self.keys = list(self.index_dict.keys())
if transforms is None:
self.transforms = ToTensor()
else:
self.transforms = transforms
@staticmethod
def _load_json(file_path: str) -> Dict:
"""Load a json file as dictionary."""
with open(file_path, 'r') as f:
return json.load(f)
def __len__(self):
"""Return length of the dataset."""
return len(self.keys)
def __getitem__(self, idx: int) -> Dict:
"""Get dataset element."""
meta = self.index_dict[self.keys[idx]]
img = Image.open(meta['path'])
img = self.transforms(img)
return {'img': img}
class MaCheXDataset(Dataset):
"""Massive chest X-ray dataset."""
def __init__(self, root: str, transforms: Optional[Compose] = None) -> None:
"""Initialize MaCheXDataset"""
self.root = root
sub_dataset_roots = os.listdir(self.root)
datasets = [
ChestXrayDataset(root=os.path.join(root, r), transforms=transforms)
for r in sub_dataset_roots
]
self.ds = ConcatDataset(datasets)
def __len__(self):
"""Return length of the dataset."""
return len(self.ds)
def __getitem__(self, idx: int) -> Dict:
"""Get dataset element."""
return self.ds[idx]
if __name__ == '__main__':
machex = MaCheXDataset(MACHEX_PATH)
print(len(machex))
print(machex[1337])