In [None]:
class CelebADataset(Dataset):
    def __init__(self, img_dir, attr_path, partition_path,
                 target_attr=None, mode='train', NoiseTransform=None):
        """
        Dataset class for the CelebA dataset
        """
        assert mode in ['train', 'val', 'test'], "mode 'train', 'val' ya da 'test' olmalıdır."
        self.img_dir = img_dir
        self.mode = mode
        self.NoiseTransform = NoiseTransform

        # Transformations
        self.transform = T.Compose([
            T.CenterCrop(178),
            T.Resize((64, 64)),
            T.ToTensor()
        ])

        # Read the labels
        attr_df = pd.read_csv(attr_path)
        attr_df.set_index("image_id", inplace=True)
        attr_df = (attr_df + 1) // 2  # -1 → 0

        # Partition file
        partition_df = pd.read_csv(partition_path)
        partition_map = {'train': 0, 'val': 1, 'test': 2}
        partition_value = partition_map[mode]
        partition_df = partition_df[partition_df['partition'] == partition_value]

        # Select the relevant samples
        attr_df = attr_df.loc[partition_df['image_id']]

        self.image_ids = attr_df.index.tolist()
        self.labels = attr_df[target_attr].astype('float32').values if target_attr else None

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_name = self.image_ids[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        # Basic transformations (crop → resize → tensor)
        image = self.transform(image)

        # Version with NoiseTransform applied
        if self.NoiseTransform:
            image_ = self.NoiseTransform(image).to(torch.float32)
            if self.labels is not None:
                label = torch.tensor(self.labels[idx])
                return image, image_, label
            return image, image_

        # Output without labels or noise
        if self.labels is not None:
            label = torch.tensor(self.labels[idx])
            return image, label

        return image