In [1]:
from dataclasses import dataclass

import torch
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import InterpolationMode
from utils import *
import glob
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [2]:
import wandb

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msetupishe[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
@dataclass(frozen=True)
class DatasetConfig:
    USED_CLASSES = [0, 1, 2, 3, 6, 7, 8, 9]

    IMAGE_SIZE: tuple[int,int] = (224, 224) # W, H
    BACKGROUND_CLS_ID: int = 0
    DATASET_PATH: str = 'data/default_dataset'
    AUGS = [A.HorizontalFlip(p=0.5), 
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(scale_limit=0.12, rotate_limit=0.15, shift_limit=0.12, p=0.5),
            ]


In [8]:
class OrgansDataset(Dataset):
    def __init__(self,
                dataset_path: str, 
                img_size: int,
                augs: List | None = None,
                cache: bool = False,
                clip_min: int | None = None,
                clip_max: int | None = None,
                ):
        super().__init__()
        self.use_cache = cache
        self.img_size = img_size
        self.images = []
        self.labels = []
        self.clip_min = clip_min
        self.clip_max = clip_max

        for img_path in glob.glob(dataset_path + '/**/*img.npy', recursive=True):
            lbl_path = img2label(img_path)
            self.images.append(load_npy(img_path) if self.use_cache else img_path)
            self.labels.append(load_npy(img_path) if self.use_cache else lbl_path)

        transforms = []
        if augs is not None:
            transforms.extend(augs)
        transforms.extend([
                        A.Resize(self.img_size, self.img_size, always_apply=True),
                        ToTensorV2(always_apply=True)
                    ])
        self.transforms = A.Compose(transforms)

        
        
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index: int):
        image = self.images[index]
        label = self.labels[index]

        if not self.use_cache:
            image = load_npy(image)
            label = load_npy(label)
        
        image = normalize(image, 
                          min_val=self.clip_min,
                          max_val=self.clip_max,
                          )
        image = np.expand_dims(image, 2)
        transformed = self.transforms(image=image, mask=label)
        image, label = transformed["image"], transformed["mask"].to(torch.long)
        return image, label

In [9]:
dataset = OrgansDataset('data/default_dataset', 224, augs = DatasetConfig.AUGS)

In [10]:
dataloader = DataLoader(dataset, batch_size=16)

In [11]:
for image, label in dataloader: 
    print(image.shape)
    print(label.shape)
    print('---')

torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 224])
torch.Size([16, 224, 224])
---
torch.Size([16, 1, 224, 

In [4]:
image = load_npy('data/default_dataset/images/amos_id0001_slice1_img.npy')

In [5]:
image.shape

(533, 651)

In [7]:
np.max(image)

1169.8229