In [None]:
import torch
torch.manual_seed(42)

from data.dataset import get_dataloaders_pca
from alexnet.augmentation.pca import ImageNetPCA

# only take 50% of the dataset for PCA computation since using the full dataset will be statistically redundant
# and computationally expensive
train_loader, _, _ = get_dataloaders_pca(sample_size=0.5, batch_size=512)
pca = ImageNetPCA(train_loader)
eigenvalues, eigenvectors = pca.fit()

print("Eigenvalues:", eigenvalues)
print("Eigenvectors:", eigenvectors)

<torch._C.Generator at 0x7f22a4177fb0>

In [None]:
# eigenvalues = torch.tensor([0.0042, 0.0185, 0.1957])
# eigenvectors = torch.tensor([
#     [ 0.3957, -0.7228, -0.5666],
#     [-0.8135,  0.0105, -0.5815],
#     [ 0.4262,  0.6910, -0.5838],
# ])

In [None]:
from alexnet.base import AlexNet
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CrossEntropyLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet().to(device)

# According to section 5 (Details of learning) of paper
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
lr_scheduler = ReduceLROnPlateau(optimizer)
criterion = CrossEntropyLoss()

In [None]:
from data.dataset import get_dataloaders_training
from tqdm import tqdm

train_loader, val_loader, test_loader = get_dataloaders_training(batch_size=128, sample_size=1)
epochs = 30

for epoch in range(epochs):
    train_loss = 0.0
    val_loss = 0.0

    with tqdm(total=len(train_loader.dataset), desc=f"Epoch {epoch+1}/{epochs}", unit=" images") as pbar:
        model.train()  # Set the model to training mode
        for batch in train_loader:
            images, labels = batch['pixel_values'], batch['labels']

            # PCA Augmentation according to section 4.1 (Data Augmentation) of the paper
            noise = torch.normal(0, 0.1, size=(images.size(0), 3))
            eigen_noise = noise * eigenvalues # (batch_size, 3) * (3,)
            principal_components = torch.mm(eigen_noise, eigenvectors.t()) # (batch_size, 3) @ (3, 3) = (batch_size, 3)
            # Add the principal components to each pixel in the image
            images = images + principal_components.view(images.size(0), 3, 1, 1)  # this will broadcast the principal components to match the image shape
            print(principal_components.view(images.size(0), 3, 1, 1))

            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            train_loss += loss.item()

            # zero the gradients, forward pass, compute loss, backward pass, and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.update(images.size(0))

    # perform validation
    with torch.no_grad():
        with tqdm(total=len(val_loader.dataset) // val_loader.batch_size, desc=f"Validation {epoch+1}/{epochs}", unit=" batches") as pbar:
            model.eval()
            for batch in val_loader:
                val_images, val_labels = batch['pixel_values'], batch['labels']
                batch_size, num_crops = val_images.shape[:2]
                val_images = val_images.view(-1, *val_images.shape[2:]) # (batch_size * num_crops, channels, height, width)
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                val_outputs = model(val_images)
                val_outputs = val_outputs.view(batch_size, num_crops, -1)
                avg_outputs = val_outputs.mean(dim=1)
                val_loss += criterion(avg_outputs, val_labels).item()
            
                pbar.update(1)

    train_loss /= len(train_loader)
    val_loss /= len(test_loader)
    print(f"Completed Epoch [{epoch+1}/{epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

    # Step the learning rate scheduler
    lr_scheduler.step(val_loss)


Epoch 1/20:   0%|          | 0/1281167 [00:00<?, ? images/s]

tensor([[[[ 1.9902e-02]],

         [[ 1.8627e-02]],

         [[ 1.7190e-02]]],


        [[[-1.6112e-03]],

         [[-1.0936e-03]],

         [[-4.6614e-04]]],


        [[[ 1.6785e-02]],

         [[ 1.7708e-02]],

         [[ 1.7695e-02]]],


        [[[-2.8728e-02]],

         [[-3.0941e-02]],

         [[-3.3200e-02]]],


        [[[ 2.0170e-03]],

         [[ 3.0865e-04]],

         [[-5.6753e-04]]],


        [[[-1.3797e-02]],

         [[-1.5921e-02]],

         [[-1.7019e-02]]],


        [[[-8.2993e-03]],

         [[-7.7794e-03]],

         [[-6.9246e-03]]],


        [[[-1.0892e-02]],

         [[-9.1443e-03]],

         [[-9.0016e-03]]],


        [[[ 1.8452e-03]],

         [[ 1.9317e-04]],

         [[-7.7451e-04]]],


        [[[ 8.6487e-03]],

         [[ 8.0257e-03]],

         [[ 9.3974e-03]]],


        [[[ 3.1069e-03]],

         [[ 2.8870e-03]],

         [[ 1.6782e-03]]],


        [[[ 6.1404e-03]],

         [[ 6.3421e-03]],

         [[ 5.9453e-03]]],


    

Epoch 1/20:   0%|          | 0/1281167 [00:00<?, ? images/s]


In [None]:
torch.save(model.state_dict(), f".checkpoint/alexnet_lrn_{epochs}_epochs.pth")