In [1]:
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

import rasterio

import warnings
warnings.filterwarnings("ignore")

In [27]:
# Define torch dataset Class
class Dataset(Dataset):
    def __init__(self,folder_path,dataset_file,sen2_amount=1):
        
        # define filepaths
        self.folder_path = folder_path
        # read file
        self.df = pd.read_pickle(dataset_file)
        # set amount of sen2 pictures that should be returned
        self.sen2_amount = sen2_amount
        
        
        # clear up DF
        self.df = self.df[self.df["sen2_no"]>2]
        try:
            self.df = self.df.drop(labels=["level_0"], axis=1)
        except KeyError:
            pass
        self.df = self.df.reset_index()
        
        
    def __len__(self):
        """
        Returns length of data
        """
        return(len(self.df))
    
 
    def __getitem__(self,idx):
        
        current = self.df.iloc[idx]
        spot6_file = current["spot6_filenames"]
        sen2_files = current["sen2_filenames"]
        
        """READ SPOT6"""
        #with rasterio.open(self.folder_path+"y/"+spot6_file) as dataset:
        spot6 = rasterio.open(self.folder_path+"y/"+spot6_file).read()

    
        """READ SEN2 SERIES"""
        # read first file
        sen2 = rasterio.open(self.folder_path+"x/"+sen2_files[0]).read()
        
        if self.sen2_amount>1:
            # read following sen2 and stack
            count=1
            for sen2_file in sen2_files[1:]:
                # read file as array
                sen2_following = rasterio.open(self.folder_path+"x/"+sen2_file).read()
                # stack to previous images
                sen2 = np.concatenate([sen2, sen2_following])

                # break if all wanted files loaded
                count=count+1
                if count==self.sen2_amount:
                    break
            # if final count not yet reached, repeat last chip until enough are there
            while count<self.sen2_amount:
                sen2 = np.concatenate([sen2, sen2_following])
                count=count+1
        
        # transform to tensor
        sen2  = torch.from_numpy(sen2)
        spot6 = torch.from_numpy(spot6)
        sen2 = sen2.float()
        spot6 = spot6.float()
        
        #print(len(sen2_files),sen2.size())
        
        # define transformer
        transform_spot = transforms.Compose([transforms.Normalize(mean=[479.0, 537.0, 344.0], std=[430.0, 290.0, 229.0]) ])
        # dynamically define transform to reflect shape of tensor
        trans_mean,trans_std = [78.0, 91.0, 62.0]*self.sen2_amount,[36.0, 28.0, 30.0]*self.sen2_amount
        transform_sen = transforms.Compose([transforms.Normalize(mean=trans_mean, std= trans_std)])
        # perform transform
        sen2  = transform_sen(sen2)
        spot6 = transform_spot(spot6)
        
        # return result
        return(spot6,sen2)


In [57]:
dataset = Dataset("data_f4/","data_f4_pkls/df_saved_images.pkl",sen2_amount=4)
loader  = DataLoader(dataset,batch_size=64, shuffle=True, num_workers=0,pin_memory=True,
                    drop_last=True,prefetch_factor=2)

In [58]:
%%timeit
for i in loader:
    a,b = i
    break

2.4 s ± 96.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
