In [55]:
import matplotlib
import torch
import torch.nn as nn
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
from torchsummary import summary
import lightning as L
from pytorch_lightning import LightningDataModule
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import Dataset, DataLoader
from glob import glob
import os
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import spectral
from sklearn.decomposition import PCA

In [83]:
class HyperspectralDataset(Dataset):
    def __init__(self, cube_dir, mask_dir, window_size):
        self.cube_dir = cube_dir
        self.mask_dir = mask_dir
        self.window_size = window_size
        self.p = window_size // 2
        # list subfolders starting with E
        self.cube_files = [i.split('\\')[-1] for i in glob(os.path.join(cube_dir, 'E*'))]
        self.current_cube = None
        self.current_mask = None
        self.current_cube_index = -1
        self.image_shape = None
        self.n_pc = 15
        self.window_indices = self.prepare_window_indices()
        self.gradient_mask = self.get_gradient_mask()


    def prepare_window_indices(self):
        window_indices = []
        for cube_index, cube_file in enumerate(self.cube_files):
            cube_path = os.path.join(self.cube_dir, cube_file, 'hsi.npy')
            cube = np.load(cube_path)
            #cube = np.transpose(cube, (1, 0, 2))
            self.image_shape = cube.shape[:2]

            num_windows_x = cube.shape[0] // self.window_size
            num_windows_y = cube.shape[1] // self.window_size

            for i in range(num_windows_x):
                for j in range(num_windows_y):
                    window_indices.append((cube_index, i * self.window_size, j * self.window_size))
        return window_indices

    def load_cube(self, cube_index):
        if cube_index != self.current_cube_index:
            cube_path = os.path.join(self.cube_dir, self.cube_files[cube_index])

            mask_all = np.zeros(self.image_shape)
            hsi_masks = glob(os.path.join(self.mask_dir, self.cube_files[cube_index], 'hsi_masks/*bmp'))
            # TODO: if more classes are introduced mask needs to be int and not bool
            for mask_file in hsi_masks:
                # load image with PIL
                mask = Image.open(mask_file)
                mask = np.array(mask)
                mask = cv2.resize(mask, (self.image_shape[1], self.image_shape[0]))
                mask_all = np.logical_or(mask_all, mask)

            self.current_cube = np.load(os.path.join(cube_path, 'hsi.npy'))
            self.pre_process_cube()
            self.current_cube = np.pad(self.current_cube, ((self.p, self.p), (self.p, self.p), (0, 0)), mode='constant', constant_values=0)

            self.current_mask = mask_all
            self.current_mask = np.pad(self.current_mask, ((self.p, self.p), (self.p, self.p)), mode='constant', constant_values=0)
            self.current_mask = self.current_mask.astype(int)
            self.current_cube_index = cube_index
        
        else:
            pass
            
    def pre_process_cube(self):
        # TODO: implement pca, random occlusion, gradient masking (if necessary)
        self.crop_bands()
        self.remove_background()
        self.apply_pca()
        pass
        
    def apply_pca(self):
        """Applies PCA to cube and reduces number of bands to 15.
        
        Note: 
            - Function assumes that edge bands are removed, i.e. spectra are cropped. 
        """
        
        x = self.current_cube.reshape(-1, self.current_cube.shape[-1])
        x = x.astype(float)
        
        # TODO: check for optimal number of components
        pca = PCA(n_components=self.n_pc)
        x = pca.fit_transform(x)
        x = x.reshape(self.current_cube.shape[0], self.current_cube.shape[1], 15)
        
        self.current_cube = x
        
    def remove_background(self):
        """Sets spectra with mean intensity below 600 to zero on all bands. Treats overall low intensity spectra as background. 
        
        Note: 
            - Changes in light intensity between cubes are not considered.
            - Function assumes that edge bands are removed, i.e. spectra are cropped. 
        """
        mean_intensity = np.mean(self.current_cube, axis=2)
        self.current_cube[mean_intensity < 600] = 0
        
    def crop_bands(self):
        """Removes bands 0-8 and 210-224. Assumes cube is of shape (w, h, 224).
        
        Note: 
            - Function assumes that edge bands are removed, i.e. spectra are cropped. 
        """
        self.current_cube = self.current_cube[:, :, 8:210]
        
    
    def apply_gradient_mask(self, window):
        """Applies gradient mask to window as described in https://www.mdpi.com/2072-4292/15/12/3123"""
        return window * self.gradient_mask

    def get_gradient_mask(self):
        s = self.window_size
        p = self.n_pc
        center = (s + 1) / 2
        mask = np.zeros((s, s))
        for i in range(s):
            for j in range(s):
                mask[i, j] = 1 - ((i - center + 1)**2 + (j - center + 1)**2) / (2 * center**2)

        mask = np.expand_dims(mask, axis=2)
        mask = np.repeat(mask, p, axis=2)
        
        return mask


    def __len__(self):
        return len(self.window_indices)

    def __getitem__(self, idx):
        cube_index, i, j = self.window_indices[idx]
        self.load_cube(cube_index)
        window = self.current_cube[i:i+self.window_size, j:j+self.window_size, :].astype(np.float32)
        window = np.transpose(window, (2, 0, 1))
        window_mask = self.current_mask[i:i+self.window_size, j:j+self.window_size]

        return window, window_mask


