In [1]:
import torch
torch.__version__

'2.6.0'

In [7]:
## 1. Get Data

import os
import zipfile

from pathlib import Path

import requests

# Setup path to data folder
data_path = Path('../data')
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
    print("f{image_path} directory exists")
else:
    print(f"Didn't find {image_path} directory, creating one...")
    image_path.mkdir(parents=True, exist_ok=True)
# Download pizza, steak, sushi data
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
    request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
    print(f"Downloading pizza, steak, sushi data...")
    f.write(request.content)

# Unzip pizza, steak, sushi data 
with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
    print("Unzipping pizza, steak, sushi data...")
    zip_ref.extractall(image_path)

f{image_path} directory exists
Downloading pizza, steak, sushi data...
Unzipping pizza, steak, sushi data...


In [8]:
os.remove(data_path / "pizza_steak_sushi.zip")

In [9]:
# Setup train and testing paths
train_dir = image_path / "train"
test_dir = image_path / "test"

## 2. Create Dataset and DataLoader

Turn the image folder to PyTorch datasets and DataLoaders

In [19]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Create a simple transform
data_transform = transforms.Compose([
    transforms.Resize(size=(64,64)),
    transforms.ToTensor()
])

# Use Image Folder to create dataset(s)
train_data = datasets.ImageFolder(
    root=train_dir,
    transform=data_transform
)

test_data = datasets.ImageFolder(
    root=test_dir,
    transform=data_transform
)

BATCH_SIZE = 32

train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=os.cpu_count()
)

test_dataloader = DataLoader(
    dataset=test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=os.cpu_count()
)

In [20]:
# Get class names as a list
class_names = train_data.classes
class_names

['pizza', 'steak', 'sushi']

In [21]:
# Get class names as a dict
class_dict = train_data.class_to_idx
class_dict

{'pizza': 0, 'steak': 1, 'sushi': 2}

In [22]:
len(train_data), len(test_data)

(225, 75)