# 04. Pytorch Custom Datasets

In [None]:
import torch
from torch import nn

print(torch.__version__)

device = "mps" if torch.mps.is_available() else "cpu"
device


## 1. Get data
The data we're going to be using is a subset of the Food101 dataset.

Food101 is popular computer vision benchmark as it contains 1000 images of 101 different kinds of foods, totaling 101,000 images (75,750 train and 25,250 test).

Can you think of 101 different foods?

Can you think of a computer program to classify 101 foods?

In [None]:
import requests
import zipfile
from pathlib import Path

# setup path to a data folder
data_path = Path("data")
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
    print(f"{image_path} directory exists.")
else:
    print(f"Did not find {image_path} directory, create one ...")
    image_path.mkdir(parents=True, exist_ok=True)

# Download pizza, steak and sushi data
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
    request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
    f.write(request.content)

with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
    zip_ref.extractall(image_path)



## 2. Become one with the data (data preparation)

The goal will be to take this data storage structure and turn it into a dataset usable with PyTorch.

We can inspect what's in our data directory by writing a small helper function to walk through each of the subdirectories and count the files present.

To do so, we'll use Python's in-built os.walk().

In [None]:
import os
def walk_through_dir(dir_path):
    """Walks through dir_path returning its contents."""
    for dirpath, dirnames, filenames in os.walk(dir_path):
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in {dirpath}")

walk_through_dir(image_path)

In [None]:
# setup train and testing paths
train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

### 2.2 Visualize an image

Let's write some code to:

1. Get all of the image paths using `pathlib.Path.glob()` to find all of the files ending in .jpg.
2. Pick a random image path using Python's `random.choice()`.
3. Get the image class name using `pathlib.Path.parent.stem`.
4. And since we're working with images, we'll open the random image path using `PIL.Image.open()` (PIL stands for Python Image Library).
5. We'll then show the image and print some metadata.

In [None]:
import random
from PIL import Image

# set seed
# random.seed(42)

# 1. get all image paths 
image_path_list = list(image_path.glob("*/*/*.jpg"))

# 2. get random image path
random_image_path =random.choice(image_path_list)

# 3. get image class from path name
image_class = random_image_path.parent.stem

# 4. open image
img = Image.open(random_image_path)

# 5. print metadata
print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image heigth: {img.height}")
print(f"Image width: {img.width}")



In [None]:
import numpy as np
import matplotlib.pyplot as plt

# turn the image into an array
img_as_array = np.asarray(img)

# plot the image with matplotlib
plt.figure(figsize=(10, 7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape}")
plt.axis(False)

In [None]:
img_as_array

## 3. Transforming data

Before we can use our image data with PyTorch we need to:

1. Turn it into tensors (numerical representations of our images).
2. Turn it into a `torch.utils.data.Dataset` and subsequently a `torch.utils.data.DataLoader`, we'll call these Dataset and DataLoader for short.

There are several different kinds of pre-built datasets and dataset loaders for PyTorch, depending on the problem you're working on.

|Problem space	|Pre-built Datasets and Functions|
|-|-|
|Vision|	torchvision.datasets|
|Audio	|torchaudio.datasets|
|Text	|torchtext.datasets|
|Recommendation system	|torchrec.datasets|

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

### 3.1 Transforming data with torchvision.transforms
We've got folders of images but before we can use them with PyTorch, we need to convert them into tensors.

One of the ways we can do this is by using the torchvision.transforms module.

torchvision.transforms contains many pre-built methods for formatting images, turning them into tensors and even manipulating them for data augmentation (the practice of altering data to make it harder for a model to learn, we'll see this later on) purposes .

To get experience with torchvision.transforms, let's write a series of transform steps that:

1. Resize the images using transforms.Resize() (from about 512x512 to 64x64, the same shape as the images on the CNN Explainer website).
2. Flip our images randomly on the horizontal using transforms.RandomHorizontalFlip() (this could be considered a form of data augmentation because it will artificially change our image data).
3. Turn our images from a PIL image to a PyTorch tensor using transforms.ToTensor().

In [None]:
# write transforms for image
data_transform = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()]
)

