In [1]:
import os
import torch
import random
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
#Class for aranging the label.
class ImageFolderWithLabels(Dataset):
    def __init__(self, root_dir, transform=None):
        
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []

        for label, subdir in enumerate(os.listdir(root_dir)):
            subdir_path = os.path.join(root_dir, subdir)
            if os.path.isdir(subdir_path):
                for filename in os.listdir(subdir_path):
                    file_path = os.path.join(subdir_path, filename)
                    self.images.append(file_path)
                    self.labels.append(label)  # 0 for normal, 1 for artifact

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [2]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [5]:
hand_dir = 'images/hand'
face_dir = 'images/face'
body_dir = 'images/body'
hand_art_dir = 'images/hand-art'
face_art_dir = 'images/face-art'
body_art_dir = 'images/body-art'

In [None]:
hand_dataset = ImageFolderWithLabels(hand_dir, transform=data_transforms)
face_dataset = ImageFolderWithLabels(face_dir, transform=data_transforms)
body_dataset = ImageFolderWithLabels(body_dir, transform=data_transforms)
hand_art_dataset = ImageFolderWithLabels(hand_art_dir, transform=data_transforms)
face_art_dataset = ImageFolderWithLabels(face_art_dir, transform=data_transforms)
body_art_dataset = ImageFolderWithLabels(body_art_dir, transform=data_transforms)

In [None]:
hand_train_size = int(0.7 * len(hand_dataset))
hand_val_size = int(0.15 * len(hand_dataset))
hand_test_size = len(hand_dataset) - hand_train_size - hand_val_size

face_train_size = int(0.7 * len(face_dataset))
face_val_size = int(0.15 * len(face_dataset))
face_test_size = len(face_dataset) - face_train_size - face_val_size

body_train_size = int(0.7 * len(body_dataset))
body_val_size = int(0.15 * len(body_dataset))
body_test_size = len(body_dataset) - body_train_size - body_val_size

hand_art_train_size = int(0.7 * len(hand_art_dataset))
hand_art_val_size = int(0.15 * len(hand_art_dataset))
hand_art_test_size = len(hand_art_dataset) - hand_art_train_size - hand_art_val_size

face_art_train_size = int(0.7 * len(face_art_dataset))
face_art_val_size = int(0.15 * len(face_art_dataset))
face_art_test_size = len(face_art_dataset) - face_art_train_size - face_art_val_size

body_art_train_size = int(0.7 * len(body_art_dataset))
body_art_val_size = int(0.15 * len(body_art_dataset))
body_art_test_size = len(body_art_dataset) - body_art_train_size - body_art_val_size