class HyperspectralDataModule(LightningDataModule):
    def __init__(self, cube_dir, mask_dir, window_size, batch_size):
        super().__init__()
        self.cube_dir = cube_dir
        self.mask_dir = mask_dir
        self.window_size = window_size
        self.batch_size = batch_size

    def train_dataloader(self):
        train_dataset = HyperspectralDataset(self.cube_dir, self.mask_dir, self.window_size)
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def test_dataloader(self):
        test_dataset = HyperspectralDataset(self.cube_dir, self.mask_dir, self.window_size)
        return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

In [84]:
class HyperSN(L.LightningModule):
    """Implementation of HyperSN for Hyperspectral Cubes (3D Conv) from https://ieeexplore.ieee.org/document/8736016 based on https://github.com/Pancakerr/HybridSN/blob/master/HybridSN.ipynb
    """
    def __init__(self, in_channels, patch_size, class_nums):
        super().__init__()
        self.in_channels = in_channels
        self.patch_size = patch_size
        self.class_nums = class_nums
        
        self.conv1 = nn.Sequential(nn.Conv3d(1, out_channels=8, kernel_size=(7,23,23), padding=(1,1,1)),
                                   nn.BatchNorm3d(8),
                                   nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv3d(8, out_channels=16, kernel_size=(5,21,21), padding=(1,1,1)),
                                   nn.BatchNorm3d(16),
                                   nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv3d(16, out_channels=32, kernel_size=(3,19,19), padding=(1,1,1)),
                                   nn.BatchNorm3d(32),
                                   nn.ReLU(inplace=True))
        
        self.x1_shape = self.get_shape_after_3dconv()
        self.conv4 = nn.Sequential(nn.Conv2d(self.x1_shape[1]*self.x1_shape[2], out_channels=64, kernel_size=(3,3), padding=(1,1)),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True))
        
        self.x2_shape = self.get_shape_after_2dconv()
        
        self.dense1 = nn.Sequential(nn.Linear(self.x2_shape, 1024),
                                    nn.ReLU(inplace=True),
                                    nn.Dropout(0.4))
        self.dense2 = nn.Sequential(nn.Linear(1024, 128),
                                    nn.ReLU(inplace=True),
                                    nn.Dropout(0.4))
        self.dense3 = nn.Linear(128, self.class_nums)
        
    def training_step(self, batch, batch_idx):
        x, mask = batch
        y = mask[:, self.patch_size//2, self.patch_size//2].long()  # Convert to Long
        y = torch.nn.functional.one_hot(y, num_classes=self.class_nums).float()


        x.unsqueeze_(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4])
        x = self.conv4(x)
        x = x.contiguous().view(x.shape[0], -1)
        x = self.dense1(x)
        x = self.dense2(x)
        out = self.dense3(x)
        
        loss = nn.CrossEntropyLoss()(out, y)
        return loss
                                    

    def get_shape_after_2dconv(self):
        x = torch.zeros((1, self.x1_shape[1]*self.x1_shape[2], self.x1_shape[3], self.x1_shape[4]))
        with torch.no_grad():
            x = self.conv4(x)
        return x.shape[1]*x.shape[2]*x.shape[3]
    
    def get_shape_after_3dconv(self):
        x = torch.zeros((1, 1, self.in_channels, self.patch_size, self.patch_size))
        with torch.no_grad():
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.conv3(x)
        return x.shape
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val_loss'
        }
    
    
model = HyperSN(in_channels=15, patch_size=128, class_nums=2)

In [85]:
cube_dir = '../../biocycle/data/processed/bcd_val/data/'
mask_dir = '../../biocycle/data/processed/bcd_val/data/'
window_size = 128
batch_size = 1

In [86]:
data_module = HyperspectralDataModule(cube_dir, mask_dir, window_size, batch_size)
train_loader = data_module.train_dataloader()

In [87]:
# wandb_logger = WandbLogger(project='pixelclassifier_hyperSN', entity='biocycle', log_model="all")

In [88]:
# trainer = L.Trainer(max_epochs=1, logger=wandb_logger)
trainer = L.Trainer(max_epochs=1)
trainer.fit(model, train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | conv1  | Sequential | 29.6 K
1 | conv2  | Sequential | 282 K 
2 | conv3  | Sequential | 554 K 
3 | conv4  | Sequential | 166 K 
4 | dense1 | Sequential | 358 M 
5 | dense2 | Sequential | 131 K 
6 | dense3 | Linear     | 258   
--------------------------------------
360 M     Trainable params
0         Non-trainable params
360 M     Total params
1,440.161 Total estimated model params size (MB)
C:\Users\stude\anaconda3\envs\biocycling_px\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

C:\Users\stude\anaconda3\envs\biocycling_px\Lib\site-packages\lightning\pytorch\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
