In [None]:
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from torch.utils.data import DataLoader
import time
from torchvision.models import vgg19
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([T.ToTensor(),
                       T.Resize((224,224))])

train_set = CIFAR10(root='CIFAR10_data/',
                    train=True,
                    transform=transform,
                    download=True)
test_set = CIFAR10(root='CIFAR10_data/',
                   train=False,
                   transform=transform,
                   download=True)

BATCH_SIZE = 50
train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=0)
test_loader = DataLoader(dataset=test_set,
                          batch_size=10,
                          shuffle=False,
                          num_workers=0)
net = vgg19(pretrained=True)

for param in net.features.parameters():
    param.requires_grad = False

net.classifier[6] = torch.nn.Linear(4096, 10)
net.to(DEVICE)

cel = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

EPOCHS = 10
loss_lst = []
acc_lst = []
start_time = time.time()
for epoch in range(EPOCHS):
    batch_time = time.time()
    print(f'====== Epoch: {epoch+1:2d} / {EPOCHS} ======')
    net.train()
    l_sum = 0
    for batch_idx, (x,y) in enumerate(train_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        z = net(x)
        loss = cel(z, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        l_sum += loss.item()
        if (batch_idx+1) % 10 == 0:
            print(f'Batch: {batch_idx+1:2d}/{len(train_loader)} ',
                f'Loss: {loss.item():0.6f}') 
                         
    loss_lst.append(l_sum/len(train_loader))

    net.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (x,y) in enumerate(test_loader):
            x, y = x.to(DEVICE), y.to(DEVICE)
            z = net(x)
            yhat = torch.argmax(z, dim=1)
            correct += torch.sum(y==yhat)
        
    accuracy = correct / len(test_set)
    acc_lst.append(accuracy)
    print(f'Accuracy: {accuracy.item()*100:0.2f}%')
    print("elapsed time:", time.time() - batch_time)
        
print("elapsed time:", time.time() - start_time)

Files already downloaded and verified
Files already downloaded and verified
Batch: 10/1000  Loss: 2.023846
Batch: 20/1000  Loss: 1.672120
Batch: 30/1000  Loss: 1.108423
Batch: 40/1000  Loss: 1.157998
Batch: 50/1000  Loss: 1.268990
Batch: 60/1000  Loss: 0.961299
Batch: 70/1000  Loss: 1.373873
Batch: 80/1000  Loss: 1.017216
Batch: 90/1000  Loss: 0.836885