type(data_transform(img))

In [None]:
def plot_transformed_images(image_paths: list, transform, n: int=3, seed: int=42):
    """Plots a series of random images from image_path.
    
    Will open n image path from image_path, transfrom them 
    with transform and plot the side by side

    args:
        images_paths(list): list of target image paths.
        transform(PyTorch Transforms): Transforms to apply to images.
        n(int, optional): Number of images to plot. Defaults to 3.
        seed(int, optional): Random seed for the random generator.default to 42.
    """
    random.seed(seed)
    random_image_paths = random.sample(image_paths,k=n)
    for image_path in random_image_paths:
        with Image.open(image_path) as f:
            fig, ax = plt.subplots(nrows=1, ncols=2)
            ax[0].imshow(f)
            ax[0].set_title(f"Original\nSize: {f.size}")
            ax[0].axis(False)
            # Note: permute() will change shape of image to suit matplotlib 
            # (PyTorch default is [C, H, W] but Matplotlib is [H, W, C])
            transformed_image = transform(f).permute(1, 2, 0)
            ax[1].imshow(transformed_image)
            ax[1].set_title(f"Transformed\nSize: {transformed_image.shape}")
            ax[1].axis(False)

            fig.suptitle(f"class: {image_path.parent.stem}")

plot_transformed_images(image_path_list, data_transform)


## 4. Option 1: Loading Image Data Using ImageFolder

In [None]:
# use imageFolder to create datasets
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir,
                                  transform=data_transform,
                                  target_transform=None
                                  )
test_data = datasets.ImageFolder(root=test_dir,
                                    transform=data_transform)

print(f"Train data:\n{train_data}")
print(f"Test data:\n{test_data}")

In [None]:
class_names = train_data.classes
class_names

In [None]:
class_dict = train_data.class_to_idx
class_dict

In [None]:
len(train_data), len(test_data)

In [None]:
img, label = train_data[0][0], train_data[0][1]
print(f"Image tensor:\n{img}")
print(f"Image shape: {img.shape}")
print(f"Image datatype: {img.dtype}")
print(f"Image label: {label}")
print(f"Image label: {class_names[label]}")
print(f"Label datatype: {type(label)}")

In [None]:
img_permute= img.permute(1, 2, 0)
plt.imshow(img_permute)
plt.title(class_names[label])
plt.axis(False)

### 4.1 turn loaded image into DataLoader

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 1
train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=1)

test_dataloader = DataLoader(dataset=test_data,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=1)
train_dataloader, test_dataloader

In [None]:
len(train_dataloader), len(test_dataloader)

img, label = next(iter(train_dataloader))
print(img.shape)
print(label.shape)

## 5. Optional 2: Loading image data with a custom dataset

|Pros of creating a custom Dataset	|Cons of creating a custom Dataset|
|-|-|
|Can create a Dataset out of almost anything.	|Even though you could create a Dataset out of almost anything, it doesn't mean it will work.|
|Not limited to PyTorch pre-built Dataset functions.	|Using a custom Dataset often results in writing more code, which could be prone to errors or performance issues.|

To see this in action, let's work towards replicating torchvision.datasets.ImageFolder() by subclassing torch.utils.data.Dataset (the base class for all Dataset's in PyTorch).

We'll start by importing the modules we need:

* Python's `os` for dealing with directories (our data is stored in directories).
* Python's `pathlib` for dealing with filepaths (each of our images has a unique filepath).
* `torch` for all things PyTorch.
* PIL's `Image` class for loading images.
* `torch.utils.data.Dataset` to subclass and create our own custom Dataset.
* `torchvision.transforms` to turn our images into tensors.
* Various types from Python's `typing` module to add type hints to our code.

In [None]:
import os
import pathlib
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List

In [None]:
# Instance of torchvision.datasets.ImageFolder()
train_data.classes, train_data.class_to_idx

### 5.1 Createing a helper function to get class names

