In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
import numpy as np
import os
from PIL import Image
from glob import glob
import cv2
import matplotlib.pyplot as plt



In [2]:
class HyperspectralDataset(Dataset):
    def __init__(self, cube_dir, mask_dir, window_size):
        self.cube_dir = cube_dir
        self.mask_dir = mask_dir
        self.window_size = window_size
        # list subfolders starting with E
        self.cube_files = [i.split('\\')[-1] for i in glob(os.path.join(cube_dir, 'E*'))]
        self.current_cube = None
        self.current_mask = None
        self.current_cube_index = -1
        self.image_shape = None
        self.window_indices = self.prepare_window_indices()
        

    def prepare_window_indices(self):
        window_indices = []
        for cube_index, cube_file in enumerate(self.cube_files):
            cube_path = os.path.join(self.cube_dir, cube_file, 'hsi.npy')
            cube = np.load(cube_path)
            #cube = np.transpose(cube, (1, 0, 2))
            self.image_shape = cube.shape[:2]
            
            num_windows_x = cube.shape[0] // self.window_size
            num_windows_y = cube.shape[1] // self.window_size

            for i in range(num_windows_x):
                for j in range(num_windows_y):
                    window_indices.append((cube_index, i * self.window_size, j * self.window_size))
        return window_indices

    def load_cube(self, cube_index):
        if cube_index != self.current_cube_index:
            cube_path = os.path.join(self.cube_dir, self.cube_files[cube_index])
            
            mask_all = np.zeros(self.image_shape)
            hsi_masks = glob(os.path.join(self.mask_dir, self.cube_files[cube_index], 'hsi_masks/*bmp'))
            for mask_file in hsi_masks:
                # load image with PIL
                mask = Image.open(mask_file)
                mask = np.array(mask)
                mask = cv2.resize(mask, (self.image_shape[1], self.image_shape[0]))
                mask_all = np.logical_or(mask_all, mask)
                
            self.current_cube = np.load(os.path.join(cube_path, 'hsi.npy'))
            self.current_mask = mask_all
            self.current_cube_index = cube_index

    def __len__(self):
        return len(self.window_indices)

    def __getitem__(self, idx):
        cube_index, i, j = self.window_indices[idx]
        self.load_cube(cube_index)

        window = self.current_cube[i:i+self.window_size, j:j+self.window_size, :]
        window_mask = self.current_mask[i:i+self.window_size, j:j+self.window_size]

        return window, window_mask


class HyperspectralDataModule(LightningDataModule):
    def __init__(self, cube_dir, mask_dir, window_size, batch_size):
        super().__init__()
        self.cube_dir = cube_dir
        self.mask_dir = mask_dir
        self.window_size = window_size
        self.batch_size = batch_size

    def train_dataloader(self):
        train_dataset = HyperspectralDataset(self.cube_dir, self.mask_dir, self.window_size)
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)

In [3]:
# Example usage
cube_dir = '../../biocycle/data/processed/bcd_val/data/'
mask_dir = '../../biocycle/data/processed/bcd_val/data/'
window_size = 7
batch_size = 2

data_module = HyperspectralDataModule(cube_dir, mask_dir, window_size, batch_size)
train_loader = data_module.train_dataloader()

In [4]:
train_dataset = HyperspectralDataset(cube_dir, mask_dir, window_size)
window, window_mask = train_dataset.__getitem__(0)

In [5]:
window.shape

(7, 7, 224)

In [6]:
window_mask.shape

(7, 7)

In [7]:
window_mask

array([[False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False]])

In [8]:
train_dataset.current_mask.shape

(679, 461)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(train_dataset.current_mask)

<matplotlib.image.AxesImage at 0x20c48feab20>