In [None]:
import os
import xml.etree.ElementTree as ET
import scipy.io as sio
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch

# Custom dataset class
class DogsDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.image_paths = []
        self.annotations = []

        # Load images and annotations
        for root, _, files in os.walk(self.annotation_dir):
            for file in files:
                if file.endswith(".xml"):
                    annotation_path = os.path.join(root, file)
                    tree = ET.parse(annotation_path)
                    root = tree.getroot()
                    folder = root.find('folder').text
                    filename = root.find('filename').text
                    image_path = os.path.join(self.image_dir, folder, f"{filename}.jpg")
                    if os.path.exists(image_path):
                        self.image_paths.append(image_path)
                        self.annotations.append(annotation_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        annotation_path = self.annotations[idx]

        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image

# Define transforms
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Create datasets and dataloaders
image_dir = 'data/all_dogs'
annotation_dir = 'data/Annotation'
dataset = DogsDataset(image_dir, annotation_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
