# PyTorch Custom Datasets
Reference book --> https://www.learnpytorch.io/04_pytorch_custom_datasets/

We've used the datasets which PyTorch provides, but how to use our own data for training and testing? Let's find out

## Domain Libraries
Depending on what kind problem we're working on, text, audio, vision, recommendation, we'll look into PyTorch domain libraries for existing data loading functions and customizable data loading functions.

We're working on a vision problem, so we'll be checking out the custom data loading functions for `torchvision`.

## 0. Importing PyTorch and setting up device agnostic code


In [1]:
import torch
from torch import nn

torch.__version__

'2.2.1+cu121'

In [2]:
# Setup device agnostic code
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
device

device(type='cpu')

In [3]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


## 1. Get Data
Our dataset is the subset of Food101 dataset.

Food101 has 101 different classes of food and 1000 images per class. (750 testing, 250 training).

Our dataset has only 3 classes of images and 10% of the data (75 testing images, 25 training images). We're doing this to speed up our experiments as larger dataset would take too long for computation.

In [7]:
import requests
import zipfile
from pathlib import Path

# Setup path to a data folder.
data_path = Path('data/')
image_path = data_path / "pizza_steak_sushi"

# If the image folder doesn't exist, download and unzip
if image_path.is_dir():
  print(f'{image_path} already exists... Skipping Download.')
else:
  print(f'{image_path} does not exist. Creating directory...')
  image_path.mkdir(parents=True, exist_ok=True)

  # Download pizza, steak, sushi data
  with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
    request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
    print("Downloading pizza, steak, sushi data...")
    f.write(request.content)

  # Unzip pizza, steak, sushi data
  with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
    print("Unzipping pizza, steak, sushi data...")
    zip_ref.extractall(image_path)

data/pizza_steak_sushi does not exist. Creating directory...
Downloading pizza, steak, sushi data...
Unzipping pizza, steak, sushi data...
