In [None]:
import torch
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
import cv2
from typing import Tuple, Dict
import random
from skimage.util import random_noise
from skimage.io import imsave, imread
from skimage.transform import resize

In [None]:
class NoisyImageNetDataset(IterableDataset):
    def __init__(self, dataset, image_size=(270, 512)):
        self.dataset = dataset
        self.image_size = image_size
        
        # Noise parameters for adding noise to images
        self.noise_params = {
            'salt_and_pepper': {'val': 0.036},
            'gaussian': {'val': 0.036},
            'salt': {'val': 0.036},
            'pepper': {'val': 0.036},
            'fine_grained': {'val': 0.036}
        }

        # Noise filter mapping for denoising images: (filter_type, filter_params)
        self.noise_filter_mapping = {
            'salt_and_pepper': ('fastNLM', {'h': 25}),
            'gaussian': ('fastNLM', {'h': 23}),
            'salt': ('fastNLM', {'h': 25}),
            'pepper': ('fastNLM', {'h': 23}),
            'fine_grained': ('fastNLM', {'h': 25})
        }


    def add_noise(self, image: np.ndarray, noise_type: str) -> np.ndarray:
        noisy_image = image.copy()
        params = self.noise_params[noise_type]
        
        if noise_type == 'salt_and_pepper':
            mask = np.random.random(image.shape[:2])
            noisy_image[mask < params['amount']/2] = 255
            noisy_image[mask > 1 - params['amount']/2] = 0
            
        elif noise_type == 'gaussian':
            noise = np.random.normal(params['mean'], params['std'], image.shape)
            noisy_image = np.clip(image + noise, 0, 255).astype(np.uint8)
            
        elif noise_type == 'salt':
            mask = np.random.random(image.shape[:2]) < params['amount']
            noisy_image[mask] = 255
            
        elif noise_type == 'pepper':
            mask = np.random.random(image.shape[:2]) < params['amount']
            noisy_image[mask] = 0
            
        elif noise_type == 'uniform':
            noise = np.random.uniform(params['low'], params['high'], image.shape)
            noisy_image = np.clip(image + noise, 0, 255).astype(np.uint8)

        elif noise_type == 'fine_grained':
            val = params['val']
            rows, cols = image.shape

            # Full resolution
            noise_im1 = np.zeros((rows, cols))
            noise_im1 = random_noise(noise_im1, mode='gaussian', var=val**2, clip=False)

            # Half resolution
            noise_im2 = np.zeros((rows//2, cols//2))
            noise_im2 = random_noise(noise_im2, mode='gaussian', var=(val*2)**2, clip=False)  # Use val*2 (needs tuning...)
            noise_im2 = resize(noise_im2, (rows, cols))  # Upscale to original image size

            noise_im = noise_im1 + noise_im2 
            noisy_img = image + noise_im
            noisy_image = np.clip(noisy_img, 0, 255).astype(np.uint8)
            
        return noisy_image
    

    def denoise_image(self, image: np.ndarray, noise_type: str) -> np.ndarray:
        filter_type, filter_params = self.noise_filter_mapping[noise_type]
        
        if filter_type == 'median':
            return cv2.medianBlur(image, filter_params['ksize'])
        elif filter_type == 'gaussian':
            return cv2.GaussianBlur(image, filter_params['ksize'], filter_params['sigmaX'])
        elif filter_type == 'min':
            return cv2.erode(image, np.ones((filter_params['ksize'], filter_params['ksize']), np.uint8))
        elif filter_type == 'max':
            return cv2.dilate(image, np.ones((filter_params['ksize'], filter_params['ksize']), np.uint8))
        elif filter_type == 'fastNLM':
            return cv2.fastNlMeansDenoising(image, filter_params['h'])
        
        return image
    

    def preprocess_image(self, image: np.ndarray) -> np.ndarray:
        if not isinstance(image, np.ndarray):
            image = np.array(image)

        image = cv2.resize(image, self.image_size)
        if len(image.shape) == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            
        return image
    

    def __iter__(self):
        for item in self.dataset:
            # Get image and label
            image = item['image']
            label = item['label']
            
            image = self.preprocess_image(image)
            noise_type = random.choice(list(self.noise_params.keys()))
            noisy_image = self.add_noise(image, noise_type)
            denoised_image = self.denoise_image(noisy_image, noise_type)

            image = torch.FloatTensor(image.transpose(2, 0, 1)) / 255.0
            noisy_image = torch.FloatTensor(noisy_image.transpose(2, 0, 1)) / 255.0
            denoised_image = torch.FloatTensor(denoised_image.transpose(2, 0, 1)) / 255.0
            
            yield {
                'original': image,
                'noisy': noisy_image,
                'denoised': denoised_image,
                'noise_type': noise_type,
                'label': label
            }

In [None]:
from datasets import load_dataset
dataset = load_dataset('imagenet-1k', split='train', streaming=True, use_auth_token=True)
dataset = NoisyImageNetDataset(dataset)
dataloader = DataLoader(dataset, batch_size=4, num_workers=4, pin_memory=True)