In [2]:
import torchvision
from torchvision import models
from torchvision import datasets, transforms
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
# tqdm 增加进度条显示
from tqdm import tqdm
import numpy as np

resnet50 = models.resnet50(num_classes=10).cuda()

class ToOneHot:
    def __init__(self, num_classes):
        self.num_classes = num_classes

    def __call__(self, label):
        return torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes).float()

def to_rgb(image):
    return image.convert("RGB")

transform = transforms.Compose([
    transforms.Lambda(to_rgb),  # 转换为 RGB
    transforms.Resize((224, 224)),  # ResNet 输入大小
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='data/', train=True,
                               transform=transform, target_transform=ToOneHot(10),download=True)
test_dataset = datasets.MNIST(root='data/', train=False,
                              transform=transform, target_transform=ToOneHot(10), download=True)
print(type(resnet50))
print(type(train_dataset))

# print("len(train_dataset) = {}".format(len(train_dataset)))
# print("len(test_dataset) = {}".format(len(test_dataset)))
# print(train_dataset.data.shape)
# print(train_dataset.targets.shape)
# print(train_dataset.data[0])
# print(train_dataset.targets[0])

train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=200, shuffle=False)

for train_data in train_loader:
    """
    输入的数据实际上是个迭代器。
    
    2
    torch.Size([100, 1, 28, 28])
    torch.Size([100])
    """
    print(len(train_data))
    print(train_data[0].shape)
    print(train_data[1].shape)
    print(train_data[1][0])
    break

# model 有了，怎么组织数据？
# print(resnet50)
criterion = nn.CrossEntropyLoss().cuda()

optimizer = optim.SGD(resnet50.parameters(), lr=0.01, momentum=0.5)



<class 'torchvision.models.resnet.ResNet'>
<class 'torchvision.datasets.mnist.MNIST'>
2
torch.Size([200, 3, 224, 224])
torch.Size([200, 10])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])


In [None]:
for epoch in range(2):
    i = 0
    for img, label in tqdm(train_loader):
        img = img.cuda()
        label = label.cuda()
        out = resnet50(img)
        loss = criterion(out, label)
        # 前向传播后，其实没有计算梯度的，只有 backward 后才会计算梯度，因此，在 backward 前需要清理梯度
        optimizer.zero_grad()
        loss.backward()
        # 更新权重
        optimizer.step()
        i += 1
        if i % 20 == 0:
            # item()：此方法用于从只有一个元素的张量中提取值，返回 Python 数字。
            # 这种方式不需要使用 detach() 或 cpu()，适用于标量损失。
            print("loss = {}".format(loss.item()))

In [4]:
torch.save(resnet50, './models/resnet50.pth')

In [5]:
torch.save(resnet50.state_dict(), './models/resnet50_state_dict.pth')

In [5]:
import torchvision
from torchvision import models
from torchvision import datasets, transforms
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
# tqdm 增加进度条显示
from tqdm import tqdm
import numpy as np

resnet50 = torch.load('./models/resnet50.pth').cuda()

resnet50.eval()
eval_loss = []
eval_correct = 0
# 不加这个显存暴涨，看起来这个是必须的
with torch.no_grad():
    for img, label in tqdm(test_loader):
        img = img.cuda()
        label = label.cuda()
        out = resnet50(img)
        loss = criterion(out, label)
        eval_loss.append(loss.item())
        _, pred = out.max(1)
        _, labels = label.max(1)
        num_correct = (pred == labels).sum().item()
        # print(num_correct)
        eval_correct += num_correct
print("len(test_loader) = {}".format(len(test_loader)))
# 求平均
loss = np.array(eval_loss)
avg_loss = np.mean(loss)
print(avg_loss)
print("avg loss = {}, eval_acc = {}".format(avg_loss, eval_correct/len(test_dataset)))

  0%|          | 0/50 [00:00<?, ?it/s]

  2%|▏         | 1/50 [00:00<00:17,  2.80it/s]

195


  4%|▍         | 2/50 [00:00<00:17,  2.82it/s]

193


  6%|▌         | 3/50 [00:01<00:16,  2.85it/s]

