In [9]:
import torch
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import InterpolationMode
from utils import *
import glob

In [11]:
type(InterpolationMode.NEAREST)

<enum 'InterpolationMode'>

In [3]:
a = torch.tensor([1, 2])
type(a)

torch.Tensor

In [16]:
class OrgansDataset(Dataset):
    def __init__(self,
                dataset_path: str, 
                img_size: int,
                cache: bool=False):
        super().__init__()
        self.use_cache = cache
        self.img_size = img_size
        self.images = []
        self.labels = []

        for img_path in glob.glob(dataset_path + '/**/*img.npy', recursive=True):
            lbl_path = img2label(img_path)
            self.images.append(load_npy(img_path) if self.use_cache else img_path)
            self.labels.append(load_npy(img_path) if self.use_cache else lbl_path)

    def preprocess(self, 
                   img: np.ndarray,
                   interpolation: InterpolationMode=InterpolationMode.BILINEAR
                    ):
        
        transform = T.Compose([
                        T.ToTensor(),
                        T.Resize((self.img_size, self.img_size), interpolation=interpolation),
                    ])
        return transform(img)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index: int):
        image = self.images[index]
        label = self.labels[index]

        if not self.use_cache:
            image = load_npy(image)
            label = load_npy(label)

        image = self.preprocess(image)
        image = T.Normalize
        label = self.preprocess(label, InterpolationMode.NEAREST)
        print(torch.unique(image))

        return image, label

In [17]:
dataset = OrgansDataset('data/default_dataset', 224)

In [18]:
dataloader = DataLoader(dataset, batch_size=16)

In [19]:
for batch in dataloader: 
    pass

tensor([-1024.0002, -1024.0001, -1024.0000,  ...,   420.9307,   421.2984,
          421.5117])
tensor([-1873.8379, -1873.7534, -1873.5381,  ...,  1029.8955,  1086.5037,
         1092.3683])
tensor([-3024.0005, -3024.0002, -3024.0000,  ...,   486.0504,   570.1536,
          584.2993])
tensor([-1018.2236, -1016.5284, -1014.6576,  ...,  1164.7253,  1175.2423,
         1191.4984])
tensor([-1020.2327, -1019.1990, -1017.5330,  ...,  1405.8142,  1405.8909,
         1406.8896])
tensor([-2048.0005, -2048.0002, -2048.0000,  ...,  1504.4927,  1506.8960,
         1527.8000])
tensor([-1021.4098, -1020.0308, -1019.3796,  ...,  1126.5269,  1130.2515,
         1141.6259])
tensor([-1027.4874, -1027.3951, -1027.2946,  ...,  1210.2180,  1215.5945,
         1272.5886])
tensor([-1015.2702, -1011.0481, -1008.9150,  ...,  1275.6802,  1320.1464,
         1343.1035])
tensor([-1037.6245, -1031.9197, -1031.8564,  ...,   439.4113,   440.6379,
          520.1743])
tensor([-1015.6625, -1015.4820, -1012.3845,  ..., 

KeyboardInterrupt: 

In [4]:
image = load_npy('data/default_dataset/images/amos_id0001_slice1_img.npy')

In [5]:
image.shape

(533, 651)

In [7]:
np.max(image)

1169.8229