<div class="alert alert-warning"> <b>Google Colab Setup:</b> If you're running this notebook on Google Colab, please run the following two cells to mount your Google Drive, set the relevant paths, and change the current working directory. You may skip to the "Install relevant packages" cell if you're running the notebook locally.</div> 

In [None]:
# Mount Google Drive
import os
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# Set up mount symlink
DRIVE_PATH = '/content/gdrive/My\ Drive/demo9_data_augmentation'
DRIVE_PYTHON_PATH = DRIVE_PATH.replace('\\', '')
if not os.path.exists(DRIVE_PYTHON_PATH):
    %mkdir $DRIVE_PATH

## the space in `My Drive` causes some issues,
## make a symlink to avoid this
SYM_PATH = '/content/demo9_data_augmentation'
if not os.path.exists(SYM_PATH):
    !ln -s $DRIVE_PATH $SYM_PATH
    
# Change working directory
os.chdir('demo9_data_augmentation')

In [None]:
# Install relevant packages
!pip install -r requirements.txt

In [None]:
# Loading dependencies
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

from utils import run_training_loop, test_performance, show_data_augmentations

# Set up Jupyter notebook environment
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# CS 182 Demo: Data Augmentations

Time and time again, regularization has proven invaluable to machine learning practitioners. With particular focus on this class, we've discussed adapting our models to ensure they do not overfit to the training data. In other words, we want our model to generalize well to unseen data. 

Many of the approaches discussed --batch normalization, layer normalization, and dropout, to name a few-- involve tuning and adjusting the inner workings of our deep learning architectures. Although Convolutional Neural Networks (CNNs) arose out of a desire for an architecture that had invariances built into it, oftentimes in practice this isn't enough. Data augmentation, the act of modifying our input training data, provides a different approach to regularizing our models.

See the following image from Sharon Y. Li's [Stanford AI Lab Blog post](https://ai.stanford.edu/blog/data-augmentation/) for an example of how data augmentation fits into the machine learning pipeline.

<img src="https://drive.google.com/uc?id=1ju6SFtwobhE5sEMD9ZeDBcmyVkKwgGkq" width="800px" align="center"></img>

