### Data Augmentation
- It is a technique to increase the diversity of your training set by applying random (but realistic) transformations such as rotation, translation, flipping, color jittering etc.
- This helps in improving the generalization of the model and reduces overfitting.

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms , datasets , models
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm.auto import tqdm


In [None]:
batch_size = 64
num_epochs = 60
learning_rate = 1e-4

data_set_root = "./data"

image_size = 96

start_from_checkpoint = False
save_dir = "./data/models"
model_name = "Resnet18_STL10"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.AutoAugment(),  # Applying AutoAugment i.e. multiple augmentations
    transforms.ToTensor(),
    transforms.Normalize( mean=[0.5 , 0.5 , 0.5] , std=[0.5 , 0.5 , 0.5])  # Normalizing the images of 3 channels
])

# Loading the STL10 dataset
train_data = datasets.STL10(root=data_set_root , split="train" , download=True , transform=train_transform)
train_loader = DataLoader(train_data , batch_size=batch_size , shuffle=True)

test_data = datasets.STL10(root=data_set_root , split="test" , download=True , transform=train_transform)
test_loader = DataLoader(test_data , batch_size=batch_size , shuffle=False)

# Checking out some augmented images
def show_augmented_images(data_loader):
    batch = next(iter(data_loader))
    images , labels = batch
    images = images[:16]
    grid_img = make_grid(images , nrow=4)
    plt.figure(figsize=(10,10))
    plt.imshow(grid_img.permute(1,2,0).cpu())
    plt.axis("off")
    plt.show()

  0%|          | 3.08M/2.64G [00:05<1:24:05, 523kB/s]  


KeyboardInterrupt: 