In [5]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [4]:
class FireDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.image_paths = []
        self.labels = []

        for label, class_name in enumerate(self.classes):
            class_folder = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_folder):
                self.image_paths.append(os.path.join(class_folder, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [1]:

dataset_path = "fire_dataset/"  
transform = transforms.Compose([
    transforms.Resize((64, 64)),  
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
])

fire_dataset = FireDataset(root_dir=dataset_path, transform=transform)
fire_dataloader = DataLoader(fire_dataset, batch_size=64, shuffle=True)

# Iterate through the dataloader
for images, labels in fire_dataloader:
    print(f"Batch size: {images.size(0)}")
    print(f"Image shape: {images.size()}")
    print(f"Labels: {labels}")
    break





NameError: name 'transforms' is not defined