In [72]:
import torch
import os
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from typing import List, Tuple, Dict 

print(f"PyTorch version: {torch.__version__}\ntorchvision version: {torchvision.__version__}")

PyTorch version: 2.0.1
torchvision version: 0.15.2


In [73]:
img_class_name = {0:'clean', 1:'dirty'}
img_class_name_inverse = {'clean':0, 'dirty':1}

In [74]:
img_partition = {0:'train', 1:'test', 2:'hidden'}
img_partition_inverse = {'train':0, 'test':1, 'hidden':2}

In [75]:
class CudaHelper:
    def get_device():
        device = "cuda" if torch.cuda.is_available() else "cpu"
        return device

In [76]:
class ImageItem:
    def __init__(self, image:torch.Tensor, image_class:int, image_partition:int):
        self.image = image
        self.image_class = image_class
        self.image_partition = image_partition
    
    def stringify(self):
        return f'{class_name[self.image_class]}; Class int: {self.image_class}'

In [77]:
class ImageHelper:
    
    def plot_img(img:ImageItem):
        plt.imshow(img.image.squeeze(), cmap='gray')
        plt.title(f'{img.stringify()}')
    
    def do_load(data_dir:str, partition:str, class_name:str) -> List[ImageItem]:
        all_files = os.listdir(data_dir)
        all_images = []

        for f in all_files:
            full_path = os.path.join(data_dir, f)
            img_tensor = read_image(full_path).to(CudaHelper.get_device())
            ii = ImageItem(img_tensor, img_class_name_inverse[class_name], img_partition_inverse[partition])
            all_images.append(ii)
        
        return all_images
    
    def load_images(base_dir:str, partition:str) -> List[ImageItem]:        
        cleans = ImageHelper.do_load(os.path.join(base_dir, 'clean'), partition, 'clean')
        dirtys = ImageHelper.do_load(os.path.join(base_dir, 'dirty'), partition, 'dirty')
        return cleans, dirtys

In [78]:
main_dir = 'C:\\Users\\Valentine\\Downloads\\NeuralNetworks\\plates\\plates\\train'
clean_train, dirty_train = ImageHelper.load_images(main_dir, 'train')
clean_test, dirty_test = ImageHelper.load_images(main_dir, 'test')
clean_hidden, dirty_hidden = ImageHelper.load_images(main_dir, 'hidden')