In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from torch.utils.data import DataLoader
import time
from torchvision.models import vgg19

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([T.ToTensor(),
                       T.Resize((224,224))])

train_set = CIFAR10(root='CIFAR10_data/',
                    train=True,
                    transform=transform,
                    download=True)
test_set = CIFAR10(root='CIFAR10_data/',
                   train=False,
                   transform=transform,
                   download=True)

BATCH_SIZE = 50

train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=0)
test_loader = DataLoader(dataset=test_set,
                          batch_size=10,
                          shuffle=False,
                          num_workers=0)

net = vgg19(pretrained=True)
net.classifier[6] = torch.nn.Linear(4096, 10)
net.to(DEVICE)

cel = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

EPOCHS = 10
loss_lst = []
acc_lst = []
start_time = time.time()

for epoch in range(EPOCHS):
    batch_time = time.time()
    print(f'====== Epoch: {epoch+1:2d} / {EPOCHS} ======')
    net.train()
    l_sum = 0
    for batch_idx, (x,y) in enumerate(train_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        z = net(x)
        loss = cel(z, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        l_sum += loss.item()
        if (batch_idx+1) % 10 == 0:
            print(f'Batch: {batch_idx+1:2d}/{len(train_loader)} ',
                f'Loss: {loss.item():0.6f}')        
            
    loss_lst.append(l_sum/len(train_loader))

    net.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (x,y) in enumerate(test_loader):
            x, y = x.to(DEVICE), y.to(DEVICE)
            z = net(x)
            yhat = torch.argmax(z, dim=1)
            correct += torch.sum(y==yhat)
        
    accuracy = correct / len(test_set)
    acc_lst.append(accuracy)
    print(f'Accuracy: {accuracy.item()*100:0.2f}%')
    print("elapsed time:", time.time() - batch_time)
        
print("elapsed time:", time.time() - start_time)

Files already downloaded and verified
Files already downloaded and verified
Batch: 10/1000  Loss: 1.936302
Batch: 20/1000  Loss: 1.582902
Batch: 30/1000  Loss: 1.302579
Batch: 40/1000  Loss: 0.981182
Batch: 50/1000  Loss: 0.957966
Batch: 60/1000  Loss: 1.163240
Batch: 70/1000  Loss: 1.009224
Batch: 80/1000  Loss: 0.584713
Batch: 90/1000  Loss: 0.889232
Batch: 100/1000  Loss: 0.713211
Batch: 110/1000  Loss: 0.527913
Batch: 120/1000  Loss: 0.863501
Batch: 130/1000  Loss: 0.604608
Batch: 140/1000  Loss: 0.475462
Batch: 150/1000  Loss: 0.759447
Batch: 160/1000  Loss: 0.461535
Batch: 170/1000  Loss: 0.424075
Batch: 180/1000  Loss: 0.728401
Batch: 190/1000  Loss: 0.519665
Batch: 200/1000  Loss: 0.666761
Batch: 210/1000  Loss: 0.854255
Batch: 220/1000  Loss: 0.270021
Batch: 230/1000  Loss: 0.463320
Batch: 240/1000  Loss: 0.464961
Batch: 250/1000  Loss: 0.337852
Batch: 260/1000  Loss: 0.565331
Batch: 270/1000  Loss: 0.400102
Batch: 280/1000  Loss: 0.424279
Batch: 290/1000  Loss: 0.680034
Batch

KeyboardInterrupt: 