In [None]:
INPUT_IMAGE_HEIGHT = 3024
INPUT_IMAGE_WIDTH = 672

class Composer:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image = t(image)
            target = t(target)
        target = torch.tensor(np.array(target), dtype=torch.int64)
        image = transforms.ToTensor()(image)
        return image, target

class SegmentationDataset(Dataset):
    def __init__(self, root_dirs, split='Train', transforms=None):
        self.root_dirs = root_dirs if isinstance(root_dirs, list) else [root_dirs]
        self.split = split
        self.transforms = transforms
        self.image_mask_pairs = self._collect_image_mask_pairs()
        print(f"Matched {len(self.image_mask_pairs)} image-mask pairs.")

    def _collect_image_mask_pairs(self):
        image_mask_pairs = []
        for root_dir in self.root_dirs:
            image_dir = os.path.join(root_dir, self.split, 'imgs')
            mask_dir = os.path.join(root_dir, self.split, 'annos', 'int_maps')
            images = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')])
            masks = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.png')])
            mask_dict = {os.path.basename(mask).split("_")[1].replace('.png', ''): mask for mask in masks}
            for img in images:
                key = os.path.basename(img).split("_")[1].replace('.png', '')
                if key in mask_dict:
                    image_mask_pairs.append((img, mask_dict[key]))
                else:
                    print(f"No matching mask for image: {img}")
        return image_mask_pairs

    def __len__(self):
        return len(self.image_mask_pairs)

    def __getitem__(self, idx):
        img_path, mask_path = self.image_mask_pairs[idx]
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        if self.transforms:
            image, mask = self.transforms(image, mask)
        return image, mask

# Transformation definition
transform = Compose([transforms.Resize((INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH), interpolation=Image.NEAREST)])

# List of root directories
root_dirs = ['CAT/mixed', 'CAT/Brown Field', 'CAT/Main_Trail', 'CAT/Power Line']

# Creating dataset and dataloader instances
trainDs = SegmentationDataset(root_dirs, split="Train", transforms=transform)
testDS = SegmentationDataset(root_dirs, split='Test', transforms=transform)
train_loader = DataLoader(trainDs, batch_size=4, shuffle=True)
test_loader = DataLoader(testDS, batch_size=4, shuffle=False)

# Debugging: Load a batch of data
for images, masks in train_loader:
    print(f"Batch of images shape: {images.shape}")
    print(f"Batch of masks shape: {masks.shape}")
    break

# Instantiate the model
model = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=1, classes=4).cuda()