In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class ADE20KDataset(Dataset):
    def __init__(self, root_dir, split='training', transform=None):
        """
        Args:
            root_dir (str): Path to ADE20K dataset
            split (str): 'training' or 'validation'
            transform (callable, optional): Transform to be applied on images
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        
        # ADE20K has 150 semantic classes
        self.num_classes = 150
        
        # Setup paths
        self.image_dir = os.path.join(root_dir, 'images', split)
        self.mask_dir = os.path.join(root_dir, 'annotations', split)
        
        # Get all image files
        self.images = sorted([f for f in os.listdir(self.image_dir) if f.endswith('.jpg')])
        
        # Define image transforms
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        self.resize = transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR)
        self.resize_mask = transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.NEAREST)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Load image
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', '.png'))
        
        # Read image and mask
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path)
        
        # Apply transforms
        image = self.resize(image)
        mask = self.resize_mask(mask)
        
        # Convert image to tensor and normalize
        image = self.to_tensor(image)
        image = self.normalize(image)
        
        # Convert mask to tensor
        mask = np.array(mask)
        mask = mask % 256  # Get class ID
        mask = torch.from_numpy(mask).long()
        
        return image, mask

def get_ade20k_loaders(root_dir, batch_size=16, num_workers=4):
    """
    Creates data loaders for ADE20K dataset
    
    Args:
        root_dir (str): Path to ADE20K dataset
        batch_size (int): Batch size for training
        num_workers (int): Number of workers for data loading
    """
    # Create datasets
    train_dataset = ADE20KDataset(
        root_dir=root_dir,
        split='training'
    )
    
    val_dataset = ADE20KDataset(
        root_dir=root_dir,
        split='validation'
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

In [29]:
# Example usage
root_dir = "ADEChallengeData2016"
train_loader, val_loader = get_ade20k_loaders(root_dir, batch_size=16)

# Iterate through the data
for images, masks in train_loader:
    # images.shape: [batch_size, 3, 512, 512]
    # masks.shape: [batch_size, 512, 512]
    # Process with your FPN model
    pass

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/root/miniconda3/envs/timm/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/root/miniconda3/envs/timm/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/root/miniconda3/envs/timm/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_3818/4114796173.py", line 52, in __getitem__
    image = torch.from_numpy(image.transpose((2, 0, 1))).float()
TypeError: expected np.ndarray (got numpy.ndarray)
