In [9]:
img = '/mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image'
img_mask = '/mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image_mask'

In [None]:
from torch.utils.data import DataLoader
from transformers import MaskFormerImageProcessor

def get_preprocessor():
    # Create a preprocessor
    return MaskFormerImageProcessor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def collate_fn(batch, preprocessor=get_preprocessor()):
    inputs = list(zip(*batch))
    images = inputs[0]
    segmentation_maps = inputs[1]
    # this function pads the inputs to the same size,
    # and creates a pixel mask
    # actually padding isn't required here since we are cropping
    batch = preprocessor(
        images,
        segmentation_maps=segmentation_maps,
        return_tensors="pt",
    )

    batch["original_images"] = inputs[2]
    batch["original_segmentation_maps"] = inputs[3]
    
    return batch


  return MaskFormerImageProcessor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)


In [13]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        #for mask name append mask_ in front of image name
        mask_name = 'mask_'+img_name
        
        print(f'idx: {idx} img name :{img_name} mask_name :{mask_name}')
        
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        try:
            image = Image.open(img_path).convert("RGB")
            mask = Image.open(mask_path).convert("L")  # Assuming mask is in grayscale
        except:
            print(f"Error loading image: {img_path}")
            # Load a blank image
            image = Image.new("RGB", (256, 256))
            #Load a blank mask
            mask = Image.new("L", (256, 256))
            
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to the desired size
    transforms.ToTensor()           # Convert to PyTorch tensor
])

# Set directories
image_dir = img
mask_dir = img_mask

# Load dataset
dataset = SegmentationDataset(image_dir, mask_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

# Iterate through the DataLoader
for images, masks in dataloader:
    print(images.shape, masks.shape)
    # Here you would pass images and masks to your model for training


idx: 17473 img name :Prunus avium0002454.jpg mask_name :mask_Prunus avium0002454.jpgidx: 4662 img name :Betula pendula0000005.jpg mask_name :mask_Betula pendula0000005.jpgidx: 20916 img name :Sorbus aucuparia0002654.jpg mask_name :mask_Sorbus aucuparia0002654.jpg


idx: 14383 img name :Fraxinus excelsior0002494.jpg mask_name :mask_Fraxinus excelsior0002494.jpg
Error loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image/Sorbus aucuparia0002654.jpg
idx: 952 img name :Acer pseudoplatanus0000943.jpg mask_name :mask_Acer pseudoplatanus0000943.jpgError loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image/Prunus avium0002454.jpg

idx: 15117 img name :Prunus avium0000097.jpg mask_name :mask_Prunus avium0000097.jpg
Error loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image/Betula pendula0000005.jpg
idx: 9990 img name :Fagus sylvatica0001347.jpg mask_name :mask_Fagus sylvatica0001347.jpg
Error loading image: /mnt/

KeyboardInterrupt: 


Error loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image/Betula pendula0001191.jpg
idx: 17048 img name :Prunus avium0002029.jpg mask_name :mask_Prunus avium0002029.jpg
Error loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image/Prunus avium0001822.jpgError loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image/Acer pseudoplatanus0002969.jpg

idx: 1440 img name :Acer pseudoplatanus0001431.jpg mask_name :mask_Acer pseudoplatanus0001431.jpgidx: 15322 img name :Prunus avium0000302.jpg mask_name :mask_Prunus avium0000302.jpg

Error loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image/Prunus avium0002029.jpg
idx: 10859 img name :Fagus sylvatica0002185.jpg mask_name :mask_Fagus sylvatica0002185.jpg
Error loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_Folder/image/Fagus sylvatica0001535.jpgError loading image: /mnt/gsdata/projects/panops/Labeled_data_seprated_in_

In [None]:
def build_loader(config):
    
    train_dataset = config.train
    train_config = config.meta[train_dataset]

    train_dat, val_dat, test_dat, columns = parse_dataset(train_config)


    dataset_train = build_dataset(is_train=True, dataframe=train_dat, config=config, col=columns, modality=modality)
    print('successfully build train dataset')

    dataset_test = build_dataset(is_train=False, dataframe=test_dat, config=config, col=columns, modality=modality)
    print(f'local rank {local_rank} / global rank {dist.get_rank()} \
        successfully build test dataset')
    
    
    
    sampler_test = torch.utils.data.SequentialSampler(dataset_test)
    
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=config.batch_size,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn ### NOTEL THIS ###
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        sampler=sampler_test,
        batch_size=config.batch_size,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn ### NOTEL THIS ###
    )


    return dataset_train, data_loader_train
