In [1]:
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from torch import optim
from torchvision import datasets, transforms
from torchvision.io import read_image
from torch.utils.data import DataLoader
import torchvision.transforms.functional as F
import models

In [2]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [3]:
input = read_image("./data/Pascal VOC 2012/VOCdevkit/VOC2012/JPEGImages/2007_000027.jpg")
transform = transforms.RandomCrop(224)
input = transform(input)
input = transforms.functional.convert_image_dtype(input, torch.float)
input.size(), input.dtype

(torch.Size([3, 224, 224]), torch.float32)

In [4]:
model = models.VGG()
model.eval()
output = model(input.unsqueeze(0))
output.size()

torch.Size([1, 1000])

In [5]:
batch_size = 100
epochs = 10
log_batch_inx = 50

In [6]:
train_data = datasets.MNIST(root="./data/MNIST",
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True
                            )

train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True
                          )

test_data = datasets.MNIST(root="./data/MNIST",
                           train=False,
                           transform=transforms.ToTensor(),
                           download=True
                           )

test_loader = DataLoader(dataset=test_data,
                         batch_size=batch_size,
                         shuffle=True
                         )

In [7]:
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [None]:
# 运行主训练循环，VGG过深无法完成
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_idx % 300 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data.item()))

In [None]:
num_correct = 0
num_samples = 0
model.eval()
with torch.no_grad():
    for batch_idx, (data, labels) in enumerate(test_loader):
        output = model(data)
        _, predictions = torch.max(output, dim=1)
        num_correct += (predictions == labels).sum()
        num_samples += predictions.size(0)
        if batch_idx % 10 == 0:
            print(torch.min(output))
            show(data[0].view(-1, 28, 28))
            print(output[0])
            print(predictions[0], labels[0])
    print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct) / float(num_samples) * 100:.2f}')
model.train()

In [8]:
input = read_image("./data/Pascal VOC 2012/VOCdevkit/VOC2012/JPEGImages/2007_000027.jpg")
transform = transforms.RandomCrop(224)
# input = transform(input)
input = transforms.functional.convert_image_dtype(input, torch.float)
# print(input.size(), input.dtype)
model = models.FCN()
model.eval()
output = model(input.unsqueeze(0))