[Figure 1](https://ai.stanford.edu/blog/data-augmentation/). Data augmentations apply a sequence of transformation functions tuned by human experts to the original data. The augmented data will be used for training downstream models.

## Part 1: Data Augmentation in Theory

### Acting as a Regularizer
Regularization is necessary for deep learning models to generalize well to unseen data (i.e. test data), and it can be introduced into models through explicit methods such as adding a weighted $L_{1}$ or $L_{2}$ penalty, for example, to the loss function. However, regularization can also be introduced implicitly into the model through data augmentation. In previous course material we saw that a least squares problem with data augmentation is equivalent to an $L_{2}$ regularized least squares problem.

Given a dataset $D=\{(x_i, y_i)\}^{m}_{i=1}$ consisting of $m$ datapoints, where $x_{i} \in \mathbb{R}^p$ and $y_{i} \in \mathbb{R}$, we can augment our data by adding Gaussian noise such that for each training point $x_{i}$, we have:

$$\tilde{X}_i = x_i + N_i \text{, where } N_i \sim \mathcal{N}(0,\,\sigma^{2}I)$$

We can construct $\tilde{X} \in \mathbb{R}^{m \times p}$ by stacking our $\tilde{X}_i$ terms together and can construct our original design matrix $X \in \mathbb{R}^{m \times p}$ by stacking our $x_{i}$ vectors together. 

Thus, it can be shown that minimizing the expected least squares objective for the noisy (i.e. augmented) data matrix is equivalent to miniziming the least squares objective with $L_{2}$ regularization: 

\begin{equation}
   \arg \min_{w} \mathbb{E}[\|y - \tilde{X}w\|^2_{2}] = \arg \min_{w}\frac{1}{m}\|y - Xw\|^2_{2} + \lambda\|w\|^2_{2} \text{, where } \lambda = \sigma^{2}
\end{equation}

The above result proves the regularizing effect that data augmentation can have on a model. Think about how we can generalize the idea of adding random Gaussian noise from the least squares setting to a computer vision problem. To achieve a similar regularizing effect, one can add random Gaussian noise to each pixel of an image. 

### Example
Consider an edge or a pattern that consistently appears near the center of an image in a subset of the training data. For example, this might be the stripes of a zebra consistently appearing in the center, due to the images of the zebra consistently being centered on the zebra. The model (i.e. CNN) will latch onto that edge or pattern as it is designed to do. Due to translational/equivariance invariance, the model should be able to detect if the zebra and its stripes were shifted around the image, thus providing the model with information to inform its prediction (ideally, that would be 'zebra').

However, in reality, images are not always as clean-cut as the training examples in datasets such as CIFAR-10. The quality of the image could be poor (blurriness) resulting in the model poorly identifying features, or a glare in the picture that distracts the model from important patterns and edges that result in a correct classification. It might even be possible that the subset of zebra images are entirely centered on a docile zebra standing horizontally, meaning a rotated or 'active' zebra may result in a misclassification.

Thus, data augmentation during training provides a way of implicitly regularizing the model by artificially creating scenarios that could be realistically seen in the real world. These data augmentations force the model to adapt to these changes by relying less on exploiting patterns from idealized versions of images. By augmenting our data, we are approximating what our data looks like in the real world.

Data augmentation is also an avenue for domain knowledge to be exploited to help produce more accurate while still robust models. Experts and scientists can provide encodings of what is important via domain emphasis by selecting data augmentations that accurately represent what the model may encounter post-training.

## Part 2: Basic Augmentations

These augmentations are used to promote invariance to small semantically insignificant changes. A few basic augmentations are:
1. Random Cropping
2. Rescaling
3. Rotations
3. Subset
4. Color Adjustment
5. Blurring

### Using PyTorch for Data Augmentation

There are many data augmentations that are implemented in the `torchvision.transforms` modules. Below are a few examples of common data augmentations used in CNNs. Let's load in a public domain image of a [Golden Retriever](https://www.publicdomainpictures.net/en/view-image.php?image=35696&picture=golden-retriever-dog) to work with.

In [None]:
# Read in image as NumPy array
dog = plt.imread(fname="images/dog.jpg", format=None)

# Convert NumPy array to Tensor
dog = torch.from_numpy(dog)

fig, ax = plt.subplots()
ax.imshow(dog)
ax.set_xticks(ticks=[])
ax.set_yticks(ticks=[])
plt.title("Original Image");

<div class="alert alert-warning"> <b>Permuting Your Images: </b>The transformations found within the <code>transforms</code> library expect RGB images to be of shape $(3, H, W)$ but the <code>imshow</code> method from <code>matplotlib.pyplot</code> expects RGB images of shape $(H, W, 3)$. Thus, we must use PyTorch's <code>permute</code> method to reorder the images' dimensions into the correct shape.</div> 

### Random Rotation
Use the `transforms.RandomRotation` method to randomly rotate an image up to a certain degree. You can find the official PyTorch documentation [here](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomRotation.html).

**Example:** Random rotations work to introduce rotational invariance and equivariance to your model. Think about the case of recognizing handwritten digits. A slightly tilted "7" should still be classified as a "7". By randomly rotating our original data, we are forcing the model to be invariant to the rotation angle of the "7".

In [None]:
# Set maximum degree you'd like to rotate the image by
MAX_DEGREE = 45

# Rotate image using rotation transformation
rotation = transforms.RandomRotation(degrees=MAX_DEGREE,
                                     interpolation=transforms.functional.InterpolationMode.NEAREST,
                                     expand=False,
                                     center=None,
                                     fill=0)

# Randomly rotate the original image four times
show_data_augmentations(original_im=dog, transform_f=rotation, title="Rotated Images")

### Random Crop
Use the `transforms.RandomResizedCrop` method to crop the given image at a random location and then resize to match the original image size. Oftentimes, we fix our model architecture such that it always takes in the same size input image, thus it's important to resize after cropping. You can find the official PyTorch documentation [here](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomCrop.html).

**Example:** Random crops help to introduce scale invariance to our models since the objects in our training data may not always be of the same scale. Think about our toy example of classifying handwritten digits. Sometimes we might get a "7" written quite small, but other times the "7" may take up the entire image. Other times, we may not see the entirety of the "7"; adding random crops helps the model to generalize to different scales of handwritten digits.

In [None]:
# Load in a fresh, untransformed image
dog = torch.from_numpy(plt.imread(fname="images/dog.jpg", format=None))

# Crop the input image based on crop size
crop = transforms.RandomResizedCrop(size=dog.shape[:2],
                                    scale=(0.08, 1.0),
                                    ratio=(0.75, 1.3333333333333333),
                                    interpolation=transforms.functional.InterpolationMode.BILINEAR,
                                    antialias='warn')

# Randomly crop the original image four times
show_data_augmentations(original_im=dog, transform_f=crop, title="Cropped Images")

Notice that the image size before and after cropping is the same:

In [None]:
print(f'Before Crop: {dog.shape}')
print(f'After Crop: {crop(dog.permute(2, 0, 1)).permute(1, 2, 0).shape}')

### Gaussian Blur
Use the `transforms.GaussianBlur` method to blur the image with randomly chosen Gaussian blur. You can find the official PyTorch documentation [here](https://pytorch.org/vision/main/generated/torchvision.transforms.GaussianBlur.html).

**Example:** The application of smoothing operators such as Gaussian filters to input images is often used to eliminate the noisy, high-frequency components. Say you have a grainy image of a dog and want to use a CNN to classify the image correctly. By convolving the image with a Gaussian kernel, we are filtering out the grainy, high-frequency components, thus leaving only a blurred version of the dog; doing so helps the model learn only the most important features of the input image.

In [None]:
# Load in a fresh, untransformed image
dog = torch.from_numpy(plt.imread(fname="images/dog.jpg", format=None))

# Set kernel size and variance
KERNEL_SIZE=13
SIGMA_RANGE=(0.1,200)

# Pass a Gaussian filter over the image
blur = transforms.GaussianBlur(kernel_size=KERNEL_SIZE, sigma=SIGMA_RANGE)

# Randomly blur the original image four times
show_data_augmentations(original_im=dog, transform_f=blur, title="Blurred Images")

<div class="alert alert-info"> <b>Try it out:</b> Play around with the <code>KERNEL_SIZE</code> and <code>SIGMA_RANGE</code> variables. How do both the kernel size and the sigma value affect the resulting blur? What happens if you use a small kernel with a large sigma and vice versa?</div>

### Color Jitter

Use the `transforms.ColorJitter` method to randomly adjusted brightness, contrast, saturation and hue of an image. You can find the official PyTorch documentation [here](https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html).

**Example:** Adding color jitter to your training images allows the network to generalize to many different real-world scenarios by allowing the model to be invariant to the lighting conditions. When trying to classify an input image as a "dog", we don't want the network to rely on images of dogs that were taken in ideal lighting conditions. Both a bright and dimly-lit image of a dog should be classified as "dog".

In [None]:
# Load in a fresh, untransformed image
dog = torch.from_numpy(plt.imread(fname="images/dog.jpg", format=None))

# Set ColorJitter parameters
BRIGHTNESS = 0.5
CONTRAST = 0.5
SATURATION = 0.1
HUE = 0.05

# Add color jitter to the image
jitter = transforms.ColorJitter(brightness=BRIGHTNESS, contrast=CONTRAST, saturation=SATURATION, hue=HUE)

# Randomly jitter the original image four times
show_data_augmentations(original_im=dog, transform_f=jitter, title="Jittered Images")

# Part 3: Advanced Augmentations

Advanced data augmentations provide more complex and unique training examples to further regularize our model. The real world is complicated and variable, meaning unseen (by the model) data are likely to be complex and dynamic. Thus, advanced augmentations provide practitioners an ability to introduce more heavily augmented, and sometimes more realistic, data to further improve the model's robustness.

### Composing Multiple Augmentations 

As machine learning practitioners, we've found that composing multiple data augmentations together, most often within a PyTorch `DataSet`, helps to improve performance and regularizes the model. Below is an example of one such composition of basic data augmentations.

In [None]:
# Load in a fresh, untransformed image
dog = torch.from_numpy(plt.imread(fname="images/dog.jpg", format=None))

composed_transform = transforms.Compose([
        transforms.RandomRotation(degrees=20, 
                                  interpolation=transforms.functional.InterpolationMode.NEAREST, 
                                  expand=False, 
                                  center=None, 
                                  fill=0),
        transforms.GaussianBlur(kernel_size=13, sigma=7),
        transforms.ColorJitter(brightness=0.3, 
                               contrast=0.3, 
                               saturation=0.1,
                               hue=0.1)
        ])

# Randomly augment the original image four times
show_data_augmentations(original_im=dog, transform_f=composed_transform, title="Augmented Images")

### More Aggressive Data Augmentations

As shown in lecture, there are several more aggressive data augmentations that are used in practice. Empirically, they are found to have a regularizing effect on the model. Some examples are below:

<img src="https://drive.google.com/uc?id=1x_ednmeFKO-iWficKQdcMnlRaIsHiJ3r" width="800px" align="center"></img>

You can read more about the above data augmentations in the original [PixMix paper](https://arxiv.org/pdf/2112.05135.pdf) from 2022, but we will not explicitly cover them in practice within this demo.

## Part 4: Augmentations in Practice

To showcase the empirical effects of data augmentation in practice, run through the following example which utilizes a ResNet achitecture for an image classification task using the CIFAR-10 dataset.

### CIFAR-10

The CIFAR-10 dataset is a collection of 60,000 labeled 32x32 color images that is commonly used for training computer vision models. Each image is labeled with a class from one of the following categories: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. CIFAR-10 is commonly used as a baseline for a model’s image classification abilities, and we will be using it to train and compare a base model and a data augmented model.

### Loading the Model Architecture

To show the empirical effects of data augmentations in practice, we'll utilize the ResNet-18 architecture with untrained weights. The original [ResNet](https://arxiv.org/abs/1512.03385) architecture allowed for much deeper models due to its innovative approach towards solving the "vanishing gradient" problem: the skip connection. ResNet has proven to be one of the most successful architectures used for object classification and recognition. You can examine the specific layers and parameters below.

<img src="https://drive.google.com/uc?id=1U3X5I40imN2RRsidPh1_BT1pd-DqWWiz" width="600px" align="center"></img>

[Figure 2.](https://www.pluralsight.com/guides/introduction-to-resnet) The ResNet-18 architecture.

In [None]:
# Run this to load the model
model = torchvision.models.resnet18(weights=None)
model

We will use the same ResNet-18 architecture to train two models: a base model with no data augmentation, and another model trained with data augmentation. Below, we use methods from the `torchvision.transforms` module to compose a set of transformations that we can apply to the training data. We augment the data by performing one random horizontal flip followed by one random crop. The PyTorch code is shown below:

**Transformation Code**
```python
data_aug_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomCrop(size=32, padding=4),
    transforms.ToTensor(),
])

base_transform = transforms.Compose([
    transforms.ToTensor(),
])
```

### Training the Models

For the sake of time and computation, we have ommitted the training process, and instead have loaded in the trained models below. To reproduce the process and model training, please reference the `model_train.py` and `utils.py` files located at the root of the directory.

For reference, our training hyperparameters are described below:
```python
batch_size=128
num_epochs=15
learning_rate=0.001
loss = nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(), lr=learning_rate)
```
More information about the ADAM optimizer can be found on the official PyTorch documentation [here](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html).

In [None]:
# Load trained models from models folder
base_model = torch.load("models/resnet18_base.pt", map_location="cpu")
aug_model = torch.load("models/resnet18_data_augment.pt", map_location="cpu")

Let's take a look at the training and validation loss curves for each model:

<img src="https://drive.google.com/uc?id=1oHsuSsSX4mBXFSHqmpYV7LWu-wIfTIco" width="800px" align="left"></img>
<img src="https://drive.google.com/uc?id=1el9SeHIfo0KVMGEynMkIih9qnAbk_yvO" width="800px" align="left"></img>

<div class="alert alert-info"> <b>What do you notice?</b> Consider the trends of the train and validation curves in respect to each other.</div>

Now, let's examine how each model performs on the test dataset. First, we must load in the test dataset:

In [None]:
cifar10_test = torchvision.datasets.CIFAR10(root = "data", 
                                            train=False, 
                                            download = True, 
                                            transform=transforms.ToTensor())

Run the following cells to test the performances of both the base model and data augmented model. You'll see that the model trained using the data augmentations performs better than the base model on unseen test data. You can find the `test_performance` function in the `utils.py` file.

In [None]:
# This may take up to a minute to run
test_performance(model=base_model, test_data=cifar10_test, batch_size=128, device="cpu")

In [None]:
# This may take up to a minute to run
test_performance(model=aug_model, test_data=cifar10_test, batch_size=128, device="cpu")

Although the base model's training loss was nearly half that of the data augmented model's training loss, the base model's performance on the validation set and test set was poor. On the other hand, the data augmented model's performance and accuracy on the test set was better than the base model. This example displays the regularization effect data augmentation has on model training.