In [1]:
import torch
import torch.nn as nn

from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from torch.utils.data import DataLoader
import time

class VGGNet(nn.Module):
    def __init__(self, num_classes):
        super(VGGNet, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes)
        )

        self.initialize_weights()

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)

        x = torch.flatten(x, 1)
        z = self.classifier(x)

        return z

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([T.ToTensor(),
                       T.Resize((224,224))])

train_set = CIFAR10(root='CIFAR10_data/',
                    train=True,
                    transform=transform,
                    download=True)
test_set = CIFAR10(root='CIFAR10_data/',
                   train=False,
                   transform=transform,
                   download=True)

BATCH_SIZE = 50
train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=12)
test_loader = DataLoader(dataset=test_set,
                          batch_size=10,
                          shuffle=False,
                          num_workers=0)


net = VGGNet(num_classes=10).to(DEVICE)

cel = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

EPOCHS = 10
loss_lst = []
acc_lst = []

start_time = time.time()
for epoch in range(EPOCHS):
    batch_time = time.time()
    print(f'====== Epoch: {epoch+1:2d} / {EPOCHS} ======')
    net.train()
    l_sum = 0
    for batch_idx, (x,y) in enumerate(train_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        z = net(x)
        loss = cel(z, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        l_sum += loss.item()
        if (batch_idx+1) % 10 == 0:
            print(f'Batch: {batch_idx+1:2d}/{len(train_loader)} ',
                  f'Loss: {loss.item():0.6f}')       

    loss_lst.append(l_sum/len(train_loader))

    net.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (x,y) in enumerate(test_loader):
            x, y = x.to(DEVICE), y.to(DEVICE)
            z = net(x)
            yhat = torch.argmax(z, dim=1)
            correct += torch.sum(y==yhat)
    
    accuracy = correct / len(test_set)
    acc_lst.append(accuracy)
    print(f'Accuracy: {accuracy.item()*100:0.2f}%')
    print("elapsed time:", time.time() - batch_time)
        
print("elapsed time:", time.time() - start_time)

Files already downloaded and verified
Files already downloaded and verified
Batch: 10/1000  Loss: 2.302355
Batch: 20/1000  Loss: 2.296149
Batch: 30/1000  Loss: 2.268746
Batch: 40/1000  Loss: 2.145165
Batch: 50/1000  Loss: 2.248774
Batch: 60/1000  Loss: 2.156821
Batch: 70/1000  Loss: 2.118255
Batch: 80/1000  Loss: 1.831013
Batch: 90/1000  Loss: 2.266618
Batch: 100/1000  Loss: 2.069404
Batch: 110/1000  Loss: 1.882095
Batch: 120/1000  Loss: 1.970169
Batch: 130/1000  Loss: 2.107259
Batch: 140/1000  Loss: 1.897148
Batch: 150/1000  Loss: 2.039683
Batch: 160/1000  Loss: 2.051947
Batch: 170/1000  Loss: 1.913484
Batch: 180/1000  Loss: 1.885376
Batch: 190/1000  Loss: 1.611502
Batch: 200/1000  Loss: 1.739829
Batch: 210/1000  Loss: 2.132336
Batch: 220/1000  Loss: 1.872237
Batch: 230/1000  Loss: 1.742311
Batch: 240/1000  Loss: 1.798359
Batch: 250/1000  Loss: 1.818119
Batch: 260/1000  Loss: 1.507444
Batch: 270/1000  Loss: 1.646263
Batch: 280/1000  Loss: 1.538991
Batch: 290/1000  Loss: 1.630717
Batch

KeyboardInterrupt: 