In [4]:
import torch
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from PIL import Image
from datasets import load_dataset
import os
os.environ["HF_HUB_READ_TIMEOUT"] = "60"
os.environ["HF_HUB_CONNECT_TIMEOUT"] = "60"

In [5]:
class AlexNet(nn.Module):
    def __init__(self, num_classes, task='classification'):
        super().__init__()
        # Only define structure, not computation
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=96,
                               kernel_size=(11,11),
                               stride=4
                              )
        self.conv2 = nn.Conv2d(in_channels=96,
                               out_channels=256, 
                               kernel_size=(5,5),
                               padding=2
                              )
        self.conv3 = nn.Conv2d(in_channels=256,
                               out_channels=384,
                               kernel_size=(3,3),
                               padding=1
                              )
        self.conv4 = nn.Conv2d(in_channels=384,
                               out_channels=384,
                               kernel_size=(3,3),
                               padding=1
                              )
        self.conv5 = nn.Conv2d(in_channels=384,
                               out_channels=256,
                               kernel_size=(3,3),
                               padding=1
                              )
        self.lrn1 = nn.LocalResponseNorm(size=5, k = 2, alpha=1e-4, beta=0.75)
        self.lrn2 = nn.LocalResponseNorm(size=5, k = 2, alpha=1e-4, beta=0.75)

        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        
        self.fc1 = nn.Linear(6*6*256, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)

    def forward(self, x):
        x = self.max_pool(self.lrn1(self.relu(self.conv1(x))))
        x = self.max_pool(self.lrn2(self.relu(self.conv2(x))))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.max_pool(self.relu(self.conv5(x)))
        x = self.dropout(self.relu(self.fc1(x.flatten(start_dim=1))))
        x = self.dropout(self.relu(self.fc2(x)))
        x = self.fc3(x)
        return x

In [6]:
model = AlexNet(num_classes=200)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

AlexNet(
  (conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
  (conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (lrn1): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=2)
  (lrn2): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=2)
  (max_pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=4096, bias=True)
  (fc3): Linear(in_features=4096, out_features=200, bias=True)
)

In [7]:
IMAGENET_EIGVAL = torch.tensor([0.2175, 0.0188, 0.0045])
IMAGENET_EIGVEC = torch.tensor([
    [-0.5675,  0.7192,  0.4009],
    [-0.5808, -0.0045, -0.8140],
    [-0.5836, -0.6948,  0.4203],
])

In [8]:
class AlexNetPCAJitter(torch.nn.Module):
    def __init__(self, eigval, eigvec, alpha_std=0.1):
        super().__init__()
        self.eigval = eigval
        self.eigvec = eigvec
        self.alpha_std = alpha_std

    def forward(self, img):
        # img: Tensor (3, H, W) in [0,1]
        alpha = torch.randn(3) * self.alpha_std
        rgb_shift = (self.eigvec @ (alpha * self.eigval)).view(3, 1, 1)
        return img + rgb_shift

In [10]:
train_dataset = load_dataset('slegroux/tiny-imagenet-200-clean', split='train')                
valid_dataset = load_dataset('slegroux/tiny-imagenet-200-clean', split='validation')
test_dataset = load_dataset('slegroux/tiny-imagenet-200-clean', split='test')

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/151M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/7.54M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/7.57M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/98179 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4909 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4923 [00:00<?, ? examples/s]

In [11]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(227),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    AlexNetPCAJitter(IMAGENET_EIGVAL, IMAGENET_EIGVEC),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(227),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [12]:
def preprocess(example, transform):
    example["image"] = [transform(img) for img in example["image"]]
    return example

train_dataset = train_dataset.with_transform(lambda x: preprocess(x, transform))
valid_dataset = valid_dataset.with_transform(lambda x: preprocess(x, test_transform))
test_dataset = test_dataset.with_transform(lambda x: preprocess(x, test_transform))

In [13]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader   = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=4)
test_loader   = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01, momentum=0.9, weight_decay=0.5e-4)

In [15]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch in dataloader:
        images = batch['image'].to(device)
        label = batch['label'].to(device)

        # Forward
        output = model(images)
        loss = criterion(output, label)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * images.size(0)
        _, preds = output.max(1)
        correct += preds.eq(label).sum().item()
        total += label.size(0)
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

@torch.no_grad()
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch in dataloader:
        images = batch['image'].to(device)
        label = batch['label'].to(device)

        # Forward
        output = model(images)
        loss = criterion(output, label)
        running_loss += loss.item() * images.size(0)
        _, preds = output.max(1)
        correct += preds.eq(label).sum().item()
        total += label.size(0)
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [16]:
best_val_acc = 0.0
num_epochs = 40
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, criterion, device
    )

    val_loss, val_acc = validate(
        model, val_loader, criterion, device
    )

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")

    print(
        f"Epoch [{epoch+1}/{num_epochs}] "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )


Epoch [1/40] Train Loss: 5.2779, Train Acc: 0.0059 | Val Loss: 5.1351, Val Acc: 0.0096
Epoch [2/40] Train Loss: 5.0419, Train Acc: 0.0183 | Val Loss: 4.7858, Val Acc: 0.0432
Epoch [3/40] Train Loss: 4.5257, Train Acc: 0.0668 | Val Loss: 4.3281, Val Acc: 0.0862
Epoch [4/40] Train Loss: 4.0019, Train Acc: 0.1325 | Val Loss: 3.8969, Val Acc: 0.1587
Epoch [5/40] Train Loss: 3.6457, Train Acc: 0.1863 | Val Loss: 3.4635, Val Acc: 0.2267
Epoch [6/40] Train Loss: 3.3738, Train Acc: 0.2327 | Val Loss: 3.3392, Val Acc: 0.2373
Epoch [7/40] Train Loss: 3.1635, Train Acc: 0.2731 | Val Loss: 3.1543, Val Acc: 0.2758
Epoch [8/40] Train Loss: 2.9950, Train Acc: 0.3045 | Val Loss: 2.9320, Val Acc: 0.3233
Epoch [9/40] Train Loss: 2.8524, Train Acc: 0.3323 | Val Loss: 2.9145, Val Acc: 0.3251
Epoch [10/40] Train Loss: 2.7170, Train Acc: 0.3581 | Val Loss: 2.8940, Val Acc: 0.3345
Epoch [11/40] Train Loss: 2.6158, Train Acc: 0.3775 | Val Loss: 2.7697, Val Acc: 0.3591
Epoch [12/40] Train Loss: 2.5146, Train A

In [23]:
_, test_acc = validate( model, test_loader, criterion, device)
print(test_acc)

0.3558805606337599
