<a href="https://colab.research.google.com/github/smbonilla/learningPyTorch/blob/main/04_pyTorchCustomDatasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 04. PyTorch Custom Datasets Video Notebook 

How do you get your own data into PyTorch? 

One of the ways to do so is via custom datasets. 

## 0. Importing PyTorch and setting up device-agnostic code

In [2]:
import torch
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

## 1. Get data

Our dataset is a subset of the Food101 dataset and 1000 images per class (750 training, 250 testing) 

Food101 starts with 101 different classes of food. Our dataset starts with 3 classes of food and only 10% of the images.

When starting out ML projects, it's important to try things on a small scale and then increase the scale when necessary 

In [3]:
import requests
import zipfile
from pathlib import Path

# Setup path to a data folder 
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

# if image folder doesn't exist, download
if image_path.is_dir():
  print(f"{image_path} directory already exists ... skipping download")
else:
  print(f"{image_path} directory doesn't exist, creating one...")
  image_path.mkdir(parents=True, exist_ok=True)

# Download pizza, steak and sushi data 
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
  request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
  print(f"Downloading pizza, steak, and sushi data...")
  f.write(request.content)

f.close()

# unzip pizza, steak, sushi data
with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
  print("Unzipping pizza, steak, and sushi data...")
  zip_ref.extractall(image_path)

data/pizza_steak_sushi directory doesn't exist, creating one...
Downloading pizza, steak, and sushi data...
Unzipping pizza, steak, and sushi data...


## 2. Becoming one with the data (data preparation and data exploration)

In [4]:
import os 
def walkThroughDir(dirPath):
  """
  Walks through dirPath returning its contents.
  """
  for dirpath, dirnames, filenames in os.walk(dirPath):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [5]:
walkThroughDir(image_path)

There are 2 directories and 0 images in 'data/pizza_steak_sushi'.
There are 3 directories and 0 images in 'data/pizza_steak_sushi/train'.
There are 0 directories and 72 images in 'data/pizza_steak_sushi/train/sushi'.
There are 0 directories and 75 images in 'data/pizza_steak_sushi/train/steak'.
There are 0 directories and 78 images in 'data/pizza_steak_sushi/train/pizza'.
There are 3 directories and 0 images in 'data/pizza_steak_sushi/test'.
There are 0 directories and 31 images in 'data/pizza_steak_sushi/test/sushi'.
There are 0 directories and 19 images in 'data/pizza_steak_sushi/test/steak'.
There are 0 directories and 25 images in 'data/pizza_steak_sushi/test/pizza'.


In [7]:
# Setup train and testing paths
trainDir = image_path / "train"
testDir = image_path / "test"

trainDir, testDir

(PosixPath('data/pizza_steak_sushi/train'),
 PosixPath('data/pizza_steak_sushi/test'))

In [None]:
# Visualize some data