In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader

In [None]:
class CustomLungDataset(Dataset):
    def __init__(self,root,transform=None):
        self.root = root
        self.transform = transform
        self.classes=sorted(os.listdir(root))
        self.class_to_idx = {cls:i for i,cls in enumerate(self.classes)}
        self.images = self.load_images()
    
    def load_images(self):
        images = []
        for cls in self.classes:
            cls_path = os.path.join(self.root,cls)
            cls_idx = self.class_to_idx[cls]
            for file in os.listdir(cls_path):
                img_path = os.path.join(cls_path,file)
                images.append((img_path,cls_idx))
        return images
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,idx):
        img_path,label = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image,label

In [None]:
data_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])

In [None]:
lung_dataset = CustomLungDataset(root='Data\Train',transform=data_transform)

In [None]:
lung_dataloader = DataLoader(lung_dataset,batch_size=32,shuffle=True)

In [None]:
print("No of Classes:",len(lung_dataset.classes))
print("Class Name:", lung_dataset.classes)
print("No of images:",lung_dataset)

In [None]:
sample_img,sample_label = lung_dataset[0]
print("Sample Image Shape:",sample_img.shape,"Label:",sample_label)