188


  8%|▊         | 4/50 [00:01<00:16,  2.87it/s]

190


 10%|█         | 5/50 [00:01<00:15,  2.87it/s]

191


 12%|█▏        | 6/50 [00:02<00:15,  2.88it/s]

189


 14%|█▍        | 7/50 [00:02<00:14,  2.87it/s]

186


 16%|█▌        | 8/50 [00:02<00:14,  2.87it/s]

192


 18%|█▊        | 9/50 [00:03<00:14,  2.87it/s]

194


 20%|██        | 10/50 [00:03<00:13,  2.86it/s]

194


 22%|██▏       | 11/50 [00:03<00:13,  2.87it/s]

182


 24%|██▍       | 12/50 [00:04<00:13,  2.87it/s]

191


 26%|██▌       | 13/50 [00:04<00:12,  2.87it/s]

190


 28%|██▊       | 14/50 [00:04<00:12,  2.88it/s]

193


 30%|███       | 15/50 [00:05<00:12,  2.88it/s]

197


 32%|███▏      | 16/50 [00:05<00:11,  2.89it/s]

194


 34%|███▍      | 17/50 [00:05<00:11,  2.91it/s]

189


 36%|███▌      | 18/50 [00:06<00:10,  2.91it/s]

190


 38%|███▊      | 19/50 [00:06<00:10,  2.90it/s]

191


 40%|████      | 20/50 [00:06<00:10,  2.89it/s]

186


 42%|████▏     | 21/50 [00:07<00:10,  2.89it/s]

190


 44%|████▍     | 22/50 [00:07<00:09,  2.83it/s]

185


 46%|████▌     | 23/50 [00:08<00:09,  2.84it/s]

189


 48%|████▊     | 24/50 [00:08<00:09,  2.85it/s]

193


 50%|█████     | 25/50 [00:08<00:08,  2.86it/s]

188


 52%|█████▏    | 26/50 [00:09<00:08,  2.86it/s]

198


 54%|█████▍    | 27/50 [00:09<00:08,  2.84it/s]

200


 56%|█████▌    | 28/50 [00:09<00:07,  2.84it/s]

197


 58%|█████▊    | 29/50 [00:10<00:07,  2.85it/s]

195


 60%|██████    | 30/50 [00:10<00:06,  2.86it/s]

193


 62%|██████▏   | 31/50 [00:10<00:06,  2.86it/s]

185


 64%|██████▍   | 32/50 [00:11<00:06,  2.86it/s]

200


 66%|██████▌   | 33/50 [00:11<00:05,  2.86it/s]

186


 68%|██████▊   | 34/50 [00:11<00:05,  2.85it/s]

194


 70%|███████   | 35/50 [00:12<00:05,  2.86it/s]

199


 72%|███████▏  | 36/50 [00:12<00:04,  2.85it/s]

200


 74%|███████▍  | 37/50 [00:12<00:04,  2.86it/s]

198


 76%|███████▌  | 38/50 [00:13<00:04,  2.86it/s]

199


 78%|███████▊  | 39/50 [00:13<00:03,  2.86it/s]

198


 80%|████████  | 40/50 [00:13<00:03,  2.86it/s]

194


 82%|████████▏ | 41/50 [00:14<00:03,  2.87it/s]

195


 84%|████████▍ | 42/50 [00:14<00:02,  2.84it/s]

193


 86%|████████▌ | 43/50 [00:15<00:02,  2.84it/s]

195


 88%|████████▊ | 44/50 [00:15<00:02,  2.84it/s]

200


 90%|█████████ | 45/50 [00:15<00:01,  2.84it/s]

200


 92%|█████████▏| 46/50 [00:16<00:01,  2.85it/s]

195


 94%|█████████▍| 47/50 [00:16<00:01,  2.85it/s]

200


 96%|█████████▌| 48/50 [00:16<00:00,  2.85it/s]

197


 98%|█████████▊| 49/50 [00:17<00:00,  2.86it/s]

185


100%|██████████| 50/50 [00:17<00:00,  2.86it/s]

198
len(test_loader) = 50
0.12795925172045827
avg loss = 0.12795925172045827, eval_acc = 0.9654



