In [136]:
#| default_exp data

In [137]:
#| hide
from nbdev.showdoc import *

In [138]:
#| export 

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader

from typing import Any, Union, Tuple

In [139]:
#| export
_size_2_t = Union[int, Tuple[int, int]]

In [140]:
# | export

def get_cifar():
    train_ds = torchvision.datasets.CIFAR10(root='../.sample/train', train=True, download=True)
    test_ds = torchvision.datasets.CIFAR10(root='../.sample/test', train=False, download=True)
    return train_ds, test_ds

In [141]:
train_ds, test_ds = get_cifar()

Files already downloaded and verified
Files already downloaded and verified


In [142]:
#| export

class CIFARDataset(Dataset):
    def __init__(self, ds) -> None:
        super().__init__()
        self.ds = ds 
    def __getitem__(self, index) -> Any:
        img = torch.Tensor(self.ds.data[index]).permute(2, 0, 1)
        target = torch.Tensor([self.ds.targets[index]])
        return img.to(torch.float32), target.to(torch.float32)
    def __len__(self):
        return len(self.ds)


In [143]:
CIFARDataset(train_ds)[0][0].dtype, CIFARDataset(train_ds)[0][1]

(torch.float32, tensor([6.]))

In [144]:
len(CIFARDataset(train_ds)), len(CIFARDataset(test_ds))

(50000, 10000)

In [145]:
#| export 

def get_dls(batch_size=100):
    train_data, test_data = get_cifar()
    train_ds = CIFARDataset(train_data)
    test_ds = CIFARDataset(test_data)
    train_dls = DataLoader(train_ds, batch_size=batch_size, drop_last=True, num_workers=4, shuffle=True)
    test_dls = DataLoader(test_ds, batch_size=batch_size, drop_last=False, num_workers=4, shuffle=False)
    return train_dls, test_dls


In [146]:
train_dls, test_dls = get_dls()

Files already downloaded and verified
Files already downloaded and verified


In [147]:
for img, label in train_dls:
    print(img.shape, label.shape, img.dtype, label.dtype)
    break

torch.Size([100, 3, 32, 32]) torch.Size([100, 1]) torch.float32 torch.float32


In [148]:
#| hide
import nbdev; nbdev.nbdev_export()