In [1]:
import os
from torch.utils.data import Dataset
from PIL import Image
import torch
from config import Config

class DeepfakeDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_test=False):
        self.is_test = is_test
        self.transform = transform
        
        if not is_test:
            self.real_dir = os.path.join(root_dir, 'real')
            self.fake_dir = os.path.join(root_dir, 'fake')
            self.image_paths = []
            self.labels = []

            for img_name in os.listdir(self.real_dir):
                self.image_paths.append(os.path.join(self.real_dir, img_name))
                self.labels.append(1)

            for img_name in os.listdir(self.fake_dir):
                self.image_paths.append(os.path.join(self.fake_dir, img_name))
                self.labels.append(0)
        else:
            self.image_paths = [os.path.join(root_dir, f) for f in sorted(os.listdir(root_dir))]
            self.image_ids = [int(os.path.splitext(f)[0]) for f in sorted(os.listdir(root_dir))]
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert('RGB')
            if self.transform:
                image = self.transform(image)
            
            if self.is_test:
                return image, self.image_ids[idx]
            else:
                return image, torch.tensor(self.labels[idx], dtype=torch.float32)
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            if self.is_test:
                return torch.zeros((3, Config.img_size, Config.img_size)), self.image_ids[idx]
            else:
                return torch.zeros((3, Config.img_size, Config.img_size)), torch.tensor(self.labels[idx], dtype=torch.float32)


ModuleNotFoundError: No module named 'config'