<a href="https://colab.research.google.com/github/taravatp/Multi_Spectral_Image_Segmentation/blob/main/MSI_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# cd /content/drive/MyDrive/Vision_Impulse_Task

In [None]:
!pip install rasterio

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage import io
import rasterio

import torch
from torch.utils.data import Dataset

import pickle

In [None]:
class MSI_data(Dataset):

  def __init__(self,flag='train',multispectral=True):
    super(MSI_data, self).__init__()

    self.multispectral =multispectral

    with open('inputs.pkl','rb') as open_file:
      images = pickle.load(open_file)

    with open('targets.pkl','rb') as open_file:
      masks = pickle.load(open_file)

    train_image_paths, rest_image_paths, train_mask_paths, rest_mask_paths = train_test_split(images,masks,test_size=0.2,random_state=42)
    test_image_paths, validation_image_paths, test_mask_paths, validation_mask_paths = train_test_split(rest_image_paths,rest_mask_paths, test_size=0.5, random_state=42)

    if flag == 'train':
      self.images = train_image_paths
      self.masks = train_mask_paths
    elif flag == 'validation':
      self.images = validation_image_paths
      self.masks = validation_mask_paths
    elif flag == 'test':
      self.images = test_image_paths
      self.masks = test_mask_paths

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    image_path = self.images[index]
    mask_path = self.masks[index]

    image = rasterio.open(image_path).read()
    mask = rasterio.open(mask_path).read()

    if self.multispectral == False:
      image = image[2:5,:,:] #fetching just the RGB channels

    #performin band-wise min-max normalization
    image = (image - np.min(image)) / (np.max(image) - np.min(image))

    mask[mask == 255] = 2 #cleaning

    image = torch.tensor(image.astype(np.float32))
    mask = torch.tensor(mask.astype(np.float32))

    return image, mask

In [None]:
if __name__ == "__main__":
  train_data = MSI_data(flag='train',multispectral=False)
  validation_data = MSI_data(flag='validation',multispectral=False)
  test_data = MSI_data(flag='test',multispectral=False)
  BATCHSIZE = 32

  train_dataloader = torch.utils.data.DataLoader(train_data,batch_size=BATCHSIZE,shuffle=True)
  validation_dataloader = torch.utils.data.DataLoader(validation_data,batch_size=BATCHSIZE,shuffle=True)
  test_dataloader = torch.utils.data.DataLoader(test_data,batch_size=BATCHSIZE,shuffle=True)

  image, mask = train_data[0]
  print('number of training samples:',len(train_dataloader)*BATCHSIZE)
  print('number of validation samples',len(validation_dataloader)*BATCHSIZE)
  print('number of test samples',len(test_dataloader)*BATCHSIZE)

number of training samples: 4576
number of validation samples 576
number of test samples 576
