# Custom Dataset

In [37]:
from torchvision.datasets import VisionDataset
from typing import Any, Callable, Dict, List, Optional, Tuple
import os

from tqdm import tqdm
import os
import sys
from pathlib import Path
import requests

from skimage import io, transform
import matplotlib.pyplot as plt

In [38]:
import tarfile

class NotMNIST(VisionDataset):
    resource_url = 'http://yaroslavvb.com/upload/notMNIST/notMNIST_large.tar.gz'
    
    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super(NotMNIST, self).__init__(root, transform=transform,
                                       target_transform=target_transform)
        
        if not self._check_exists():
            self.download()
            
        if download:
            self.download()
            
        self.data, self.targets = self._load_data()
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        image_name = self.data[index]
        image = io.imread(image_name)
        label = self.targets[index]
        if self.transform:
            image = self.transform(image)
        return image, label

    def _load_data(self):
        filepath = self.image_folder
        data = []
        targets = []
        
        for target in os.listdir(filepath):
            filenames = [os.path.abspath(
                os.path.join(filepath, target, x)) for x in os.listdir(
                os.path.join(filepath, target))]
            
            targets.extend([target] * len(filenames))
            data.extend(filenames)
        return data, targets
    
    @property
    def raw_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, 'raw')
    
    @property
    def image_folder(self) -> str:
        return os.path.join(self.root, 'notMNIST_large')
    
    def download(self) -> None:
        os.makedirs(self.raw_folder, exist_ok=True)
        fname = self.resource_url.split("/")[-1]
        chunk_size = 1024
        
#        filesize = int(requests.head(self.resource_url).headers["Content-Length"])
        
        with requests.get(self.resource_url, stream=True) as r, open(
            os.path.join(self.raw_folder, fname), "wb") as f, tqdm(
            unit="B", # unit string to be displayed.
            unit_scale=True, # let tqdm to determine the scale in Kilo, mega, ..etc
            unit_divisor=1024, # is used when unit_scale is true
#            total=filesize, # the total iteration.
            file=sys.stdout, # default goes to stderr, this is the display on console.
            desc=fname # prefix to be displayed on progress bar.
        ) as progress:
            for chunk in r.iter_content(chunk_size=chunk_size):
                # download the file chunk by chunk
                datasize = f.write(chunk)
                # on each chunk update the progress bar.
                progress.update(datasize)
            
        self._extract_file(os.path.join(self.raw_folder, fname), target_path=self.root)
        
    def _extract_file(self, fname, target_path) -> None:
        if fname.endswith("tar.gz"):
            tag = "r:gz"
        elif fname.endswith("tar"):
            tag = "r:"
        tar = tarfile.open(fname, tag)
        tar.extractall(path=target_path)
        tar.close()
        
    def _check_exists(self) -> bool:
        return os.path.exists(self.raw_folder)

In [39]:
dataset = NotMNIST("data", download=True)

notMNIST_large.tar.gz: 226B [00:00, 226kB/s]


ReadError: not a gzip file

In [28]:
requests.head('http://yaroslavvb.com/upload/notMNIST/notMNIST_large.tar.gz').headers

{'Date': 'Thu, 16 Mar 2023 12:00:25 GMT', 'Server': 'Apache', 'Keep-Alive': 'timeout=5, max=75', 'Connection': 'Keep-Alive', 'Content-Type': 'text/html; charset=iso-8859-1'}