Let's write a helper function capable of creating a list of class names and a dictionary of class names and their indexes given a directory path.

To do so, we'll:

* Get the class names using os.scandir() to traverse a target directory (ideally the directory is in standard image classification format).
* Raise an error if the class names aren't found (if this happens, there might be something wrong with the directory structure).
* Turn the class names into a dictionary of numerical labels, one for each class.


In [None]:
# setup path for target directory
target_directory = train_dir
print(f"Target directory:{target_directory}")

# get the class names from the target directory
class_names_found = sorted([entry.name for entry in list(os.scandir(target_directory))])
print(class_names_found)

In [None]:
list(os.scandir(target_directory))

In [None]:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folder names in a target directory.
    
    Assumes target directory is in standard image classification format.

    Args:
        directory (str): target directory to load classnames from.

    Returns:
        Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...))
    
    Example:
        find_classes("food_images/train")
        >>> (["class_1", "class_2"], {"class_1": 0, ...})
    """
    # 1. Get the class names by scanning the target directory
    classes = sorted([entry.name for entry in list(os.scandir(target_directory))])
    # 2. Raise an error if class names not found
    if not classes:
        raise FileNotFoundError(f"Counldn't find any classes in {directory}.")
    # 3. Create a dictionary of index labels
    class_to_idx = {class_name: i for i, class_name in enumerate(classes)}
    return classes, class_to_idx

find_classes(target_directory)

### 5.2 Create a custom Dataset to replicate ImageFolder

1. Subclass `torch.utils.data.Dataset`.
2. Initialize our subclass with a `targ_dir` parameter (the target data directory) and transform parameter (so we have the option to transform our data if needed).
3. Create several attributes for `paths` (the paths of our target images), `transform` (the transforms we might like to use, this can be `None`), `classes` and `class_to_idx` (from our `find_classes()` function).
4. Create a function to load images from file and return them, this could be using `PIL` or `torchvision.io` (for input/output of vision data).
5. Overwrite the `__len__` method of `torch.utils.data.Dataset` to return the number of samples in the `Dataset`, this is recommended but not required. This is so you can call len(`Dataset`).
6. Overwrite the `__getitem__` method of `torch.utils.data.Dataset` to return a single sample from the `Dataset`, this is required.


In [None]:
# Write a custom dataset class
from torch.utils.data import Dataset

# 1. Subclass torch.utils.data.Dataset
class ImageFolderCustom(Dataset):
    # 2. Initialize our custom dataset
    def __init__(self, target_directory: str, transform=None) -> None:
        # 3. Create class attributes
        # Get all image paths
        self.paths = list(pathlib.Path(target_directory).glob("*/*.jpg"))
        # Setup transforms
        self.transform = transform
        # Create classes and class_to_idx attributes
        self.classes, self.class_to_idx = find_classes(target_directory)

    # 4. Create a function to load images
    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path)
    
    # 5. Overwrite the __len__() method
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)
    
    # 6. Overwrite the __getitem__() method
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label(X, y)"
        img = self.load_image(index)
        # expects path in data_folder/class_name/image.jpeg
        class_name = self.paths[index].parent.name
        class_idx = self.class_to_idx[class_name]

        # Transform of if necessary
        if self.transform:
            # return data, label(X, y)
            return self.transform(img), class_idx
        else:
            return img, class_idx
        

In [None]:
# Augment train data
train_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

In [None]:
train_data_custom = ImageFolderCustom(target_directory=train_dir,
                                      transform=train_transform)
test_data_custom = ImageFolderCustom(target_directory=test_dir,
                                     transform=test_transform)
train_data_custom, test_data_custom

In [None]:
len(train_data), len(train_data_custom)

In [None]:
len(test_data), len(test_data_custom)

In [None]:
print(train_data_custom.classes == train_data.classes)
print(train_data_custom.class_to_idx)
print(test_data_custom.classes)
print(test_data_custom.class_to_idx)

### 5.3 Create a function to display random images

Let's create a helper function called display_random_images() that helps us visualize images in our Dataset's.

Specifically, it'll:

1. Take in a Dataset and a number of other parameters such as classes (the names of our target classes), the number of images to display (n) and a random seed.
2. To prevent the display getting out of hand, we'll cap n at 10 images.
3. Set the random seed for reproducible plots (if seed is set).
4. Get a list of random sample indexes (we can use Python's random.sample() for this) to plot.
5. Setup a matplotlib plot.
6. Loop through the random sample indexes found in step 4 and plot them with matplotlib.
7. Make sure the sample images are of shape HWC (height, width, color channels) so we can plot them.

In [None]:
# 1. Create a function to take in a dataset
def display_random_images(dataset: torch.utils.data.Dataset,
                          classes: List[str] = None,
                          n: int = 10,
                          display_shape: bool = True,
                          seed: int = None):
    
    # 2. Adjust display if n is too high
    if n > 10:
        n = 10
        display_shape = False
        print(f"For display purposes, n shouldn't be larger than 10, setting to 10 and removing shape display.")

    # 3. Set random seed
    if seed:
        random.seed(seed)
    
    # 4. Get random sample indexes
    random_samples_idx = random.sample(range(len(dataset)), k = n)

    # 5. Setup plot
    plt.figure(figsize=(16, 8))

    # 6. Loop through samples and display random samples
    for i, targ_sample in enumerate(random_samples_idx):
        targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]
        # 7. Adjust image tensor shape for plotting
        targ_image_adjust = targ_image.permute(1, 2, 0)

        # plot adjusted samples
        plt.subplot(1, n, i+1)
        plt.imshow(targ_image_adjust)
        plt.axis(False)
        if classes:
            title = f"class: {classes[targ_label]}"
            if display_shape:
                title = title + f"\nshape:{targ_image_adjust.shape}"
        plt.title(title)

In [None]:
# display random images from ImageFolder created dataset
display_random_images(train_data,
                      class_names,
                      5)

In [None]:
# display random images from ImageFolder created dataset
display_random_images(train_data_custom,
                      class_names,
                      5)

### 5.4 Turn custom loaded images into dataloader


In [None]:
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_WORKS = 0
print(NUM_WORKS)
train_dataloader_custom = DataLoader(
    dataset=train_data_custom,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKS
)
test_dataloader_custom = DataLoader(
    dataset=test_data_custom,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKS
)
train_dataloader_custom, test_dataloader_custom

In [None]:
img_custom, img_label = next(iter(train_dataloader_custom))
img_custom.shape, img_label.shape

## 6. other forms of transforms(data augmentation)

Data augmentation is the process of altering your data in such a way that you artificially increase the diversity of your training set.


In [None]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [None]:
# get all image paths
image_path_list = list(image_path.glob("*/*/*.jpg"))

plot_transformed_images(image_paths=image_path_list,
                        transform=train_transform,
                        n=3,
                        seed=None)

## 7. Model 0: TinyVGG withput data augmentation

In [None]:
# Create sample transform
simple_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

### 7.1 Creating transforms and loading data for Model 0

1. Load the data, turning each of our training and test folders first into a Dataset with torchvision.datasets.ImageFolder()
2. Then into a DataLoader using torch.utils.data.DataLoader().
We'll set the batch_size=32 and num_workers to as many CPUs on our machine (this will depend on what machine you're using).

In [None]:
# Load and transform data
from torchvision import datasets
train_data_sample = datasets.ImageFolder(root=train_dir,
                                         transform=simple_transform)
test_data_sample = datasets.ImageFolder(root=test_dir,
                                        transform=simple_transform)

# turn data into dataloader
import os
from torch.utils.data import DataLoader

# setup batch size and number of workers
BATCH_SIZE = 4
NUM_WORKS = 0

print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKS} workers")

# Create DataLoader's
train_dataloader_simple = DataLoader(dataset=train_data_sample,
                                     batch_size=BATCH_SIZE,
                                     shuffle=True,
                                     num_workers=NUM_WORKS)
test_dataloader_simple = DataLoader(dataset=test_data_sample,
                                     batch_size=BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=NUM_WORKS)

train_dataloader_simple, test_dataloader_simple

### 7.2 Create TinyVGG model class

In [None]:
class TinyVGG(nn.Module):
    def __init__(self,input_shape: int, hidden_units: int, output_shape: int):
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape,
                      out_channels=hidden_units,
                      kernel_size=2,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=2,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=1,
                         padding=0)
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=2,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=2,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=2,
                         padding=0)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=10*33*33,
                      out_features=output_shape)
        )
    def forward(self, x):
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return self.classifier(x)

In [None]:
torch.manual_seed(42)
model_0 = TinyVGG(input_shape=3, hidden_units=10, output_shape=10).to(device)
model_0

### 7.3 Try a forward pass on a single image (to test the model)
To do a forward pass on a single image, let's:

1. Get a batch of images and labels from the DataLoader.
2. Get a single image from the batch and unsqueeze() the image so it has a batch size of 1 (so its shape fits the model).
3. Perform inference on a single image (making sure to send the image to the target device).
4. Print out what's happening and convert the model's raw output logits to prediction probabilities with torch.softmax() (since we're working with multi-class data) and convert the prediction probabilities to prediction labels with torch.argmax().

In [None]:
img_batch, label_batch = next(iter(train_dataloader_simple))
print(img_batch.shape, label_batch.shape)

In [None]:
img_single, label_single = img_batch[0].unsqueeze(dim=0), label_batch[0]
img_single.shape, label_single

In [None]:
model_0.eval()
with torch.inference_mode():
    preds = model_0(img_single.to(device))

print(preds.shape)
print(torch.softmax(preds, dim=1).argmax(dim=1))
print(label_single)

### 7.4 Use torchinfo to get an idea of the shapes going through our model

In [None]:
from torchinfo import summary
summary(model=model_0, input_size=[1, 3, 64, 64])

### 7.5 Create train & test loop function
Specifically, we're going to make three functions:

* train_step() - takes in a model, a DataLoader, a loss function and an optimizer and trains the model on the DataLoader.
* test_step() - takes in a model, a DataLoader and a loss function and evaluates the model on the DataLoader.
* train() - performs 1. and 2. together for a given number of epochs and returns a results dictionary.

In [None]:
from tqdm import tqdm
from common_function import train_step, test_step
# 1. Take in various parameters required for training and test steps
def train(model: nn.Module,
               train_dataloader: torch.utils.data.DataLoader,
               test_dataloader: torch.utils.data.DataLoader,
               optimizer: torch.optim.Optimizer,
               loss_fn: torch.nn.Module=nn.CrossEntropyLoss(),
               epochs: int=5):
    # 2. create empty results dictionary
    results = {
        "train_loss":[],
        "train_acc":[],
        "test_loss":[],
        "test_acc":[]
    }

    # 3. loop through training and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer,
                                           device=device
                                           )
        test_loss, test_acc = test_step(model=model,
                                        dataloader=test_dataloader,
                                        loss_fn=loss_fn,
                                        device=device)
        # 4. Print out what's happening
        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_acc:.4f}"
        )
        # 5. Update results dictionary
        # Ensure all data is moved to CPU and converted to float for storage
        results["train_loss"].append(train_loss.item() if isinstance(train_loss, torch.Tensor) else train_loss)
        results["train_acc"].append(train_acc.item() if isinstance(train_acc, torch.Tensor) else train_acc)
        results["test_loss"].append(test_loss.item() if isinstance(test_loss, torch.Tensor) else test_loss)
        results["test_acc"].append(test_acc.item() if isinstance(test_acc, torch.Tensor) else test_acc)

    # 6. Return the filled results at the end of the epochs
    return results

### 7.7 train and evaluate Model 0

In [None]:
# set random seeds
torch.manual_seed(42)
torch.mps.manual_seed(42)

# set epochs
NUM_EPOCHS = 5

# recreate an instance of TinyVGG
model_0 = TinyVGG(input_shape=3,
                  hidden_units=10,
                  output_shape=len(train_data.classes)).to(device=device)

# setup loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_0.parameters(),
                             lr=0.001)

# start the timer
from timeit import default_timer as timer
start_time = timer()

# train model_0
model_0_result = train(model=model_0,
                       train_dataloader=train_dataloader_simple,
                       test_dataloader=test_dataloader_simple,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=NUM_EPOCHS)
# end the timer and print out how long it took
end_time = timer()

print(f"Total training time: {end_time-start_time:.3f} s.")

In [None]:
model_0_result

### 7.8 plot the loss curves of Model 0
From the print outs of our model_0 training, it didn't look like it did too well.

But we can further evaluate it by plotting the model's loss curves.

Loss curves show the model's results over time.


In [None]:
# get the model_0_results keys
model_0_result.keys()

In [None]:
def plot_loss_curves(results: Dict[str, List[float]]):
    """Plots training curves of a results dictionary.

    Args:
        results (dict): dictionary containing list of values, e.g.
            {"train_loss": [...],
             "train_acc": [...],
             "test_loss": [...],
             "test_acc": [...]}
    """
    # Get the loss values of the results dictionary
    loss = results['train_loss']
    test_loss = results['test_loss']
    
    # get the accuracy values of the results dictionary
    accuracy = results['train_acc']
    test_accuracy = results['test_acc']

    epochs = range(len(results['train_loss']))

    plt.figure(figsize=(15, 7))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, label='train_loss')
    plt.plot(epochs, test_loss, label='test_loss')
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy, label='train_acc')
    plt.plot(epochs, test_accuracy, label='test_acc')
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.legend()

In [None]:
plot_loss_curves(results=model_0_result)

## 8. What should an ideal loss curve look like
https://developers.google.com/machine-learning/crash-course/overfitting/interpreting-loss-curves?hl=ja

### 8.1 How to deal with overfitting¶

|Method to prevent overfitting	|What is it?|
|-|-|
|Get more data	|Having more data gives the model more opportunities to learn patterns, patterns which may be more generalizable to new examples.|
|Simplify your model	|If the current model is already overfitting the training data, it may be too complicated of a model. This means it's learning the patterns of the data too well and isn't able to generalize well to unseen data. One way to simplify a model is to reduce the number of layers it uses or to reduce the number of hidden units in each layer.|
|Use data augmentation	|Data augmentation manipulates the training data in a way so that's harder for the model to learn as it artificially adds more variety to the data. If a model is able to learn patterns in augmented data, the model may be able to generalize better to unseen data.|
|Use transfer learning	|Transfer learning involves leveraging the patterns (also called pretrained weights) one model has learned to use as the foundation for your own task. In our case, we could use one computer vision model pretrained on a large variety of images and then tweak it slightly to be more specialized for food images.|
|Use dropout layers	|Dropout layers randomly remove connections between hidden layers in neural networks, effectively simplifying a model but also making the remaining connections better. See torch.nn.Dropout() for more.|
|Use learning rate decay	|The idea here is to slowly decrease the learning rate as a model trains. This is akin to reaching for a coin at the back of a couch. The closer you get, the smaller your steps. The same with the learning rate, the closer you get to convergence, the smaller you'll want your weight updates to be.|
|Use early stopping	|Early stopping stops model training before it begins to overfit. As in, say the model's loss has stopped decreasing for the past 10 epochs (this number is arbitrary), you may want to stop the model training here and go with the model weights that had the lowest loss (10 epochs prior).|

### 8.2 How to deal with underfitting

|Method to prevent underfitting|	What is it?|
|-|-|
|Add more layers/units to your model	|If your model is underfitting, it may not have enough capability to learn the required patterns/weights/representations of the data to be predictive. One way to add more predictive power to your model is to increase the number of hidden layers/units within those layers.|
|Tweak the learning rate	|Perhaps your model's learning rate is too high to begin with. And it's trying to update its weights each epoch too much, in turn not learning anything. In this case, you might lower the learning rate and see what happens.|
|Use transfer learning	|Transfer learning is capable of preventing overfitting and underfitting. It involves using the patterns from a previously working model and adjusting them to your own problem.|
T|rain for longer	|Sometimes a model just needs more time to learn representations of data. If you find in your smaller experiments your model isn't learning anything, perhaps leaving it train for a more epochs may result in better performance.|
|Use less regularization	|Perhaps your model is underfitting because you're trying to prevent overfitting too much. Holding back on regularization techniques can help your model fit the data better.|

## 9. Model 1: TinyVGG with data augmentation

### 9.1 Create transform with data augmentation

In [None]:
 # Create training transform with TriviailAugment
from torchvision import transforms
train_trainsform_trivial = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor()
])
test_trainsform_simple = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.ToTensor()
])

### 9.2 Create train and test Dataset's and DataLoader's

In [None]:
# Turn image folders into Datasets
from torchvision import datasets
train_data_augmented = datasets.ImageFolder(root=train_dir,
                                          transform=train_trainsform_trivial)
test_data_simple = datasets.ImageFolder(root=test_dir,
                                          transform=test_trainsform_simple)

In [None]:
# Turn our Dataset into Data+pader
import os
BATCH_SIZE=32
NUM_WORKS=1

from torch.utils.data import DataLoader
train_dataloader_argumented = DataLoader(dataset=train_data_augmented,
                                         batch_size=BATCH_SIZE,
                                         shuffle=True,
                                         num_workers=NUM_WORKS)
test_dataloader_simple = DataLoader(dataset=test_data_simple,
                                         batch_size=BATCH_SIZE,
                                         shuffle=False,
                                         num_workers=NUM_WORKS)
train_dataloader_argumented, test_dataloader_simple

### 9.3 Construct and train Model 1

In [None]:
# Create model_1 and send it to the target device
torch.manual_seed(42)
model_1 = TinyVGG(input_shape=3,
                  hidden_units=10,
                  output_shape=len(train_data.classes)).to(device=device)
model_1

In [None]:
# set random seeds
torch.manual_seed(42)
torch.mps.manual_seed(42)

# set number of epochs
NUM_EPOCHS = 5

# setup loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_1.parameters(), lr=0.001)

# start the timer
from timeit import default_timer as timer
start_time = timer()

# train model_1

model_1_results = train(model=model_1,
                        train_dataloader=train_dataloader_argumented,
                        test_dataloader=test_dataloader_simple,
                        optimizer=optimizer,
                        loss_fn=loss_fn,
                        epochs=NUM_EPOCHS)
end_time = timer()
print(f"total training time: {end_time-start_time:.3f} s.")

### 9.4 Plot the loss curves of model 1

In [None]:
plot_loss_curves(model_1_results)

## 10. Compare model results

In [None]:
import pandas as pd
model_0_df = pd.DataFrame(model_0_result)
model_1_df = pd.DataFrame(model_1_results)
model_0_df

In [None]:
# Setup a plot 
plt.figure(figsize=(15, 10))

# Get number of epochs
epochs = range(len(model_0_df))

# Plot train loss
plt.subplot(2, 2, 1)
plt.plot(epochs, model_0_df["train_loss"], label="Model 0")
plt.plot(epochs, model_1_df["train_loss"], label="Model 1")
plt.title("Train Loss")
plt.xlabel("Epochs")
plt.legend()

# Plot test loss
plt.subplot(2, 2, 2)
plt.plot(epochs, model_0_df["test_loss"], label="Model 0")
plt.plot(epochs, model_1_df["test_loss"], label="Model 1")
plt.title("Test Loss")
plt.xlabel("Epochs")
plt.legend()

# Plot train accuracy
plt.subplot(2, 2, 3)
plt.plot(epochs, model_1_df["train_acc"], label="Model 1")
plt.plot(epochs, model_0_df["train_acc"], label="Model 0")
plt.title("Train Accuracy")
plt.xlabel("Epochs")
plt.legend()

# Plot test accuracy
plt.subplot(2, 2, 4)
plt.plot(epochs, model_0_df["test_acc"], label="Model 0")
plt.plot(epochs, model_1_df["test_acc"], label="Model 1")
plt.title("Test Accuracy")
plt.xlabel("Epochs")
plt.legend();

## 11. Make a prediction on a custom image

In [None]:
import requests

# setup custom image path
custom_image_path = data_path / "custom-image-sushi.jpeg"
type(custom_image_path)

In [None]:
import torchvision

custom_image = torchvision.io.read_image(str(custom_image_path)).type(torch.float32)
custom_image = custom_image[:3]  # RGB
custom_image /= 255
# print out image data
#print(f"Custom image tensor:\n{custom_image_uint8}")
print(custom_image.shape)
print(custom_image.dtype)
print(custom_image)

### 11.2 Predicting on custom images with a trained pytorch model
Our model was trained on images with shape [3, 64, 64], whereas our custom image is currently [4, 998, 944].

How could we make sure our custom image is the same shape as the images our model was trained on?

Are there any torchvision.transforms that could help?

Before we answer that question, let's plot the image with matplotlib to make sure it looks okay, remember we'll have to permute the dimensions from CHW to HWC to suit matplotlib's requirements.

In [None]:
# Plot custom image
plt.imshow(custom_image.permute(1, 2, 0)) # need to permute image dimensions from CHW -> HWC otherwise matplotlib will error
plt.title(f"Image shape: {custom_image.shape}")
plt.axis(False);

In [None]:
# create transform pipeline to resize image
custom_image_transform = transforms.Compose([
    transforms.Resize(size=(64, 64))
])

custom_image_transformed = custom_image_transform(custom_image)

print(custom_image_transformed.shape)


In [None]:
# Plot custom image
plt.imshow(custom_image_transformed.permute(1, 2, 0)) # need to permute image dimensions from CHW -> HWC otherwise matplotlib will error
plt.title(f"Image shape: {custom_image_transformed.shape}")
plt.axis(False);

In [None]:
model_0.eval()
with torch.inference_mode():
    custom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0).to(device)
    print(custom_image_transformed_with_batch_size.shape)
    
    custom_image_pred = model_0(custom_image_transformed_with_batch_size).to(device)

custom_image_pred

In [None]:
custom_image_label = torch.softmax(custom_image_pred, dim=1).argmax(dim=1)
print(class_names[custom_image_label])

### 11.3 Putting custom image prediction together: building a function


In [None]:
def pred_and_plot_image(model: torch.nn.Module,
                        image_path: str,
                        class_names: List[str]=None,
                        transform=None,
                        device="cpu"):
    """Makes a prediction on a target image and plots the image with its prediction."""
    # 1. Load in image and convert the tensor values to float32
    target_image = torchvision.io.read_image(image_path).type(torch.float32)

    # 2. Divide the image pixel values by 255 to get them between [0,1]
    target_image /= 255

    # 3. Transform if necessary
    if target_image.shape[0] == 4:
        target_image = target_image[:3]
    if transform:
        target_image = transform(target_image)
    
    # 4. Make sure the model is on the target device
    model.to(device)

    # 5. Turn on model evaluation mode and inference mode
    model.eval()
    with torch.inference_mode():
        # add an extra dimension to the image
        target_image = target_image.unsqueeze(dim=0)
        # Make a prediction on image with an extra dimension and send it to the target device
        target_image_pred = model(target_image.to(device))
    
    # 6. Convert logits to prediction probabilities
    target_image_pred_probs = torch.softmax(target_image_pred, dim=1)

    # 7. Convert prediction probabilities to prediction probability
    target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)

    # 8. Plot the image alongside the prediction and predcition probability
    plt.imshow(target_image.squeeze().permute(1, 2, 0))
    if class_names:
        title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
    else: 
        title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
    plt.title(title)
    plt.axis(False)


In [None]:
# pred on our custom image
pred_and_plot_image(model=model_0,
                    image_path=custom_image_path,
                    class_names=class_names,
                    transform=custom_image_transform,
                    device=device)