# VGGNet 구현 - 20192253 Hongchan Yoon

In [1]:
import torch
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('current device: ',device)

current device:  cpu


## 1. Dataset Preparation

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import matplotlib.pyplot as plt # for visualization
%matplotlib inline

import numpy as np
import math
from PIL import Image

### 1-1. Load CIFAR100 Dateset & perform PCA Analysis(밑의 RGB ColourShift에 사용)

In [None]:
# Load CIFAR-100 dataset
train_dataset = torchvision.datasets.CIFAR100(root='./', train=True, download=True)
test_dataset = torchvision.datasets.CIFAR100(root='./', train=False, download=True)

# Extract RGB pixel values from the training dataset
pixels = np.vstack([np.asarray(img).reshape(-1, 3) for img, _ in train_dataset])

# Compute the covariance matrix of the RGB pixel values
cov_matrix = np.cov(pixels, rowvar=False)

# Perform eigen decomposition to obtain eigenvectors and eigenvalues
eig_vals, eig_vecs = np.linalg.eigh(cov_matrix)

# Ensure eigenvalues and eigenvectors are sorted in descending order
sorted_indices = np.argsort(eig_vals)[::-1]
eig_vals = eig_vals[sorted_indices]
eig_vecs = eig_vecs[:, sorted_indices]

print("Eigenvalues:", eig_vals)
print("Eigenvectors:\n", eig_vecs)

### 1-2. Define Data Augmentation

In [2]:
def convert2numpy(image):
    if torch.is_tensor(image):
        image = image.data.cpu().numpy()
    else:
        image = np.array(image)
    return image

'''
VGGNet Paper - To further augment the training set, the crops underwent random horizontal flipping and random RGB colour shift

다만 Dataset을 논문과 같이 ILSVRC-2014를 사용하지 않고, 이미지 크기가 32x32인 CIFAR100 을 사용할 것이기에 논문과 같이 Training Scale S에 맞춰 Resize후
Crop 하는 방식은 사용하지 않고, 32x32 image 부분에 5x5 crop(구멍)을 내준다.
'''
class RandomCrop(object):
    def __init__(self, crop_pixel:int = 5):
        self.crop_pixel = crop_pixel

    def __call__(self, image):
        image = convert2numpy(image)
        # Image: Height x Width x Channel
        x_y = np.random.choice(image.shape[0] - self.crop_pixel, 2)
        start_x, start_y = x_y[0], x_y[1]
        image[start_x: start_x + self.crop_pixel, start_y: start_y + self.crop_pixel, :] = 0.0

        return image

# Random Horizontal Flipping
# probability의 확률로 flipping 실행
class RandomHorizontalFlip(object):
    def __init__(self, probability = 0.3):
        assert probability >= 0.0 and probability <= 1.0
        self.probability = probability

    def __call__(self, image):
        self.execute = np.random.rand() < self.probability
        if self.execute:
            new_image = image.transpose(Image.FLIP_LEFT_RIGHT)
            return new_image
        else:
            return image
        
# Random RGB Colour Shift
# probability의 확률로 Colour Shift 실행
# VGGNet 논문에서는 RGB Colour Shift에 대한 자세한 내용은 없고, AlexNet 논문만 인용.
# 따라서 AlexNet의 RGB Colour Shif(PCA 연산 후 더하기)로 구현
'''AlexNet - To each training image, we add multiples of the found principal components
with magnitudes proportional to the corresponding eigenvalues times a random variable drawn from
a Gaussian with mean zero and standard deviation 0.1'''
class RandomRGBColourShift(object):
    def __init__(self, eig_vecs, eig_vals, alpha_std=0.1):
        """
        Initialize with precomputed eigenvectors and eigenvalues.
        
        Parameters:
        - eig_vecs: eigenvectors of the covariance matrix of RGB pixel values
        - eig_vals: eigenvalues of the covariance matrix of RGB pixel values
        - alpha_std: standard deviation of the Gaussian from which alphas are drawn
        - probability: probability of applying the color shift
        """
        self.eig_vecs = eig_vecs
        self.eig_vals = eig_vals
        self.alpha_std = alpha_std

    def __call__(self, image):
        alpha = np.random.normal(0, self.alpha_std, 3)
        quantity = np.dot(self.eig_vecs, alpha * self.eig_vals)
        new_image = image.astype(np.float32)
        for i in range(3):  # For R, G, B channels
            new_image[:, :, i] += quantity[i]
        new_image = np.clip(new_image, 0, 255).astype(np.uint8)
        return new_image

def imshow(img):
    # Un-normalize and display the image
    img = img / 2 + 0.5
    plt.imshow(np.transpose(img, (1,2,0)))

## Apply augmentation to the dataset

In [None]:
train_transform = torchvision.transforms.Compose([
    RandomCrop(),
    RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(32),
    #torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                    (0.247, 0.243, 0.261))
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(32),
    # torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.247, 0.243, 0.261))
])

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100(root='./', train=True, download=True, transform=train_transform),
    batch_size=128, shuffle=True, num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100(root='./', train=False, download=True, transform=test_transform),
    batch_size=128, shuffle=False, num_workers=0
)

## Plot the augmented images.

In [None]:
# Get one batch of training images
dataiter = iter(train_loader)
images, labels = next(dataiter)
# Convert images to numpy for display
images = images.numpy()

classes = ["airplane", "automobile", "bird", "cat", "deer",
           "dog", "frog", "horse", "ship", "truck"]

# Plot the images in the batch
fig = plt.figure(figsize=(25, 4))

# Display 20 images
# Viaulize Images
for idx in np.arange(20):
    ax = fig.add_subplot(2, int(20/2), idx+1, xticks=[], yticks=[])
    imshow(images[idx])
    ax.set_title(classes[labels[idx]])